[mlir][sparse][taco] Use the SparseCompiler from python/tools.

Copy the implementation of SparseCompiler from python/tools to taco/tools until we have a common place to install it. Modify TACO to use this SparseCompiler for compilation and jitting.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D123696
This commit is contained in:
Bixia Zheng 2022-04-14 09:41:27 -07:00
parent de026aeb8e
commit cb6f8d77a2
2 changed files with 52 additions and 5 deletions

View File

@ -16,7 +16,8 @@ from mlir import execution_engine
from mlir import ir from mlir import ir
from mlir import runtime from mlir import runtime
from mlir.dialects import sparse_tensor from mlir.dialects import sparse_tensor
from mlir.passmanager import PassManager
from . import mlir_sparse_compiler
# Type aliases for type annotation. # Type aliases for type annotation.
_SupportFunc = Callable[..., None] _SupportFunc = Callable[..., None]
@ -40,6 +41,13 @@ def _get_support_lib_name() -> str:
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB) return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
@functools.lru_cache()
def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler:
"""Gets the MLIR sparse compiler with default setting."""
return mlir_sparse_compiler.SparseCompiler(
options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
def _record_support_funcs( def _record_support_funcs(
ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc, ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None: ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
@ -184,10 +192,7 @@ def compile_and_build_engine(
A JIT execution engine for the MLIR module. A JIT execution engine for the MLIR module.
""" """
pipeline = f"sparse-compiler" return _get_sparse_compiler().compile_and_jit(module)
PassManager.parse(pipeline).run(module)
return execution_engine.ExecutionEngine(
module, opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
class _SparseTensorDescriptor(ctypes.Structure): class _SparseTensorDescriptor(ctypes.Structure):

View File

@ -0,0 +1,42 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This file contains the sparse compiler class. It is copied from
# test/Integration/Dialect/SparseTensor/python/ until we have a better
# solution.
from mlir import all_passes_registration
from mlir import execution_engine
from mlir import ir
from mlir import passmanager
from typing import Sequence
class SparseCompiler:
"""Sparse compiler class for compiling and building MLIR modules."""
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}'
self.pipeline = pipeline
self.opt_level = opt_level
self.shared_libs = shared_libs
def __call__(self, module: ir.Module):
"""Convenience application method."""
self.compile(module)
def compile(self, module: ir.Module):
"""Compiles the module by invoking the sparse copmiler pipeline."""
passmanager.PassManager.parse(self.pipeline).run(module)
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
return execution_engine.ExecutionEngine(
module, opt_level=self.opt_level, shared_libs=self.shared_libs)
def compile_and_jit(self,
module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
self.compile(module)
return self.jit(module)