forked from OSchip/llvm-project
[mlir][sparse][taco] Use the SparseCompiler from python/tools.
Copy the implementation of SparseCompiler from python/tools to taco/tools until we have a common place to install it. Modify TACO to use this SparseCompiler for compilation and jitting. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D123696
This commit is contained in:
parent
de026aeb8e
commit
cb6f8d77a2
|
@ -16,7 +16,8 @@ from mlir import execution_engine
|
||||||
from mlir import ir
|
from mlir import ir
|
||||||
from mlir import runtime
|
from mlir import runtime
|
||||||
from mlir.dialects import sparse_tensor
|
from mlir.dialects import sparse_tensor
|
||||||
from mlir.passmanager import PassManager
|
|
||||||
|
from . import mlir_sparse_compiler
|
||||||
|
|
||||||
# Type aliases for type annotation.
|
# Type aliases for type annotation.
|
||||||
_SupportFunc = Callable[..., None]
|
_SupportFunc = Callable[..., None]
|
||||||
|
@ -40,6 +41,13 @@ def _get_support_lib_name() -> str:
|
||||||
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
|
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler:
|
||||||
|
"""Gets the MLIR sparse compiler with default setting."""
|
||||||
|
return mlir_sparse_compiler.SparseCompiler(
|
||||||
|
options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
|
||||||
|
|
||||||
|
|
||||||
def _record_support_funcs(
|
def _record_support_funcs(
|
||||||
ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
|
ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
|
||||||
ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
|
ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
|
||||||
|
@ -184,10 +192,7 @@ def compile_and_build_engine(
|
||||||
A JIT execution engine for the MLIR module.
|
A JIT execution engine for the MLIR module.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pipeline = f"sparse-compiler"
|
return _get_sparse_compiler().compile_and_jit(module)
|
||||||
PassManager.parse(pipeline).run(module)
|
|
||||||
return execution_engine.ExecutionEngine(
|
|
||||||
module, opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
|
|
||||||
|
|
||||||
|
|
||||||
class _SparseTensorDescriptor(ctypes.Structure):
|
class _SparseTensorDescriptor(ctypes.Structure):
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
# This file contains the sparse compiler class. It is copied from
|
||||||
|
# test/Integration/Dialect/SparseTensor/python/ until we have a better
|
||||||
|
# solution.
|
||||||
|
|
||||||
|
from mlir import all_passes_registration
|
||||||
|
from mlir import execution_engine
|
||||||
|
from mlir import ir
|
||||||
|
from mlir import passmanager
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
class SparseCompiler:
|
||||||
|
"""Sparse compiler class for compiling and building MLIR modules."""
|
||||||
|
|
||||||
|
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
|
||||||
|
pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}'
|
||||||
|
self.pipeline = pipeline
|
||||||
|
self.opt_level = opt_level
|
||||||
|
self.shared_libs = shared_libs
|
||||||
|
|
||||||
|
def __call__(self, module: ir.Module):
|
||||||
|
"""Convenience application method."""
|
||||||
|
self.compile(module)
|
||||||
|
|
||||||
|
def compile(self, module: ir.Module):
|
||||||
|
"""Compiles the module by invoking the sparse copmiler pipeline."""
|
||||||
|
passmanager.PassManager.parse(self.pipeline).run(module)
|
||||||
|
|
||||||
|
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||||
|
"""Wraps the module in a JIT execution engine."""
|
||||||
|
return execution_engine.ExecutionEngine(
|
||||||
|
module, opt_level=self.opt_level, shared_libs=self.shared_libs)
|
||||||
|
|
||||||
|
def compile_and_jit(self,
|
||||||
|
module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||||
|
"""Compiles and jits the module."""
|
||||||
|
self.compile(module)
|
||||||
|
return self.jit(module)
|
Loading…
Reference in New Issue