forked from OSchip/llvm-project
Linalg to LLVM lowering: decrease the reliance on symbol lookup in a module
During the conversion, both the original and the converted function may coexist in the module and have the same symbol name. There is no guarantee which of the two will be found by the symbol lookup. Avoid returning the result of the library function lookup when lowering Linalg to Standard or LLVM. Use the symbol reference instead. After the conversion completes, only one symbol will remain and the Ops using SymbolRefAttrs will be referring to the correct one. PiperOrigin-RevId: 273510079
This commit is contained in:
parent
11d12670da
commit
0cdc53a762
|
@ -559,20 +559,23 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Get function definition for the LinalgOp. If it doesn't exist, insert a
|
||||
// definition.
|
||||
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
|
||||
// If the library function does not exist, insert a declaration.
|
||||
template <typename LinalgOp>
|
||||
static FuncOp getLLVMLibraryCallDeclaration(Operation *op,
|
||||
PatternRewriter &rewriter) {
|
||||
static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
||||
PatternRewriter &rewriter) {
|
||||
auto linalgOp = cast<LinalgOp>(op);
|
||||
auto fnName = linalgOp.getLibraryCallName();
|
||||
if (fnName.empty()) {
|
||||
op->emitWarning("No library call defined for: ") << *op;
|
||||
return FuncOp();
|
||||
return {};
|
||||
}
|
||||
|
||||
// fnName is a dynamic std::String, unique it via a SymbolRefAttr.
|
||||
SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
if (auto f = module.lookupSymbol<FuncOp>(fnName)) {
|
||||
return f;
|
||||
if (module.lookupSymbol(fnName)) {
|
||||
return fnNameAttr;
|
||||
}
|
||||
|
||||
SmallVector<Type, 4> inputTypes(op->getOperandTypes());
|
||||
|
@ -580,14 +583,14 @@ static FuncOp getLLVMLibraryCallDeclaration(Operation *op,
|
|||
"Library call for linalg operation can be generated only for ops that "
|
||||
"have void return types");
|
||||
auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
|
||||
// fnName is a dynamic std::String, unique it via a SymbolRefAttr.
|
||||
SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
// Insert before module terminator.
|
||||
rewriter.setInsertionPoint(module.getBody(),
|
||||
std::prev(module.getBody()->end()));
|
||||
return rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
|
||||
ArrayRef<NamedAttribute>{});
|
||||
rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
|
||||
ArrayRef<NamedAttribute>{});
|
||||
return fnNameAttr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -615,14 +618,13 @@ public:
|
|||
|
||||
PatternMatchResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, rewriter);
|
||||
if (!f)
|
||||
auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
|
||||
if (!libraryCallName)
|
||||
return this->matchFailure();
|
||||
|
||||
auto fAttr = rewriter.getSymbolRefAttr(f);
|
||||
SmallVector<Value *, 4> operands(op.getOperands().begin(),
|
||||
op.getOperands().end());
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, fAttr.getValue(),
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
|
||||
ArrayRef<Type>{}, operands);
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
@ -643,14 +645,13 @@ public:
|
|||
if (outputPerm.hasValue() && !outputPerm->isIdentity())
|
||||
return matchFailure();
|
||||
|
||||
auto f = getLLVMLibraryCallDeclaration<CopyOp>(op, rewriter);
|
||||
if (!f)
|
||||
auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
|
||||
if (!libraryCallName)
|
||||
return matchFailure();
|
||||
|
||||
auto fAttr = rewriter.getSymbolRefAttr(f);
|
||||
SmallVector<Value *, 4> operands(op.getOperands().begin(),
|
||||
op.getOperands().end());
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, fAttr.getValue(),
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
|
||||
ArrayRef<Type>{}, operands);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue