NFC: Replace Module::getNamedFunction with lookupSymbol<FuncOp>.

This allows for removing the last direct reference to FuncOp from ModuleOp.

PiperOrigin-RevId: 257498296
This commit is contained in:
River Riddle 2019-07-10 15:49:27 -07:00 committed by jpienaar
parent 122cab6770
commit 6da343ecfc
16 changed files with 45 additions and 41 deletions

View File

@ -225,7 +225,7 @@ struct PythonMLIRModule {
} }
PythonFunction getNamedFunction(const std::string &name) { PythonFunction getNamedFunction(const std::string &name) {
return moduleManager.getNamedFunction(name); return moduleManager.lookupSymbol<FuncOp>(name);
} }
PythonFunctionContext PythonFunctionContext

View File

@ -121,7 +121,7 @@ public:
void runOnModule() override { void runOnModule() override {
auto module = getModule(); auto module = getModule();
auto main = module.getNamedFunction("main"); auto main = module.lookupSymbol<mlir::FuncOp>("main");
if (!main) { if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()), emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n"); "Shape inference failed: can't find a main function\n");
@ -161,7 +161,8 @@ public:
// We will create a new function with the concrete types for the parameters // We will create a new function with the concrete types for the parameters
// and clone the body into it. // and clone the body into it.
if (!functionToSpecialize.mangledName.empty()) { if (!functionToSpecialize.mangledName.empty()) {
if (getModule().getNamedFunction(functionToSpecialize.mangledName)) { if (getModule().lookupSymbol<mlir::FuncOp>(
functionToSpecialize.mangledName)) {
funcWorklist.pop_back(); funcWorklist.pop_back();
// Function already specialized, move on. // Function already specialized, move on.
return mlir::success(); return mlir::success();
@ -295,7 +296,7 @@ public:
// restart after the callee is processed. // restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) { if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName(); auto calleeName = callOp.getCalleeName();
auto callee = getModule().getNamedFunction(calleeName); auto callee = getModule().lookupSymbol<mlir::FuncOp>(calleeName);
if (!callee) { if (!callee) {
f.emitError("Shape inference failed, call to unknown '") f.emitError("Shape inference failed, call to unknown '")
<< calleeName << "'"; << calleeName << "'";
@ -305,7 +306,8 @@ public:
auto mangledName = mangle(calleeName, op->getOpOperands()); auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n"); << "', mangled: '" << mangledName << "'\n");
auto mangledCallee = getModule().getNamedFunction(mangledName); auto mangledCallee =
getModule().lookupSymbol<mlir::FuncOp>(mangledName);
if (!mangledCallee) { if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the // Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now. // callee and stop the inference for the current function now.

View File

@ -206,7 +206,7 @@ private:
/// Return the prototype declaration for printf in the module, create it if /// Return the prototype declaration for printf in the module, create it if
/// necessary. /// necessary.
FuncOp getPrintf(ModuleOp module) const { FuncOp getPrintf(ModuleOp module) const {
auto printfFunc = module.getNamedFunction("printf"); auto printfFunc = module.lookupSymbol<FuncOp>("printf");
if (printfFunc) if (printfFunc)
return printfFunc; return printfFunc;

View File

@ -122,7 +122,7 @@ public:
void runOnModule() override { void runOnModule() override {
auto module = getModule(); auto module = getModule();
mlir::ModuleManager moduleManager(module); mlir::ModuleManager moduleManager(module);
auto main = moduleManager.getNamedFunction("main"); auto main = moduleManager.lookupSymbol<mlir::FuncOp>("main");
if (!main) { if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()), emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n"); "Shape inference failed: can't find a main function\n");
@ -163,7 +163,8 @@ public:
// We will create a new function with the concrete types for the parameters // We will create a new function with the concrete types for the parameters
// and clone the body into it. // and clone the body into it.
if (!functionToSpecialize.mangledName.empty()) { if (!functionToSpecialize.mangledName.empty()) {
if (moduleManager.getNamedFunction(functionToSpecialize.mangledName)) { if (moduleManager.lookupSymbol<mlir::FuncOp>(
functionToSpecialize.mangledName)) {
funcWorklist.pop_back(); funcWorklist.pop_back();
// FuncOp already specialized, move on. // FuncOp already specialized, move on.
return mlir::success(); return mlir::success();
@ -298,7 +299,7 @@ public:
// restart after the callee is processed. // restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) { if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName(); auto calleeName = callOp.getCalleeName();
auto callee = moduleManager.getNamedFunction(calleeName); auto callee = moduleManager.lookupSymbol<mlir::FuncOp>(calleeName);
if (!callee) { if (!callee) {
signalPassFailure(); signalPassFailure();
return f.emitError("Shape inference failed, call to unknown '") return f.emitError("Shape inference failed, call to unknown '")
@ -307,7 +308,8 @@ public:
auto mangledName = mangle(calleeName, op->getOpOperands()); auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n"); << "', mangled: '" << mangledName << "'\n");
auto mangledCallee = moduleManager.getNamedFunction(mangledName); auto mangledCallee =
moduleManager.lookupSymbol<mlir::FuncOp>(mangledName);
if (!mangledCallee) { if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the // Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now. // callee and stop the inference for the current function now.

View File

@ -22,7 +22,6 @@
#ifndef MLIR_IR_MODULE_H #ifndef MLIR_IR_MODULE_H
#define MLIR_IR_MODULE_H #define MLIR_IR_MODULE_H
#include "mlir/IR/Function.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
namespace mlir { namespace mlir {
@ -93,11 +92,6 @@ public:
insertPt = Block::iterator(body->getTerminator()); insertPt = Block::iterator(body->getTerminator());
body->getOperations().insert(insertPt, op); body->getOperations().insert(insertPt, op);
} }
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them. Note: This
/// performs a linear scan of held symbols.
FuncOp getNamedFunction(StringRef name) { return lookupSymbol<FuncOp>(name); }
}; };
/// The ModuleTerminatorOp is a special terminator operation for the body of a /// The ModuleTerminatorOp is a special terminator operation for the body of a
@ -130,8 +124,8 @@ public:
/// Look up a symbol with the specified name, returning null if no such /// Look up a symbol with the specified name, returning null if no such
/// name exists. Names must never include the @ on them. /// name exists. Names must never include the @ on them.
template <typename NameTy> FuncOp getNamedFunction(NameTy &&name) const { template <typename T, typename NameTy> T lookupSymbol(NameTy &&name) const {
return symbolTable.lookup<FuncOp>(name); return symbolTable.lookup<T>(name);
} }
/// Insert a new symbol into the module, auto-renaming it as necessary. /// Insert a new symbol into the module, auto-renaming it as necessary.

View File

@ -18,6 +18,7 @@
#ifndef MLIR_PASS_ANALYSISMANAGER_H #ifndef MLIR_PASS_ANALYSISMANAGER_H
#define MLIR_PASS_ANALYSISMANAGER_H #define MLIR_PASS_ANALYSISMANAGER_H
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/Pass/PassInstrumentation.h" #include "mlir/Pass/PassInstrumentation.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"

View File

@ -34,6 +34,7 @@
namespace mlir { namespace mlir {
class Attribute; class Attribute;
class FuncOp;
class Location; class Location;
class ModuleOp; class ModuleOp;
class Operation; class Operation;

View File

@ -152,7 +152,7 @@ private:
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
ModuleOp module = getModule(); ModuleOp module = getModule();
Builder builder(module); Builder builder(module);
if (!module.getNamedFunction(cuModuleLoadName)) { if (!module.lookupSymbol<FuncOp>(cuModuleLoadName)) {
module.push_back( module.push_back(
FuncOp::create(loc, cuModuleLoadName, FuncOp::create(loc, cuModuleLoadName,
builder.getFunctionType( builder.getFunctionType(
@ -162,7 +162,7 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
}, },
getCUResultType()))); getCUResultType())));
} }
if (!module.getNamedFunction(cuModuleGetFunctionName)) { if (!module.lookupSymbol<FuncOp>(cuModuleGetFunctionName)) {
// The helper uses void* instead of CUDA's opaque CUmodule and // The helper uses void* instead of CUDA's opaque CUmodule and
// CUfunction. // CUfunction.
module.push_back( module.push_back(
@ -175,7 +175,7 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
}, },
getCUResultType()))); getCUResultType())));
} }
if (!module.getNamedFunction(cuLaunchKernelName)) { if (!module.lookupSymbol<FuncOp>(cuLaunchKernelName)) {
// Other than the CUDA api, the wrappers use uintptr_t to match the // Other than the CUDA api, the wrappers use uintptr_t to match the
// LLVM type if MLIR's index type, which the GPU dialect uses. // LLVM type if MLIR's index type, which the GPU dialect uses.
// Furthermore, they use void* instead of CUDA's opaque CUfunction and // Furthermore, they use void* instead of CUDA's opaque CUfunction and
@ -198,14 +198,14 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
}, },
getCUResultType()))); getCUResultType())));
} }
if (!module.getNamedFunction(cuGetStreamHelperName)) { if (!module.lookupSymbol<FuncOp>(cuGetStreamHelperName)) {
// Helper function to get the current CUDA stream. Uses void* instead of // Helper function to get the current CUDA stream. Uses void* instead of
// CUDAs opaque CUstream. // CUDAs opaque CUstream.
module.push_back(FuncOp::create( module.push_back(FuncOp::create(
loc, cuGetStreamHelperName, loc, cuGetStreamHelperName,
builder.getFunctionType({}, getPointerType() /* void *stream */))); builder.getFunctionType({}, getPointerType() /* void *stream */)));
} }
if (!module.getNamedFunction(cuStreamSynchronizeName)) { if (!module.lookupSymbol<FuncOp>(cuStreamSynchronizeName)) {
module.push_back( module.push_back(
FuncOp::create(loc, cuStreamSynchronizeName, FuncOp::create(loc, cuStreamSynchronizeName,
builder.getFunctionType( builder.getFunctionType(
@ -322,7 +322,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// Emit a call to the cubin getter to retrieve a pointer to the data that // Emit a call to the cubin getter to retrieve a pointer to the data that
// represents the cubin at runtime. // represents the cubin at runtime.
// TODO(herhut): This should rather be a static global once supported. // TODO(herhut): This should rather be a static global once supported.
auto kernelFunction = getModule().getNamedFunction(launchOp.kernel()); auto kernelFunction = getModule().lookupSymbol<FuncOp>(launchOp.kernel());
auto cubinGetter = auto cubinGetter =
kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation); kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
if (!cubinGetter) { if (!cubinGetter) {
@ -335,7 +335,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// Emit the load module call to load the module data. Error checking is done // Emit the load module call to load the module data. Error checking is done
// in the called helper function. // in the called helper function.
auto cuModule = allocatePointer(builder, loc); auto cuModule = allocatePointer(builder, loc);
FuncOp cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName); FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()}, builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleLoad), builder.getFunctionAttr(cuModuleLoad),
ArrayRef<Value *>{cuModule, data.getResult(0)}); ArrayRef<Value *>{cuModule, data.getResult(0)});
@ -346,19 +346,19 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder); auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
auto cuFunction = allocatePointer(builder, loc); auto cuFunction = allocatePointer(builder, loc);
FuncOp cuModuleGetFunction = FuncOp cuModuleGetFunction =
getModule().getNamedFunction(cuModuleGetFunctionName); getModule().lookupSymbol<FuncOp>(cuModuleGetFunctionName);
builder.create<LLVM::CallOp>( builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()}, loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleGetFunction), builder.getFunctionAttr(cuModuleGetFunction),
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName}); ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
// Grab the global stream needed for execution. // Grab the global stream needed for execution.
FuncOp cuGetStreamHelper = FuncOp cuGetStreamHelper =
getModule().getNamedFunction(cuGetStreamHelperName); getModule().lookupSymbol<FuncOp>(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>( auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()}, loc, ArrayRef<Type>{getPointerType()},
builder.getFunctionAttr(cuGetStreamHelper), ArrayRef<Value *>{}); builder.getFunctionAttr(cuGetStreamHelper), ArrayRef<Value *>{});
// Invoke the function with required arguments. // Invoke the function with required arguments.
auto cuLaunchKernel = getModule().getNamedFunction(cuLaunchKernelName); auto cuLaunchKernel = getModule().lookupSymbol<FuncOp>(cuLaunchKernelName);
auto cuFunctionRef = auto cuFunctionRef =
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction); builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
auto paramsArray = setupParamsArray(launchOp, builder); auto paramsArray = setupParamsArray(launchOp, builder);
@ -375,7 +375,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
paramsArray, /* kernel params */ paramsArray, /* kernel params */
nullpointer /* extra */}); nullpointer /* extra */});
// Sync on the stream to make it synchronous. // Sync on the stream to make it synchronous.
auto cuStreamSync = getModule().getNamedFunction(cuStreamSynchronizeName); auto cuStreamSync = getModule().lookupSymbol<FuncOp>(cuStreamSynchronizeName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()}, builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuStreamSync), builder.getFunctionAttr(cuStreamSync),
ArrayRef<Value *>(cuStream.getResult(0))); ArrayRef<Value *>(cuStream.getResult(0)));

View File

@ -60,7 +60,7 @@ private:
} }
FuncOp getMallocHelper(Location loc, Builder &builder) { FuncOp getMallocHelper(Location loc, Builder &builder) {
FuncOp result = getModule().getNamedFunction(kMallocHelperName); FuncOp result = getModule().lookupSymbol<FuncOp>(kMallocHelperName);
if (!result) { if (!result) {
result = FuncOp::create( result = FuncOp::create(
loc, kMallocHelperName, loc, kMallocHelperName,

View File

@ -442,7 +442,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Insert the `malloc` declaration if it is not already present. // Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>(); auto module = op->getParentOfType<ModuleOp>();
FuncOp mallocFunc = module.getNamedFunction("malloc"); FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
if (!mallocFunc) { if (!mallocFunc) {
auto mallocType = auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType()); rewriter.getFunctionType(getIndexType(), getVoidPtrType());
@ -503,7 +503,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
OperandAdaptor<DeallocOp> transformed(operands); OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present. // Insert the `free` declaration if it is not already present.
FuncOp freeFunc = op->getParentOfType<ModuleOp>().getNamedFunction("free"); FuncOp freeFunc =
op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
if (!freeFunc) { if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);

View File

@ -426,7 +426,7 @@ LogicalResult LaunchFuncOp::verify() {
} }
auto module = getParentOfType<ModuleOp>(); auto module = getParentOfType<ModuleOp>();
FuncOp kernelFunc = module.getNamedFunction(kernel()); FuncOp kernelFunc = module.lookupSymbol<FuncOp>(kernel());
if (!kernelFunc) if (!kernelFunc)
return emitError() << "kernel function '" << kernelAttr << "' is undefined"; return emitError() << "kernel function '" << kernelAttr << "' is undefined";

View File

@ -27,6 +27,7 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"

View File

@ -171,7 +171,7 @@ public:
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present. // Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>(); auto module = op->getParentOfType<ModuleOp>();
FuncOp mallocFunc = module.getNamedFunction("malloc"); FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
if (!mallocFunc) { if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
mallocFunc = mallocFunc =
@ -232,7 +232,7 @@ public:
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present. // Insert the `free` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>(); auto module = op->getParentOfType<ModuleOp>();
FuncOp freeFunc = module.getNamedFunction("free"); FuncOp freeFunc = module.lookupSymbol<FuncOp>("free");
if (!freeFunc) { if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {}); auto freeType = rewriter.getFunctionType(voidPtrTy, {});
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
@ -576,7 +576,7 @@ public:
static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) { static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) {
auto implFnName = (libFn.getName().str() + "_impl"); auto implFnName = (libFn.getName().str() + "_impl");
auto module = libFn.getParentOfType<ModuleOp>(); auto module = libFn.getParentOfType<ModuleOp>();
if (auto f = module.getNamedFunction(implFnName)) { if (auto f = module.lookupSymbol<FuncOp>(implFnName)) {
return f; return f;
} }
SmallVector<Type, 4> fnArgTypes; SmallVector<Type, 4> fnArgTypes;
@ -603,7 +603,7 @@ static FuncOp getLLVMLibraryCallDeclaration(Operation *op,
assert(isa<LinalgOp>(op)); assert(isa<LinalgOp>(op));
auto fnName = LinalgOp::getLibraryCallName(); auto fnName = LinalgOp::getLibraryCallName();
auto module = op->getParentOfType<ModuleOp>(); auto module = op->getParentOfType<ModuleOp>();
if (auto f = module.getNamedFunction(fnName)) { if (auto f = module.lookupSymbol<FuncOp>(fnName)) {
return f; return f;
} }

View File

@ -20,6 +20,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/SPIRV/SPIRVOps.h" #include "mlir/SPIRV/SPIRVOps.h"
#include "mlir/SPIRV/Serialization.h" #include "mlir/SPIRV/Serialization.h"

View File

@ -431,7 +431,8 @@ static LogicalResult verify(CallOp op) {
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee"); auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr) if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute"); return op.emitOpError("requires a 'callee' function attribute");
auto fn = op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue()); auto fn =
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
if (!fn) if (!fn)
return op.emitOpError() << "'" << fnAttr.getValue() return op.emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function"; << "' does not reference a valid function";
@ -1098,7 +1099,7 @@ static LogicalResult verify(ConstantOp &op) {
// Try to find the referenced function. // Try to find the referenced function.
auto fn = auto fn =
op.getParentOfType<ModuleOp>().getNamedFunction(fnAttr.getValue()); op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
if (!fn) if (!fn)
return op.emitOpError("reference to undefined function 'bar'"); return op.emitOpError("reference to undefined function 'bar'");

View File

@ -164,7 +164,7 @@ static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
static Error compileAndExecuteFunctionWithMemRefs( static Error compileAndExecuteFunctionWithMemRefs(
ModuleOp module, StringRef entryPoint, ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) { std::function<llvm::Error(llvm::Module *)> transformer) {
FuncOp mainFunction = module.getNamedFunction(entryPoint); FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
if (!mainFunction || mainFunction.getBlocks().empty()) { if (!mainFunction || mainFunction.getBlocks().empty()) {
return make_string_error("entry point not found"); return make_string_error("entry point not found");
} }
@ -207,7 +207,7 @@ static Error compileAndExecuteFunctionWithMemRefs(
static Error compileAndExecuteSingleFloatReturnFunction( static Error compileAndExecuteSingleFloatReturnFunction(
ModuleOp module, StringRef entryPoint, ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) { std::function<llvm::Error(llvm::Module *)> transformer) {
FuncOp mainFunction = module.getNamedFunction(entryPoint); FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
if (!mainFunction || mainFunction.isExternal()) { if (!mainFunction || mainFunction.isExternal()) {
return make_string_error("entry point not found"); return make_string_error("entry point not found");
} }