[mlir][python] Simplify python extension loading.

* Now that packaging has stabilized, removes old mechanisms for loading extensions, preferring direct importing.
* Removes _cext_loader.py, _dlloader.py as unnecessary.
* Fixes the path where the CAPI dll is written on Windows. This enables that path of least resistance loading behavior to work with no further drama (see: https://bugs.python.org/issue36085).
* With this patch, `ninja check-mlir` on Windows with Python bindings works for me, modulo some failures that are actually due to a couple of pre-existing Windows bugs. I think this is the first time the Windows Python bindings have worked upstream.
* Downstream changes needed:
  * If downstreams are using the now removed `load_extension`, `reexport_cext`, etc, then those should be replaced with normal import statements as done in this patch.

Reviewed By: jdd, aartbik

Differential Revision: https://reviews.llvm.org/D108489
This commit is contained in:
Stella Laurenzo 2021-09-03 00:37:00 +00:00
parent 78fbd1aa3d
commit cb7b03819a
19 changed files with 42 additions and 194 deletions

View File

@ -371,6 +371,9 @@ function(add_mlir_python_common_capi_library name)
set_target_properties(${name} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
# Needed for windows (and don't hurt others).
RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
)
mlir_python_setup_extension_rpath(${name}
RELATIVE_INSTALL_ROOT "${ARG_RELATIVE_INSTALL_ROOT}"

View File

@ -12,6 +12,8 @@
#include <vector>
#include "mlir-c/Bindings/Python/Interop.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
@ -25,6 +27,9 @@ PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
assert(!instance && "PyGlobals already constructed");
instance = this;
// The default search path include {mlir.}dialects, where {mlir.} is the
// package prefix configured at compile time.
dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
}
PyGlobals::~PyGlobals() { instance = nullptr; }

View File

@ -20,8 +20,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
ADD_TO_PARENT MLIRPythonSources
SOURCES
_cext_loader.py
_dlloader.py
_mlir_libs/__init__.py
ir.py
passmanager.py

View File

@ -1,57 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Common module for looking up and manipulating C-Extensions."""
# The normal layout is to have a nested _mlir_libs package that contains
# all native libraries and extensions. If that exists, use it, but also fallback
# to old behavior where extensions were at the top level as loose libraries.
# TODO: Remove the fallback once downstreams adapt.
try:
from ._mlir_libs import *
# TODO: Remove these aliases once everything migrates
_preload_dependency = preload_dependency
_load_extension = load_extension
except ModuleNotFoundError:
# Assume that we are in-tree.
# The _dlloader takes care of platform specific setup before we try to
# load a shared library.
# TODO: Remove _dlloader once all consolidated on the _mlir_libs approach.
from ._dlloader import preload_dependency
def load_extension(name):
import importlib
return importlib.import_module(name) # i.e. '_mlir' at the top level
preload_dependency("MLIRPythonCAPI")
# Expose the corresponding C-Extension module with a well-known name at this
# top-level module. This allows relative imports like the following to
# function:
# from .._cext_loader import _cext
# This reduces coupling, allowing embedding of the python sources into another
# project that can just vary based on this top-level loader module.
_cext = load_extension("_mlir")
def _reexport_cext(cext_module_name, target_module_name):
"""Re-exports a named sub-module of the C-Extension into another module.
Typically:
from ._cext_loader import _reexport_cext
_reexport_cext("ir", __name__)
del _reexport_cext
"""
import sys
target_module = sys.modules[target_module_name]
submodule_names = cext_module_name.split(".")
source_module = _cext
for submodule_name in submodule_names:
source_module = getattr(source_module, submodule_name)
for attr_name in dir(source_module):
if not attr_name.startswith("__"):
setattr(target_module, attr_name, getattr(source_module, attr_name))
# Add our 'dialects' parent module to the search path for implementations.
_cext.globals.append_dialect_search_prefix("mlir.dialects")

View File

@ -1,59 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import os
import platform
_is_windows = platform.system() == "Windows"
_this_directory = os.path.dirname(__file__)
# The standard LLVM build/install tree for Windows is laid out as:
# bin/
# MLIRPublicAPI.dll
# python/
# _mlir.*.pyd (dll extension)
# mlir/
# _dlloader.py (this file)
# First check the python/ directory level for DLLs co-located with the pyd
# file, and then fall back to searching the bin/ directory.
# TODO: This should be configurable at some point.
_dll_search_path = [
os.path.join(_this_directory, ".."),
os.path.join(_this_directory, "..", "..", "bin"),
]
# Stash loaded DLLs to keep them alive.
_loaded_dlls = []
def preload_dependency(public_name):
"""Preloads a dylib by its soname or DLL name.
On Windows and Linux, doing this prior to loading a dependency will populate
the library in the flat namespace so that a subsequent library that depend
on it will resolve to this preloaded version.
On OSX, resolution is completely path based so this facility no-ops. On
Linux, as long as RPATHs are setup properly, resolution is path based but
this facility can still act as an escape hatch for relocatable distributions.
"""
if _is_windows:
_preload_dependency_windows(public_name)
def _preload_dependency_windows(public_name):
dll_basename = public_name + ".dll"
found_path = None
for search_dir in _dll_search_path:
candidate_path = os.path.join(search_dir, dll_basename)
if os.path.exists(candidate_path):
found_path = candidate_path
break
if found_path is None:
raise RuntimeError(
f"Unable to find dependency DLL {dll_basename} in search "
f"path {_dll_search_path}")
import ctypes
_loaded_dlls.append(ctypes.CDLL(found_path))

View File

@ -4,24 +4,10 @@
from typing import Sequence
import importlib
import os
__all__ = [
"load_extension",
"preload_dependency",
]
_this_dir = os.path.dirname(__file__)
def load_extension(name):
return importlib.import_module(f".{name}", __package__)
def preload_dependency(public_name):
# TODO: Implement this hook to pre-load DLLs with ctypes on Windows.
pass
def get_lib_dirs() -> Sequence[str]:
"""Gets the lib directory for linking to shared libraries.

View File

@ -2,7 +2,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._cext_loader import _load_extension
_cextAllPasses = _load_extension("_mlirAllPassesRegistration")
del _load_extension
from .._mlir_libs import _mlirAllPassesRegistration as _cextAllPasses

View File

@ -4,5 +4,4 @@
# Expose the corresponding C-Extension module with a well-known name at this
# level.
from .._cext_loader import _load_extension
_cextConversions = _load_extension("_mlirConversions")
from .._mlir_libs import _mlirConversions as _cextConversions

View File

@ -6,10 +6,7 @@ try:
from typing import Optional, Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context
# TODO: resolve name collision for Linalg functionality that is injected inside
# the _mlir.dialects.linalg directly via pybind.
from .._cext_loader import _cext
fill_builtin_region = _cext.dialects.linalg.fill_builtin_region
from .._mlir_libs._mlir.dialects.linalg import fill_builtin_region
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@ -29,12 +26,11 @@ class FillOp:
results = []
if isa(RankedTensorType, output.type):
results = [output.type]
op = self.build_generic(
results=results,
operands=[value, output],
attributes=None,
loc=loc,
ip=ip)
op = self.build_generic(results=results,
operands=[value, output],
attributes=None,
loc=loc,
ip=ip)
OpView.__init__(self, op)
linalgDialect = Context.current.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, self.operation)
@ -78,12 +74,11 @@ class InitTensorOp:
attributes["static_sizes"] = ArrayAttr.get(
[IntegerAttr.get(i64_type, s) for s in static_size_ints],
context=context)
op = self.build_generic(
results=[result_type],
operands=operands,
attributes=attributes,
loc=loc,
ip=ip)
op = self.build_generic(results=[result_type],
operands=operands,
attributes=attributes,
loc=loc,
ip=ip)
OpView.__init__(self, op)
@ -92,11 +87,10 @@ class StructuredOpMixin:
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
super().__init__(
self.build_generic(
results=list(results),
operands=[list(inputs), list(outputs)],
loc=loc,
ip=ip))
self.build_generic(results=list(results),
operands=[list(inputs), list(outputs)],
loc=loc,
ip=ip))
def select_opview_mixin(parent_opview_cls):

View File

@ -2,8 +2,9 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Re-export the parent _cext so that every level of the API can get it locally.
from .._cext_loader import _cext
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
__all__ = [
"equally_sized_accessor",

View File

@ -2,5 +2,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ...._cext_loader import _load_extension
_cextAsyncPasses = _load_extension("_mlirAsyncPasses")
from ...._mlir_libs import _mlirAsyncPasses as _cextAsyncPasses

View File

@ -2,5 +2,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ...._cext_loader import _load_extension
_cextGPUPasses = _load_extension("_mlirGPUPasses")
from ...._mlir_libs import _mlirGPUPasses as _cextGPUPasses

View File

@ -5,13 +5,11 @@
from typing import Dict, Sequence
from .....ir import *
from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
from .... import linalg
from .... import std
from .... import math
# TODO: resolve name collision for Linalg functionality that is injected inside
# the _mlir.dialects.linalg directly via pybind.
from ....._cext_loader import _cext
fill_builtin_region = _cext.dialects.linalg.fill_builtin_region
from .scalar_expr import *
from .config import *
@ -216,8 +214,8 @@ class _BodyBuilder:
value_attr = Attribute.parse(expr.scalar_const.value)
return std.ConstantOp(value_attr.type, value_attr).result
elif expr.scalar_index:
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
dim_attr = IntegerAttr.get(IntegerType.get_signless(64),
expr.scalar_index.dim)
return linalg.IndexOp(IndexType.get(), dim_attr).result
elif expr.scalar_apply:
try:

View File

@ -2,5 +2,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ...._cext_loader import _load_extension
_cextLinalgPasses = _load_extension("_mlirLinalgPasses")
from ...._mlir_libs import _mlirLinalgPasses as _cextLinalgPasses

View File

@ -2,11 +2,5 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._cext_loader import _reexport_cext
from .._cext_loader import _load_extension
_reexport_cext("dialects.sparse_tensor", __name__)
_cextSparseTensorPasses = _load_extension("_mlirSparseTensorPasses")
del _reexport_cext
del _load_extension
from .._mlir_libs._mlir.dialects.sparse_tensor import *
from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses

View File

@ -3,8 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Simply a wrapper around the extension module of the same name.
from ._cext_loader import load_extension
_execution_engine = load_extension("_mlirExecutionEngine")
from ._mlir_libs import _mlirExecutionEngine as _execution_engine
import ctypes
__all__ = [

View File

@ -2,8 +2,5 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Simply a wrapper around the extension module of the same name.
from ._cext_loader import _reexport_cext
_reexport_cext("ir", __name__)
del _reexport_cext
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug

View File

@ -2,7 +2,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Simply a wrapper around the extension module of the same name.
from ._cext_loader import _reexport_cext
_reexport_cext("passmanager", __name__)
del _reexport_cext
from ._mlir_libs._mlir.passmanager import *

View File

@ -4,5 +4,4 @@
# Expose the corresponding C-Extension module with a well-known name at this
# level.
from .._cext_loader import _load_extension
_cextTransforms = _load_extension("_mlirTransforms")
from .._mlir_libs import _mlirTransforms as _cextTransforms