forked from OSchip/llvm-project
[mlir][sparse][taco] Use the SparseCompiler from python/tools.
Copy the implementation of SparseCompiler from python/tools to taco/tools until we have a common place to install it. Modify TACO to use this SparseCompiler for compilation and jitting. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D123696
This commit is contained in:
parent
de026aeb8e
commit
cb6f8d77a2
|
@ -16,7 +16,8 @@ from mlir import execution_engine
|
|||
from mlir import ir
|
||||
from mlir import runtime
|
||||
from mlir.dialects import sparse_tensor
|
||||
from mlir.passmanager import PassManager
|
||||
|
||||
from . import mlir_sparse_compiler
|
||||
|
||||
# Type aliases for type annotation.
|
||||
_SupportFunc = Callable[..., None]
|
||||
|
@ -40,6 +41,13 @@ def _get_support_lib_name() -> str:
|
|||
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler:
|
||||
"""Gets the MLIR sparse compiler with default setting."""
|
||||
return mlir_sparse_compiler.SparseCompiler(
|
||||
options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
|
||||
|
||||
|
||||
def _record_support_funcs(
|
||||
ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
|
||||
ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
|
||||
|
@ -184,10 +192,7 @@ def compile_and_build_engine(
|
|||
A JIT execution engine for the MLIR module.
|
||||
|
||||
"""
|
||||
pipeline = f"sparse-compiler"
|
||||
PassManager.parse(pipeline).run(module)
|
||||
return execution_engine.ExecutionEngine(
|
||||
module, opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
|
||||
return _get_sparse_compiler().compile_and_jit(module)
|
||||
|
||||
|
||||
class _SparseTensorDescriptor(ctypes.Structure):
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# This file contains the sparse compiler class. It is copied from
|
||||
# test/Integration/Dialect/SparseTensor/python/ until we have a better
|
||||
# solution.
|
||||
|
||||
from mlir import all_passes_registration
|
||||
from mlir import execution_engine
|
||||
from mlir import ir
|
||||
from mlir import passmanager
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
class SparseCompiler:
|
||||
"""Sparse compiler class for compiling and building MLIR modules."""
|
||||
|
||||
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
|
||||
pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}'
|
||||
self.pipeline = pipeline
|
||||
self.opt_level = opt_level
|
||||
self.shared_libs = shared_libs
|
||||
|
||||
def __call__(self, module: ir.Module):
|
||||
"""Convenience application method."""
|
||||
self.compile(module)
|
||||
|
||||
def compile(self, module: ir.Module):
|
||||
"""Compiles the module by invoking the sparse copmiler pipeline."""
|
||||
passmanager.PassManager.parse(self.pipeline).run(module)
|
||||
|
||||
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Wraps the module in a JIT execution engine."""
|
||||
return execution_engine.ExecutionEngine(
|
||||
module, opt_level=self.opt_level, shared_libs=self.shared_libs)
|
||||
|
||||
def compile_and_jit(self,
|
||||
module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Compiles and jits the module."""
|
||||
self.compile(module)
|
||||
return self.jit(module)
|
Loading…
Reference in New Issue