forked from OSchip/llvm-project
NFC: Replace Module::getNamedFunction with lookupSymbol<FuncOp>.
This allows for removing the last direct reference to FuncOp from ModuleOp. PiperOrigin-RevId: 257498296
This commit is contained in:
parent
122cab6770
commit
6da343ecfc
|
@ -225,7 +225,7 @@ struct PythonMLIRModule {
|
|||
}
|
||||
|
||||
PythonFunction getNamedFunction(const std::string &name) {
|
||||
return moduleManager.getNamedFunction(name);
|
||||
return moduleManager.lookupSymbol<FuncOp>(name);
|
||||
}
|
||||
|
||||
PythonFunctionContext
|
||||
|
|
|
@ -121,7 +121,7 @@ public:
|
|||
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
auto main = module.getNamedFunction("main");
|
||||
auto main = module.lookupSymbol<mlir::FuncOp>("main");
|
||||
if (!main) {
|
||||
emitError(mlir::UnknownLoc::get(module.getContext()),
|
||||
"Shape inference failed: can't find a main function\n");
|
||||
|
@ -161,7 +161,8 @@ public:
|
|||
// We will create a new function with the concrete types for the parameters
|
||||
// and clone the body into it.
|
||||
if (!functionToSpecialize.mangledName.empty()) {
|
||||
if (getModule().getNamedFunction(functionToSpecialize.mangledName)) {
|
||||
if (getModule().lookupSymbol<mlir::FuncOp>(
|
||||
functionToSpecialize.mangledName)) {
|
||||
funcWorklist.pop_back();
|
||||
// Function already specialized, move on.
|
||||
return mlir::success();
|
||||
|
@ -295,7 +296,7 @@ public:
|
|||
// restart after the callee is processed.
|
||||
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
|
||||
auto calleeName = callOp.getCalleeName();
|
||||
auto callee = getModule().getNamedFunction(calleeName);
|
||||
auto callee = getModule().lookupSymbol<mlir::FuncOp>(calleeName);
|
||||
if (!callee) {
|
||||
f.emitError("Shape inference failed, call to unknown '")
|
||||
<< calleeName << "'";
|
||||
|
@ -305,7 +306,8 @@ public:
|
|||
auto mangledName = mangle(calleeName, op->getOpOperands());
|
||||
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
|
||||
<< "', mangled: '" << mangledName << "'\n");
|
||||
auto mangledCallee = getModule().getNamedFunction(mangledName);
|
||||
auto mangledCallee =
|
||||
getModule().lookupSymbol<mlir::FuncOp>(mangledName);
|
||||
if (!mangledCallee) {
|
||||
// Can't find the target, this is where we queue the request for the
|
||||
// callee and stop the inference for the current function now.
|
||||
|
|
|
@ -206,7 +206,7 @@ private:
|
|||
/// Return the prototype declaration for printf in the module, create it if
|
||||
/// necessary.
|
||||
FuncOp getPrintf(ModuleOp module) const {
|
||||
auto printfFunc = module.getNamedFunction("printf");
|
||||
auto printfFunc = module.lookupSymbol<FuncOp>("printf");
|
||||
if (printfFunc)
|
||||
return printfFunc;
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ public:
|
|||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
mlir::ModuleManager moduleManager(module);
|
||||
auto main = moduleManager.getNamedFunction("main");
|
||||
auto main = moduleManager.lookupSymbol<mlir::FuncOp>("main");
|
||||
if (!main) {
|
||||
emitError(mlir::UnknownLoc::get(module.getContext()),
|
||||
"Shape inference failed: can't find a main function\n");
|
||||
|
@ -163,7 +163,8 @@ public:
|
|||
// We will create a new function with the concrete types for the parameters
|
||||
// and clone the body into it.
|
||||
if (!functionToSpecialize.mangledName.empty()) {
|
||||
if (moduleManager.getNamedFunction(functionToSpecialize.mangledName)) {
|
||||
if (moduleManager.lookupSymbol<mlir::FuncOp>(
|
||||
functionToSpecialize.mangledName)) {
|
||||
funcWorklist.pop_back();
|
||||
// FuncOp already specialized, move on.
|
||||
return mlir::success();
|
||||
|
@ -298,7 +299,7 @@ public:
|
|||
// restart after the callee is processed.
|
||||
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
|
||||
auto calleeName = callOp.getCalleeName();
|
||||
auto callee = moduleManager.getNamedFunction(calleeName);
|
||||
auto callee = moduleManager.lookupSymbol<mlir::FuncOp>(calleeName);
|
||||
if (!callee) {
|
||||
signalPassFailure();
|
||||
return f.emitError("Shape inference failed, call to unknown '")
|
||||
|
@ -307,7 +308,8 @@ public:
|
|||
auto mangledName = mangle(calleeName, op->getOpOperands());
|
||||
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
|
||||
<< "', mangled: '" << mangledName << "'\n");
|
||||
auto mangledCallee = moduleManager.getNamedFunction(mangledName);
|
||||
auto mangledCallee =
|
||||
moduleManager.lookupSymbol<mlir::FuncOp>(mangledName);
|
||||
if (!mangledCallee) {
|
||||
// Can't find the target, this is where we queue the request for the
|
||||
// callee and stop the inference for the current function now.
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#ifndef MLIR_IR_MODULE_H
|
||||
#define MLIR_IR_MODULE_H
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -93,11 +92,6 @@ public:
|
|||
insertPt = Block::iterator(body->getTerminator());
|
||||
body->getOperations().insert(insertPt, op);
|
||||
}
|
||||
|
||||
/// Look up a function with the specified name, returning null if no such
|
||||
/// name exists. Function names never include the @ on them. Note: This
|
||||
/// performs a linear scan of held symbols.
|
||||
FuncOp getNamedFunction(StringRef name) { return lookupSymbol<FuncOp>(name); }
|
||||
};
|
||||
|
||||
/// The ModuleTerminatorOp is a special terminator operation for the body of a
|
||||
|
@ -130,8 +124,8 @@ public:
|
|||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names must never include the @ on them.
|
||||
template <typename NameTy> FuncOp getNamedFunction(NameTy &&name) const {
|
||||
return symbolTable.lookup<FuncOp>(name);
|
||||
template <typename T, typename NameTy> T lookupSymbol(NameTy &&name) const {
|
||||
return symbolTable.lookup<T>(name);
|
||||
}
|
||||
|
||||
/// Insert a new symbol into the module, auto-renaming it as necessary.
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#ifndef MLIR_PASS_ANALYSISMANAGER_H
|
||||
#define MLIR_PASS_ANALYSISMANAGER_H
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Pass/PassInstrumentation.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
|
||||
namespace mlir {
|
||||
class Attribute;
|
||||
class FuncOp;
|
||||
class Location;
|
||||
class ModuleOp;
|
||||
class Operation;
|
||||
|
|
|
@ -152,7 +152,7 @@ private:
|
|||
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
||||
ModuleOp module = getModule();
|
||||
Builder builder(module);
|
||||
if (!module.getNamedFunction(cuModuleLoadName)) {
|
||||
if (!module.lookupSymbol<FuncOp>(cuModuleLoadName)) {
|
||||
module.push_back(
|
||||
FuncOp::create(loc, cuModuleLoadName,
|
||||
builder.getFunctionType(
|
||||
|
@ -162,7 +162,7 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
|||
},
|
||||
getCUResultType())));
|
||||
}
|
||||
if (!module.getNamedFunction(cuModuleGetFunctionName)) {
|
||||
if (!module.lookupSymbol<FuncOp>(cuModuleGetFunctionName)) {
|
||||
// The helper uses void* instead of CUDA's opaque CUmodule and
|
||||
// CUfunction.
|
||||
module.push_back(
|
||||
|
@ -175,7 +175,7 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
|||
},
|
||||
getCUResultType())));
|
||||
}
|
||||
if (!module.getNamedFunction(cuLaunchKernelName)) {
|
||||
if (!module.lookupSymbol<FuncOp>(cuLaunchKernelName)) {
|
||||
// Other than the CUDA api, the wrappers use uintptr_t to match the
|
||||
// LLVM type if MLIR's index type, which the GPU dialect uses.
|
||||
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
|
||||
|
@ -198,14 +198,14 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
|||
},
|
||||
getCUResultType())));
|
||||
}
|
||||
if (!module.getNamedFunction(cuGetStreamHelperName)) {
|
||||
if (!module.lookupSymbol<FuncOp>(cuGetStreamHelperName)) {
|
||||
// Helper function to get the current CUDA stream. Uses void* instead of
|
||||
// CUDAs opaque CUstream.
|
||||
module.push_back(FuncOp::create(
|
||||
loc, cuGetStreamHelperName,
|
||||
builder.getFunctionType({}, getPointerType() /* void *stream */)));
|
||||
}
|
||||
if (!module.getNamedFunction(cuStreamSynchronizeName)) {
|
||||
if (!module.lookupSymbol<FuncOp>(cuStreamSynchronizeName)) {
|
||||
module.push_back(
|
||||
FuncOp::create(loc, cuStreamSynchronizeName,
|
||||
builder.getFunctionType(
|
||||
|
@ -322,7 +322,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
// Emit a call to the cubin getter to retrieve a pointer to the data that
|
||||
// represents the cubin at runtime.
|
||||
// TODO(herhut): This should rather be a static global once supported.
|
||||
auto kernelFunction = getModule().getNamedFunction(launchOp.kernel());
|
||||
auto kernelFunction = getModule().lookupSymbol<FuncOp>(launchOp.kernel());
|
||||
auto cubinGetter =
|
||||
kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
|
||||
if (!cubinGetter) {
|
||||
|
@ -335,7 +335,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
// Emit the load module call to load the module data. Error checking is done
|
||||
// in the called helper function.
|
||||
auto cuModule = allocatePointer(builder, loc);
|
||||
FuncOp cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
|
||||
FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuModuleLoad),
|
||||
ArrayRef<Value *>{cuModule, data.getResult(0)});
|
||||
|
@ -346,19 +346,19 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
|
||||
auto cuFunction = allocatePointer(builder, loc);
|
||||
FuncOp cuModuleGetFunction =
|
||||
getModule().getNamedFunction(cuModuleGetFunctionName);
|
||||
getModule().lookupSymbol<FuncOp>(cuModuleGetFunctionName);
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuModuleGetFunction),
|
||||
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
|
||||
// Grab the global stream needed for execution.
|
||||
FuncOp cuGetStreamHelper =
|
||||
getModule().getNamedFunction(cuGetStreamHelperName);
|
||||
getModule().lookupSymbol<FuncOp>(cuGetStreamHelperName);
|
||||
auto cuStream = builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getPointerType()},
|
||||
builder.getFunctionAttr(cuGetStreamHelper), ArrayRef<Value *>{});
|
||||
// Invoke the function with required arguments.
|
||||
auto cuLaunchKernel = getModule().getNamedFunction(cuLaunchKernelName);
|
||||
auto cuLaunchKernel = getModule().lookupSymbol<FuncOp>(cuLaunchKernelName);
|
||||
auto cuFunctionRef =
|
||||
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
|
||||
auto paramsArray = setupParamsArray(launchOp, builder);
|
||||
|
@ -375,7 +375,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
paramsArray, /* kernel params */
|
||||
nullpointer /* extra */});
|
||||
// Sync on the stream to make it synchronous.
|
||||
auto cuStreamSync = getModule().getNamedFunction(cuStreamSynchronizeName);
|
||||
auto cuStreamSync = getModule().lookupSymbol<FuncOp>(cuStreamSynchronizeName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuStreamSync),
|
||||
ArrayRef<Value *>(cuStream.getResult(0)));
|
||||
|
|
|
@ -60,7 +60,7 @@ private:
|
|||
}
|
||||
|
||||
FuncOp getMallocHelper(Location loc, Builder &builder) {
|
||||
FuncOp result = getModule().getNamedFunction(kMallocHelperName);
|
||||
FuncOp result = getModule().lookupSymbol<FuncOp>(kMallocHelperName);
|
||||
if (!result) {
|
||||
result = FuncOp::create(
|
||||
loc, kMallocHelperName,
|
||||
|
|
|
@ -442,7 +442,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
FuncOp mallocFunc = module.getNamedFunction("malloc");
|
||||
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType =
|
||||
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
|
||||
|
@ -503,7 +503,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
OperandAdaptor<DeallocOp> transformed(operands);
|
||||
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
FuncOp freeFunc = op->getParentOfType<ModuleOp>().getNamedFunction("free");
|
||||
FuncOp freeFunc =
|
||||
op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
|
||||
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
|
||||
|
|
|
@ -426,7 +426,7 @@ LogicalResult LaunchFuncOp::verify() {
|
|||
}
|
||||
|
||||
auto module = getParentOfType<ModuleOp>();
|
||||
FuncOp kernelFunc = module.getNamedFunction(kernel());
|
||||
FuncOp kernelFunc = module.lookupSymbol<FuncOp>(kernel());
|
||||
if (!kernelFunc)
|
||||
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
|
|
|
@ -171,7 +171,7 @@ public:
|
|||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
FuncOp mallocFunc = module.getNamedFunction("malloc");
|
||||
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
|
||||
mallocFunc =
|
||||
|
@ -232,7 +232,7 @@ public:
|
|||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
FuncOp freeFunc = module.getNamedFunction("free");
|
||||
FuncOp freeFunc = module.lookupSymbol<FuncOp>("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
|
||||
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
|
||||
|
@ -576,7 +576,7 @@ public:
|
|||
static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) {
|
||||
auto implFnName = (libFn.getName().str() + "_impl");
|
||||
auto module = libFn.getParentOfType<ModuleOp>();
|
||||
if (auto f = module.getNamedFunction(implFnName)) {
|
||||
if (auto f = module.lookupSymbol<FuncOp>(implFnName)) {
|
||||
return f;
|
||||
}
|
||||
SmallVector<Type, 4> fnArgTypes;
|
||||
|
@ -603,7 +603,7 @@ static FuncOp getLLVMLibraryCallDeclaration(Operation *op,
|
|||
assert(isa<LinalgOp>(op));
|
||||
auto fnName = LinalgOp::getLibraryCallName();
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
if (auto f = module.getNamedFunction(fnName)) {
|
||||
if (auto f = module.lookupSymbol<FuncOp>(fnName)) {
|
||||
return f;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/SPIRV/Serialization.h"
|
||||
|
|
|
@ -431,7 +431,8 @@ static LogicalResult verify(CallOp op) {
|
|||
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return op.emitOpError("requires a 'callee' function attribute");
|
||||
auto fn = op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue());
|
||||
auto fn =
|
||||
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError() << "'" << fnAttr.getValue()
|
||||
<< "' does not reference a valid function";
|
||||
|
@ -1098,7 +1099,7 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
|
||||
// Try to find the referenced function.
|
||||
auto fn =
|
||||
op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue());
|
||||
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError("reference to undefined function 'bar'");
|
||||
|
||||
|
|
|
@ -164,7 +164,7 @@ static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
|
|||
static Error compileAndExecuteFunctionWithMemRefs(
|
||||
ModuleOp module, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
FuncOp mainFunction = module.getNamedFunction(entryPoint);
|
||||
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
|
||||
if (!mainFunction || mainFunction.getBlocks().empty()) {
|
||||
return make_string_error("entry point not found");
|
||||
}
|
||||
|
@ -207,7 +207,7 @@ static Error compileAndExecuteFunctionWithMemRefs(
|
|||
static Error compileAndExecuteSingleFloatReturnFunction(
|
||||
ModuleOp module, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
FuncOp mainFunction = module.getNamedFunction(entryPoint);
|
||||
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
|
||||
if (!mainFunction || mainFunction.isExternal()) {
|
||||
return make_string_error("entry point not found");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue