diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py index 4e2e8ba43930..909738a61760 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py @@ -16,7 +16,8 @@ from mlir import execution_engine from mlir import ir from mlir import runtime from mlir.dialects import sparse_tensor -from mlir.passmanager import PassManager + +from . import mlir_sparse_compiler # Type aliases for type annotation. _SupportFunc = Callable[..., None] @@ -40,6 +41,13 @@ def _get_support_lib_name() -> str: return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB) +@functools.lru_cache() +def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler: + """Gets the MLIR sparse compiler with default setting.""" + return mlir_sparse_compiler.SparseCompiler( + options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()]) + + def _record_support_funcs( ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc, ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None: @@ -184,10 +192,7 @@ def compile_and_build_engine( A JIT execution engine for the MLIR module. """ - pipeline = f"sparse-compiler" - PassManager.parse(pipeline).run(module) - return execution_engine.ExecutionEngine( - module, opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()]) + return _get_sparse_compiler().compile_and_jit(module) class _SparseTensorDescriptor(ctypes.Structure): diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py new file mode 100644 index 000000000000..58e08d9a4e9a --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py @@ -0,0 +1,42 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This file contains the sparse compiler class. It is copied from +# test/Integration/Dialect/SparseTensor/python/ until we have a better +# solution. + +from mlir import all_passes_registration +from mlir import execution_engine +from mlir import ir +from mlir import passmanager +from typing import Sequence + + +class SparseCompiler: + """Sparse compiler class for compiling and building MLIR modules.""" + + def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): + pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}' + self.pipeline = pipeline + self.opt_level = opt_level + self.shared_libs = shared_libs + + def __call__(self, module: ir.Module): + """Convenience application method.""" + self.compile(module) + + def compile(self, module: ir.Module): + """Compiles the module by invoking the sparse copmiler pipeline.""" + passmanager.PassManager.parse(self.pipeline).run(module) + + def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: + """Wraps the module in a JIT execution engine.""" + return execution_engine.ExecutionEngine( + module, opt_level=self.opt_level, shared_libs=self.shared_libs) + + def compile_and_jit(self, + module: ir.Module) -> execution_engine.ExecutionEngine: + """Compiles and jits the module.""" + self.compile(module) + return self.jit(module)