NFC: Replace Module::getNamedFunction with lookupSymbol<FuncOp>.

This allows for removing the last direct reference to FuncOp from ModuleOp.

PiperOrigin-RevId: 257498296
This commit is contained in:
River Riddle 2019-07-10 15:49:27 -07:00 committed by jpienaar
parent 122cab6770
commit 6da343ecfc
16 changed files with 45 additions and 41 deletions

View File

@ -225,7 +225,7 @@ struct PythonMLIRModule {
}
PythonFunction getNamedFunction(const std::string &name) {
return moduleManager.getNamedFunction(name);
return moduleManager.lookupSymbol<FuncOp>(name);
}
PythonFunctionContext

View File

@ -121,7 +121,7 @@ public:
void runOnModule() override {
auto module = getModule();
auto main = module.getNamedFunction("main");
auto main = module.lookupSymbol<mlir::FuncOp>("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n");
@ -161,7 +161,8 @@ public:
// We will create a new function with the concrete types for the parameters
// and clone the body into it.
if (!functionToSpecialize.mangledName.empty()) {
if (getModule().getNamedFunction(functionToSpecialize.mangledName)) {
if (getModule().lookupSymbol<mlir::FuncOp>(
functionToSpecialize.mangledName)) {
funcWorklist.pop_back();
// Function already specialized, move on.
return mlir::success();
@ -295,7 +296,7 @@ public:
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
auto callee = getModule().getNamedFunction(calleeName);
auto callee = getModule().lookupSymbol<mlir::FuncOp>(calleeName);
if (!callee) {
f.emitError("Shape inference failed, call to unknown '")
<< calleeName << "'";
@ -305,7 +306,8 @@ public:
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
auto mangledCallee = getModule().getNamedFunction(mangledName);
auto mangledCallee =
getModule().lookupSymbol<mlir::FuncOp>(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.

View File

@ -206,7 +206,7 @@ private:
/// Return the prototype declaration for printf in the module, create it if
/// necessary.
FuncOp getPrintf(ModuleOp module) const {
auto printfFunc = module.getNamedFunction("printf");
auto printfFunc = module.lookupSymbol<FuncOp>("printf");
if (printfFunc)
return printfFunc;

View File

@ -122,7 +122,7 @@ public:
void runOnModule() override {
auto module = getModule();
mlir::ModuleManager moduleManager(module);
auto main = moduleManager.getNamedFunction("main");
auto main = moduleManager.lookupSymbol<mlir::FuncOp>("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n");
@ -163,7 +163,8 @@ public:
// We will create a new function with the concrete types for the parameters
// and clone the body into it.
if (!functionToSpecialize.mangledName.empty()) {
if (moduleManager.getNamedFunction(functionToSpecialize.mangledName)) {
if (moduleManager.lookupSymbol<mlir::FuncOp>(
functionToSpecialize.mangledName)) {
funcWorklist.pop_back();
// FuncOp already specialized, move on.
return mlir::success();
@ -298,7 +299,7 @@ public:
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
auto callee = moduleManager.getNamedFunction(calleeName);
auto callee = moduleManager.lookupSymbol<mlir::FuncOp>(calleeName);
if (!callee) {
signalPassFailure();
return f.emitError("Shape inference failed, call to unknown '")
@ -307,7 +308,8 @@ public:
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
auto mangledCallee = moduleManager.getNamedFunction(mangledName);
auto mangledCallee =
moduleManager.lookupSymbol<mlir::FuncOp>(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.

View File

@ -22,7 +22,6 @@
#ifndef MLIR_IR_MODULE_H
#define MLIR_IR_MODULE_H
#include "mlir/IR/Function.h"
#include "mlir/IR/SymbolTable.h"
namespace mlir {
@ -93,11 +92,6 @@ public:
insertPt = Block::iterator(body->getTerminator());
body->getOperations().insert(insertPt, op);
}
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them. Note: This
/// performs a linear scan of held symbols.
FuncOp getNamedFunction(StringRef name) { return lookupSymbol<FuncOp>(name); }
};
/// The ModuleTerminatorOp is a special terminator operation for the body of a
@ -130,8 +124,8 @@ public:
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names must never include the @ on them.
template <typename NameTy> FuncOp getNamedFunction(NameTy &&name) const {
return symbolTable.lookup<FuncOp>(name);
template <typename T, typename NameTy> T lookupSymbol(NameTy &&name) const {
return symbolTable.lookup<T>(name);
}
/// Insert a new symbol into the module, auto-renaming it as necessary.

View File

@ -18,6 +18,7 @@
#ifndef MLIR_PASS_ANALYSISMANAGER_H
#define MLIR_PASS_ANALYSISMANAGER_H
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Pass/PassInstrumentation.h"
#include "mlir/Support/LLVM.h"

View File

@ -34,6 +34,7 @@
namespace mlir {
class Attribute;
class FuncOp;
class Location;
class ModuleOp;
class Operation;

View File

@ -152,7 +152,7 @@ private:
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
ModuleOp module = getModule();
Builder builder(module);
if (!module.getNamedFunction(cuModuleLoadName)) {
if (!module.lookupSymbol<FuncOp>(cuModuleLoadName)) {
module.push_back(
FuncOp::create(loc, cuModuleLoadName,
builder.getFunctionType(
@ -162,7 +162,7 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
},
getCUResultType())));
}
if (!module.getNamedFunction(cuModuleGetFunctionName)) {
if (!module.lookupSymbol<FuncOp>(cuModuleGetFunctionName)) {
// The helper uses void* instead of CUDA's opaque CUmodule and
// CUfunction.
module.push_back(
@ -175,7 +175,7 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
},
getCUResultType())));
}
if (!module.getNamedFunction(cuLaunchKernelName)) {
if (!module.lookupSymbol<FuncOp>(cuLaunchKernelName)) {
// Other than the CUDA api, the wrappers use uintptr_t to match the
// LLVM type if MLIR's index type, which the GPU dialect uses.
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
@ -198,14 +198,14 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
},
getCUResultType())));
}
if (!module.getNamedFunction(cuGetStreamHelperName)) {
if (!module.lookupSymbol<FuncOp>(cuGetStreamHelperName)) {
// Helper function to get the current CUDA stream. Uses void* instead of
// CUDAs opaque CUstream.
module.push_back(FuncOp::create(
loc, cuGetStreamHelperName,
builder.getFunctionType({}, getPointerType() /* void *stream */)));
}
if (!module.getNamedFunction(cuStreamSynchronizeName)) {
if (!module.lookupSymbol<FuncOp>(cuStreamSynchronizeName)) {
module.push_back(
FuncOp::create(loc, cuStreamSynchronizeName,
builder.getFunctionType(
@ -322,7 +322,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// Emit a call to the cubin getter to retrieve a pointer to the data that
// represents the cubin at runtime.
// TODO(herhut): This should rather be a static global once supported.
auto kernelFunction = getModule().getNamedFunction(launchOp.kernel());
auto kernelFunction = getModule().lookupSymbol<FuncOp>(launchOp.kernel());
auto cubinGetter =
kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
if (!cubinGetter) {
@ -335,7 +335,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// Emit the load module call to load the module data. Error checking is done
// in the called helper function.
auto cuModule = allocatePointer(builder, loc);
FuncOp cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleLoad),
ArrayRef<Value *>{cuModule, data.getResult(0)});
@ -346,19 +346,19 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
auto cuFunction = allocatePointer(builder, loc);
FuncOp cuModuleGetFunction =
getModule().getNamedFunction(cuModuleGetFunctionName);
getModule().lookupSymbol<FuncOp>(cuModuleGetFunctionName);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleGetFunction),
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
// Grab the global stream needed for execution.
FuncOp cuGetStreamHelper =
getModule().getNamedFunction(cuGetStreamHelperName);
getModule().lookupSymbol<FuncOp>(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},
builder.getFunctionAttr(cuGetStreamHelper), ArrayRef<Value *>{});
// Invoke the function with required arguments.
auto cuLaunchKernel = getModule().getNamedFunction(cuLaunchKernelName);
auto cuLaunchKernel = getModule().lookupSymbol<FuncOp>(cuLaunchKernelName);
auto cuFunctionRef =
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
auto paramsArray = setupParamsArray(launchOp, builder);
@ -375,7 +375,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
paramsArray, /* kernel params */
nullpointer /* extra */});
// Sync on the stream to make it synchronous.
auto cuStreamSync = getModule().getNamedFunction(cuStreamSynchronizeName);
auto cuStreamSync = getModule().lookupSymbol<FuncOp>(cuStreamSynchronizeName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuStreamSync),
ArrayRef<Value *>(cuStream.getResult(0)));

View File

@ -60,7 +60,7 @@ private:
}
FuncOp getMallocHelper(Location loc, Builder &builder) {
FuncOp result = getModule().getNamedFunction(kMallocHelperName);
FuncOp result = getModule().lookupSymbol<FuncOp>(kMallocHelperName);
if (!result) {
result = FuncOp::create(
loc, kMallocHelperName,

View File

@ -442,7 +442,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp mallocFunc = module.getNamedFunction("malloc");
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
if (!mallocFunc) {
auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
@ -503,7 +503,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
FuncOp freeFunc = op->getParentOfType<ModuleOp>().getNamedFunction("free");
FuncOp freeFunc =
op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);

View File

@ -426,7 +426,7 @@ LogicalResult LaunchFuncOp::verify() {
}
auto module = getParentOfType<ModuleOp>();
FuncOp kernelFunc = module.getNamedFunction(kernel());
FuncOp kernelFunc = module.lookupSymbol<FuncOp>(kernel());
if (!kernelFunc)
return emitError() << "kernel function '" << kernelAttr << "' is undefined";

View File

@ -27,6 +27,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"

View File

@ -171,7 +171,7 @@ public:
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp mallocFunc = module.getNamedFunction("malloc");
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
mallocFunc =
@ -232,7 +232,7 @@ public:
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp freeFunc = module.getNamedFunction("free");
FuncOp freeFunc = module.lookupSymbol<FuncOp>("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
@ -576,7 +576,7 @@ public:
static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) {
auto implFnName = (libFn.getName().str() + "_impl");
auto module = libFn.getParentOfType<ModuleOp>();
if (auto f = module.getNamedFunction(implFnName)) {
if (auto f = module.lookupSymbol<FuncOp>(implFnName)) {
return f;
}
SmallVector<Type, 4> fnArgTypes;
@ -603,7 +603,7 @@ static FuncOp getLLVMLibraryCallDeclaration(Operation *op,
assert(isa<LinalgOp>(op));
auto fnName = LinalgOp::getLibraryCallName();
auto module = op->getParentOfType<ModuleOp>();
if (auto f = module.getNamedFunction(fnName)) {
if (auto f = module.lookupSymbol<FuncOp>(fnName)) {
return f;
}

View File

@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/SPIRV/SPIRVOps.h"
#include "mlir/SPIRV/Serialization.h"

View File

@ -431,7 +431,8 @@ static LogicalResult verify(CallOp op) {
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute");
auto fn = op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue());
auto fn =
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
if (!fn)
return op.emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
@ -1098,7 +1099,7 @@ static LogicalResult verify(ConstantOp &op) {
// Try to find the referenced function.
auto fn =
op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue());
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
if (!fn)
return op.emitOpError("reference to undefined function 'bar'");

View File

@ -164,7 +164,7 @@ static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
static Error compileAndExecuteFunctionWithMemRefs(
ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
FuncOp mainFunction = module.getNamedFunction(entryPoint);
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
if (!mainFunction || mainFunction.getBlocks().empty()) {
return make_string_error("entry point not found");
}
@ -207,7 +207,7 @@ static Error compileAndExecuteFunctionWithMemRefs(
static Error compileAndExecuteSingleFloatReturnFunction(
ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
FuncOp mainFunction = module.getNamedFunction(entryPoint);
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
if (!mainFunction || mainFunction.isExternal()) {
return make_string_error("entry point not found");
}