forked from OSchip/llvm-project
[mlir][python] Simplify python extension loading.
* Now that packaging has stabilized, removes old mechanisms for loading extensions, preferring direct importing. * Removes _cext_loader.py, _dlloader.py as unnecessary. * Fixes the path where the CAPI dll is written on Windows. This enables that path of least resistance loading behavior to work with no further drama (see: https://bugs.python.org/issue36085). * With this patch, `ninja check-mlir` on Windows with Python bindings works for me, modulo some failures that are actually due to a couple of pre-existing Windows bugs. I think this is the first time the Windows Python bindings have worked upstream. * Downstream changes needed: * If downstreams are using the now removed `load_extension`, `reexport_cext`, etc, then those should be replaced with normal import statements as done in this patch. Reviewed By: jdd, aartbik Differential Revision: https://reviews.llvm.org/D108489
This commit is contained in:
parent
78fbd1aa3d
commit
cb7b03819a
|
@ -371,6 +371,9 @@ function(add_mlir_python_common_capi_library name)
|
|||
set_target_properties(${name} PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
|
||||
BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
|
||||
# Needed for windows (and don't hurt others).
|
||||
RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
|
||||
ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
|
||||
)
|
||||
mlir_python_setup_extension_rpath(${name}
|
||||
RELATIVE_INSTALL_ROOT "${ARG_RELATIVE_INSTALL_ROOT}"
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
|
@ -25,6 +27,9 @@ PyGlobals *PyGlobals::instance = nullptr;
|
|||
PyGlobals::PyGlobals() {
|
||||
assert(!instance && "PyGlobals already constructed");
|
||||
instance = this;
|
||||
// The default search path include {mlir.}dialects, where {mlir.} is the
|
||||
// package prefix configured at compile time.
|
||||
dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
|
||||
}
|
||||
|
||||
PyGlobals::~PyGlobals() { instance = nullptr; }
|
||||
|
|
|
@ -20,8 +20,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core
|
|||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
ADD_TO_PARENT MLIRPythonSources
|
||||
SOURCES
|
||||
_cext_loader.py
|
||||
_dlloader.py
|
||||
_mlir_libs/__init__.py
|
||||
ir.py
|
||||
passmanager.py
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Common module for looking up and manipulating C-Extensions."""
|
||||
|
||||
# The normal layout is to have a nested _mlir_libs package that contains
|
||||
# all native libraries and extensions. If that exists, use it, but also fallback
|
||||
# to old behavior where extensions were at the top level as loose libraries.
|
||||
# TODO: Remove the fallback once downstreams adapt.
|
||||
try:
|
||||
from ._mlir_libs import *
|
||||
# TODO: Remove these aliases once everything migrates
|
||||
_preload_dependency = preload_dependency
|
||||
_load_extension = load_extension
|
||||
except ModuleNotFoundError:
|
||||
# Assume that we are in-tree.
|
||||
# The _dlloader takes care of platform specific setup before we try to
|
||||
# load a shared library.
|
||||
# TODO: Remove _dlloader once all consolidated on the _mlir_libs approach.
|
||||
from ._dlloader import preload_dependency
|
||||
|
||||
def load_extension(name):
|
||||
import importlib
|
||||
return importlib.import_module(name) # i.e. '_mlir' at the top level
|
||||
|
||||
preload_dependency("MLIRPythonCAPI")
|
||||
|
||||
# Expose the corresponding C-Extension module with a well-known name at this
|
||||
# top-level module. This allows relative imports like the following to
|
||||
# function:
|
||||
# from .._cext_loader import _cext
|
||||
# This reduces coupling, allowing embedding of the python sources into another
|
||||
# project that can just vary based on this top-level loader module.
|
||||
_cext = load_extension("_mlir")
|
||||
|
||||
|
||||
def _reexport_cext(cext_module_name, target_module_name):
|
||||
"""Re-exports a named sub-module of the C-Extension into another module.
|
||||
|
||||
Typically:
|
||||
from ._cext_loader import _reexport_cext
|
||||
_reexport_cext("ir", __name__)
|
||||
del _reexport_cext
|
||||
"""
|
||||
import sys
|
||||
target_module = sys.modules[target_module_name]
|
||||
submodule_names = cext_module_name.split(".")
|
||||
source_module = _cext
|
||||
for submodule_name in submodule_names:
|
||||
source_module = getattr(source_module, submodule_name)
|
||||
for attr_name in dir(source_module):
|
||||
if not attr_name.startswith("__"):
|
||||
setattr(target_module, attr_name, getattr(source_module, attr_name))
|
||||
|
||||
|
||||
# Add our 'dialects' parent module to the search path for implementations.
|
||||
_cext.globals.append_dialect_search_prefix("mlir.dialects")
|
|
@ -1,59 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import os
|
||||
import platform
|
||||
|
||||
_is_windows = platform.system() == "Windows"
|
||||
_this_directory = os.path.dirname(__file__)
|
||||
|
||||
# The standard LLVM build/install tree for Windows is laid out as:
|
||||
# bin/
|
||||
# MLIRPublicAPI.dll
|
||||
# python/
|
||||
# _mlir.*.pyd (dll extension)
|
||||
# mlir/
|
||||
# _dlloader.py (this file)
|
||||
# First check the python/ directory level for DLLs co-located with the pyd
|
||||
# file, and then fall back to searching the bin/ directory.
|
||||
# TODO: This should be configurable at some point.
|
||||
_dll_search_path = [
|
||||
os.path.join(_this_directory, ".."),
|
||||
os.path.join(_this_directory, "..", "..", "bin"),
|
||||
]
|
||||
|
||||
# Stash loaded DLLs to keep them alive.
|
||||
_loaded_dlls = []
|
||||
|
||||
def preload_dependency(public_name):
|
||||
"""Preloads a dylib by its soname or DLL name.
|
||||
|
||||
On Windows and Linux, doing this prior to loading a dependency will populate
|
||||
the library in the flat namespace so that a subsequent library that depend
|
||||
on it will resolve to this preloaded version.
|
||||
|
||||
On OSX, resolution is completely path based so this facility no-ops. On
|
||||
Linux, as long as RPATHs are setup properly, resolution is path based but
|
||||
this facility can still act as an escape hatch for relocatable distributions.
|
||||
"""
|
||||
if _is_windows:
|
||||
_preload_dependency_windows(public_name)
|
||||
|
||||
|
||||
def _preload_dependency_windows(public_name):
|
||||
dll_basename = public_name + ".dll"
|
||||
found_path = None
|
||||
for search_dir in _dll_search_path:
|
||||
candidate_path = os.path.join(search_dir, dll_basename)
|
||||
if os.path.exists(candidate_path):
|
||||
found_path = candidate_path
|
||||
break
|
||||
|
||||
if found_path is None:
|
||||
raise RuntimeError(
|
||||
f"Unable to find dependency DLL {dll_basename} in search "
|
||||
f"path {_dll_search_path}")
|
||||
|
||||
import ctypes
|
||||
_loaded_dlls.append(ctypes.CDLL(found_path))
|
|
@ -4,24 +4,10 @@
|
|||
|
||||
from typing import Sequence
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
__all__ = [
|
||||
"load_extension",
|
||||
"preload_dependency",
|
||||
]
|
||||
|
||||
_this_dir = os.path.dirname(__file__)
|
||||
|
||||
def load_extension(name):
|
||||
return importlib.import_module(f".{name}", __package__)
|
||||
|
||||
|
||||
def preload_dependency(public_name):
|
||||
# TODO: Implement this hook to pre-load DLLs with ctypes on Windows.
|
||||
pass
|
||||
|
||||
|
||||
def get_lib_dirs() -> Sequence[str]:
|
||||
"""Gets the lib directory for linking to shared libraries.
|
||||
|
|
|
@ -2,7 +2,4 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._cext_loader import _load_extension
|
||||
|
||||
_cextAllPasses = _load_extension("_mlirAllPassesRegistration")
|
||||
del _load_extension
|
||||
from .._mlir_libs import _mlirAllPassesRegistration as _cextAllPasses
|
||||
|
|
|
@ -4,5 +4,4 @@
|
|||
|
||||
# Expose the corresponding C-Extension module with a well-known name at this
|
||||
# level.
|
||||
from .._cext_loader import _load_extension
|
||||
_cextConversions = _load_extension("_mlirConversions")
|
||||
from .._mlir_libs import _mlirConversions as _cextConversions
|
||||
|
|
|
@ -6,10 +6,7 @@ try:
|
|||
from typing import Optional, Sequence, Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context
|
||||
# TODO: resolve name collision for Linalg functionality that is injected inside
|
||||
# the _mlir.dialects.linalg directly via pybind.
|
||||
from .._cext_loader import _cext
|
||||
fill_builtin_region = _cext.dialects.linalg.fill_builtin_region
|
||||
from .._mlir_libs._mlir.dialects.linalg import fill_builtin_region
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
@ -29,12 +26,11 @@ class FillOp:
|
|||
results = []
|
||||
if isa(RankedTensorType, output.type):
|
||||
results = [output.type]
|
||||
op = self.build_generic(
|
||||
results=results,
|
||||
operands=[value, output],
|
||||
attributes=None,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
op = self.build_generic(results=results,
|
||||
operands=[value, output],
|
||||
attributes=None,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
OpView.__init__(self, op)
|
||||
linalgDialect = Context.current.get_dialect_descriptor("linalg")
|
||||
fill_builtin_region(linalgDialect, self.operation)
|
||||
|
@ -78,12 +74,11 @@ class InitTensorOp:
|
|||
attributes["static_sizes"] = ArrayAttr.get(
|
||||
[IntegerAttr.get(i64_type, s) for s in static_size_ints],
|
||||
context=context)
|
||||
op = self.build_generic(
|
||||
results=[result_type],
|
||||
operands=operands,
|
||||
attributes=attributes,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
op = self.build_generic(results=[result_type],
|
||||
operands=operands,
|
||||
attributes=attributes,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
OpView.__init__(self, op)
|
||||
|
||||
|
||||
|
@ -92,11 +87,10 @@ class StructuredOpMixin:
|
|||
|
||||
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
results=list(results),
|
||||
operands=[list(inputs), list(outputs)],
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
self.build_generic(results=list(results),
|
||||
operands=[list(inputs), list(outputs)],
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
|
||||
|
||||
def select_opview_mixin(parent_opview_cls):
|
||||
|
|
|
@ -2,8 +2,9 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Re-export the parent _cext so that every level of the API can get it locally.
|
||||
from .._cext_loader import _cext
|
||||
# Provide a convenient name for sub-packages to resolve the main C-extension
|
||||
# with a relative import.
|
||||
from .._mlir_libs import _mlir as _cext
|
||||
|
||||
__all__ = [
|
||||
"equally_sized_accessor",
|
||||
|
|
|
@ -2,5 +2,4 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ...._cext_loader import _load_extension
|
||||
_cextAsyncPasses = _load_extension("_mlirAsyncPasses")
|
||||
from ...._mlir_libs import _mlirAsyncPasses as _cextAsyncPasses
|
||||
|
|
|
@ -2,5 +2,4 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ...._cext_loader import _load_extension
|
||||
_cextGPUPasses = _load_extension("_mlirGPUPasses")
|
||||
from ...._mlir_libs import _mlirGPUPasses as _cextGPUPasses
|
||||
|
|
|
@ -5,13 +5,11 @@
|
|||
from typing import Dict, Sequence
|
||||
|
||||
from .....ir import *
|
||||
from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
|
||||
|
||||
from .... import linalg
|
||||
from .... import std
|
||||
from .... import math
|
||||
# TODO: resolve name collision for Linalg functionality that is injected inside
|
||||
# the _mlir.dialects.linalg directly via pybind.
|
||||
from ....._cext_loader import _cext
|
||||
fill_builtin_region = _cext.dialects.linalg.fill_builtin_region
|
||||
|
||||
from .scalar_expr import *
|
||||
from .config import *
|
||||
|
@ -216,8 +214,8 @@ class _BodyBuilder:
|
|||
value_attr = Attribute.parse(expr.scalar_const.value)
|
||||
return std.ConstantOp(value_attr.type, value_attr).result
|
||||
elif expr.scalar_index:
|
||||
dim_attr = IntegerAttr.get(
|
||||
IntegerType.get_signless(64), expr.scalar_index.dim)
|
||||
dim_attr = IntegerAttr.get(IntegerType.get_signless(64),
|
||||
expr.scalar_index.dim)
|
||||
return linalg.IndexOp(IndexType.get(), dim_attr).result
|
||||
elif expr.scalar_apply:
|
||||
try:
|
||||
|
|
|
@ -2,5 +2,4 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ...._cext_loader import _load_extension
|
||||
_cextLinalgPasses = _load_extension("_mlirLinalgPasses")
|
||||
from ...._mlir_libs import _mlirLinalgPasses as _cextLinalgPasses
|
||||
|
|
|
@ -2,11 +2,5 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._cext_loader import _reexport_cext
|
||||
from .._cext_loader import _load_extension
|
||||
|
||||
_reexport_cext("dialects.sparse_tensor", __name__)
|
||||
_cextSparseTensorPasses = _load_extension("_mlirSparseTensorPasses")
|
||||
|
||||
del _reexport_cext
|
||||
del _load_extension
|
||||
from .._mlir_libs._mlir.dialects.sparse_tensor import *
|
||||
from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses
|
||||
|
|
|
@ -3,8 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Simply a wrapper around the extension module of the same name.
|
||||
from ._cext_loader import load_extension
|
||||
_execution_engine = load_extension("_mlirExecutionEngine")
|
||||
from ._mlir_libs import _mlirExecutionEngine as _execution_engine
|
||||
import ctypes
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -2,8 +2,5 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Simply a wrapper around the extension module of the same name.
|
||||
from ._cext_loader import _reexport_cext
|
||||
_reexport_cext("ir", __name__)
|
||||
del _reexport_cext
|
||||
|
||||
from ._mlir_libs._mlir.ir import *
|
||||
from ._mlir_libs._mlir.ir import _GlobalDebug
|
||||
|
|
|
@ -2,7 +2,4 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Simply a wrapper around the extension module of the same name.
|
||||
from ._cext_loader import _reexport_cext
|
||||
_reexport_cext("passmanager", __name__)
|
||||
del _reexport_cext
|
||||
from ._mlir_libs._mlir.passmanager import *
|
||||
|
|
|
@ -4,5 +4,4 @@
|
|||
|
||||
# Expose the corresponding C-Extension module with a well-known name at this
|
||||
# level.
|
||||
from .._cext_loader import _load_extension
|
||||
_cextTransforms = _load_extension("_mlirTransforms")
|
||||
from .._mlir_libs import _mlirTransforms as _cextTransforms
|
||||
|
|
Loading…
Reference in New Issue