[mlir][taco] Add a utility to create an MLIR sparse tensor from a file.

Move the functions that retrieve the supporting C library, compile an MLIR
module and build a JIT execution engine to mlir_pytaco_utils.

Add a function to create an MLIR sparse tensor from a file and return a pointer
to the MLIR sparse tensor as well as the shape of the sparse tensor.

Add unit tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D118496
This commit is contained in:
Bixia Zheng 2022-01-28 10:56:50 -08:00
parent 46add4901f
commit ae7ee655a9
3 changed files with 289 additions and 62 deletions

View File

@ -30,8 +30,6 @@ import os
import threading import threading
# Import MLIR related modules. # Import MLIR related modules.
from mlir import all_passes_registration # Register MLIR compiler passes.
from mlir import execution_engine
from mlir import ir from mlir import ir
from mlir import runtime from mlir import runtime
from mlir.dialects import arith from mlir.dialects import arith
@ -40,7 +38,6 @@ from mlir.dialects import linalg
from mlir.dialects import std from mlir.dialects import std
from mlir.dialects import sparse_tensor from mlir.dialects import sparse_tensor
from mlir.dialects.linalg.opdsl import lang from mlir.dialects.linalg.opdsl import lang
from mlir.passmanager import PassManager
from . import mlir_pytaco_utils as utils from . import mlir_pytaco_utils as utils
@ -51,13 +48,6 @@ _TACO_TENSOR_PREFIX = "A"
# Bitwidths for pointers and indices. # Bitwidths for pointers and indices.
_POINTER_BIT_WIDTH = 0 _POINTER_BIT_WIDTH = 0
_INDEX_BIT_WIDTH = 0 _INDEX_BIT_WIDTH = 0
# The name for the environment variable that provides the full path for the
# supporting library.
_SUPPORTLIB_ENV_VAR = "SUPPORTLIB"
# The default supporting library if the environment variable is not provided.
_DEFAULT_SUPPORTLIB = "libmlir_c_runner_utils.so"
# The JIT compiler optimization level.
_OPT_LEVEL = 2
# The entry point to the JIT compiled program. # The entry point to the JIT compiled program.
_ENTRY_NAME = "main" _ENTRY_NAME = "main"
@ -134,33 +124,6 @@ def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
return dtype_to_irtype[dtype.kind] return dtype_to_irtype[dtype.kind]
def _compile_mlir(module: ir.Module) -> ir.Module:
"""Compiles an MLIR module and returns the compiled module."""
# TODO: Replace this with a pipeline implemented for
# https://github.com/llvm/llvm-project/issues/51751.
pipeline = (
f"sparsification,"
f"sparse-tensor-conversion,"
f"builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
f"convert-scf-to-std,"
f"func-bufferize,"
f"arith-bufferize,"
f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),"
f"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},"
f"lower-affine,"
f"convert-memref-to-llvm,"
f"convert-std-to-llvm,"
f"reconcile-unrealized-casts")
PassManager.parse(pipeline).run(module)
return module
@functools.lru_cache()
def _get_support_lib_name() -> str:
"""Returns the string for the supporting C shared library."""
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer: def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer:
"""Returns the ctype pointer for the given numpy array.""" """Returns the ctype pointer for the given numpy array."""
return ctypes.pointer( return ctypes.pointer(
@ -900,8 +863,7 @@ class Tensor:
shape = np.array(self._shape, np.int64) shape = np.array(self._shape, np.int64)
indices = np.array(self._coords, np.int64) indices = np.array(self._coords, np.int64)
values = np.array(self._values, self._dtype.value) values = np.array(self._values, self._dtype.value)
ptr = utils.coo_tensor_to_sparse_tensor(_get_support_lib_name(), shape, ptr = utils.coo_tensor_to_sparse_tensor(shape, values, indices)
values, indices)
return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p)) return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
def get_coordinates_and_values( def get_coordinates_and_values(
@ -1316,18 +1278,12 @@ class IndexExpr(abc.ABC):
input_accesses = [] input_accesses = []
self._visit(_gather_input_accesses_index_vars, (input_accesses,)) self._visit(_gather_input_accesses_index_vars, (input_accesses,))
support_lib = _get_support_lib_name()
# Build and compile the module to produce the execution engine. # Build and compile the module to produce the execution engine.
with ir.Context(), ir.Location.unknown(): with ir.Context(), ir.Location.unknown():
module = ir.Module.create() module = ir.Module.create()
self._emit_assignment(module, dst, dst_indices, expr_to_info, self._emit_assignment(module, dst, dst_indices, expr_to_info,
input_accesses) input_accesses)
compiled_module = _compile_mlir(module) engine = utils.compile_and_build_engine(module)
# We currently rely on an environment to pass in the full path of a
# supporting library for the execution engine.
engine = execution_engine.ExecutionEngine(
compiled_module, opt_level=_OPT_LEVEL, shared_libs=[support_lib])
# Gather the pointers for the input buffers. # Gather the pointers for the input buffers.
input_pointers = [a.tensor.ctype_pointer() for a in input_accesses] input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
@ -1351,7 +1307,6 @@ class IndexExpr(abc.ABC):
# Check and return the sparse tensor output. # Check and return the sparse tensor output.
rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor( rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
support_lib,
ctypes.cast(arg_pointers[-1][0], ctypes.c_void_p), ctypes.cast(arg_pointers[-1][0], ctypes.c_void_p),
np.float64, np.float64,
) )

View File

@ -4,21 +4,47 @@
# This file contains the utilities to process sparse tensor outputs. # This file contains the utilities to process sparse tensor outputs.
from typing import Tuple from typing import Sequence, Tuple
import ctypes import ctypes
import functools import functools
import numpy as np import numpy as np
import os
# Import MLIR related modules.
from mlir import all_passes_registration # Register MLIR compiler passes.
from mlir import execution_engine
from mlir import ir
from mlir import runtime
from mlir.dialects import sparse_tensor
from mlir.passmanager import PassManager
# The name for the environment variable that provides the full path for the
# supporting library.
_SUPPORTLIB_ENV_VAR = "SUPPORTLIB"
# The default supporting library if the environment variable is not provided.
_DEFAULT_SUPPORTLIB = "libmlir_c_runner_utils.so"
# The JIT compiler optimization level.
_OPT_LEVEL = 2
# The entry point to the JIT compiled program.
_ENTRY_NAME = "main"
@functools.lru_cache() @functools.lru_cache()
def _get_c_shared_lib(lib_name: str) -> ctypes.CDLL: def _get_support_lib_name() -> str:
"""Loads and returns the requested C shared library. """Gets the string name for the supporting C shared library."""
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
Args:
lib_name: A string representing the C shared library. @functools.lru_cache()
def _get_c_shared_lib() -> ctypes.CDLL:
"""Loads the supporting C shared library with the needed routines.
The name of the supporting C shared library is either provided by an
an environment variable or a default value.
Returns: Returns:
The C shared library. The supporting C shared library.
Raises: Raises:
OSError: If there is any problem in loading the shared library. OSError: If there is any problem in loading the shared library.
@ -26,7 +52,7 @@ def _get_c_shared_lib(lib_name: str) -> ctypes.CDLL:
""" """
# This raises OSError exception if there is any problem in loading the shared # This raises OSError exception if there is any problem in loading the shared
# library. # library.
c_lib = ctypes.CDLL(lib_name) c_lib = ctypes.CDLL(_get_support_lib_name())
try: try:
c_lib.convertToMLIRSparseTensor.restype = ctypes.c_void_p c_lib.convertToMLIRSparseTensor.restype = ctypes.c_void_p
@ -44,14 +70,12 @@ def _get_c_shared_lib(lib_name: str) -> ctypes.CDLL:
def sparse_tensor_to_coo_tensor( def sparse_tensor_to_coo_tensor(
lib_name: str,
sparse_tensor: ctypes.c_void_p, sparse_tensor: ctypes.c_void_p,
dtype: np.dtype, dtype: np.dtype,
) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
"""Converts an MLIR sparse tensor to a COO-flavored format tensor. """Converts an MLIR sparse tensor to a COO-flavored format tensor.
Args: Args:
lib_name: A string for the supporting C shared library.
sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor. sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
dtype: The numpy data type for the tensor elements. dtype: The numpy data type for the tensor elements.
@ -69,7 +93,7 @@ def sparse_tensor_to_coo_tensor(
OSError: If there is any problem in loading the shared library. OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines. ValueError: If the shared library doesn't contain the needed routines.
""" """
c_lib = _get_c_shared_lib(lib_name) c_lib = _get_c_shared_lib()
rank = ctypes.c_ulonglong(0) rank = ctypes.c_ulonglong(0)
nse = ctypes.c_ulonglong(0) nse = ctypes.c_ulonglong(0)
@ -84,16 +108,14 @@ def sparse_tensor_to_coo_tensor(
shape = np.ctypeslib.as_array(shape, shape=[rank.value]) shape = np.ctypeslib.as_array(shape, shape=[rank.value])
values = np.ctypeslib.as_array(values, shape=[nse.value]) values = np.ctypeslib.as_array(values, shape=[nse.value])
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value]) indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
return rank, nse, shape, values, indices return rank.value, nse.value, shape, values, indices
def coo_tensor_to_sparse_tensor(lib_name: str, np_shape: np.ndarray, def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
np_values: np.ndarray,
np_indices: np.ndarray) -> int: np_indices: np.ndarray) -> int:
"""Converts a COO-flavored format sparse tensor to an MLIR sparse tensor. """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
Args: Args:
lib_name: A string for the supporting C shared library.
np_shape: A 1D numpy array of integers, for the shape of the tensor. np_shape: A 1D numpy array of integers, for the shape of the tensor.
np_values: A 1D numpy array, for the non-zero values in the tensor. np_values: A 1D numpy array, for the non-zero values in the tensor.
np_indices: A 2D numpy array of integers, representing the indices for the np_indices: A 2D numpy array of integers, representing the indices for the
@ -115,7 +137,136 @@ def coo_tensor_to_sparse_tensor(lib_name: str, np_shape: np.ndarray,
ctypes.POINTER(np.ctypeslib.as_ctypes_type(np_values.dtype))) ctypes.POINTER(np.ctypeslib.as_ctypes_type(np_values.dtype)))
indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong)) indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
c_lib = _get_c_shared_lib(lib_name) c_lib = _get_c_shared_lib()
ptr = c_lib.convertToMLIRSparseTensor(rank, nse, shape, values, indices) ptr = c_lib.convertToMLIRSparseTensor(rank, nse, shape, values, indices)
assert ptr is not None, "Problem with calling convertToMLIRSparseTensor" assert ptr is not None, "Problem with calling convertToMLIRSparseTensor"
return ptr return ptr
def compile_and_build_engine(
module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles an MLIR module and builds a JIT execution engine.
Args:
module: The MLIR module.
Returns:
A JIT execution engine for the MLIR module.
"""
pipeline = (
f"sparsification,"
f"sparse-tensor-conversion,"
f"builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
f"convert-scf-to-std,"
f"func-bufferize,"
f"arith-bufferize,"
f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),"
f"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},"
f"lower-affine,"
f"convert-memref-to-llvm,"
f"convert-std-to-llvm,"
f"reconcile-unrealized-casts")
PassManager.parse(pipeline).run(module)
return execution_engine.ExecutionEngine(
module, opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
class _SparseTensorDescriptor(ctypes.Structure):
"""A C structure for an MLIR sparse tensor."""
_fields_ = [
# A pointer for the MLIR sparse tensor storage.
("storage", ctypes.POINTER(ctypes.c_ulonglong)),
# An MLIR MemRef descriptor for the shape of the sparse tensor.
("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
]
def _output_one_dim(dim: int, rank: int, shape: str) -> str:
"""Produces the MLIR text code to output the size for the given dimension."""
return f"""
%c{dim} = arith.constant {dim} : index
%d{dim} = tensor.dim %t, %c{dim} : tensor<{shape}xf64, #enc>
memref.store %d{dim}, %b[%c{dim}] : memref<{rank}xindex>
"""
# TODO: With better support from MLIR, we may improve the current implementation
# by doing the following:
# (1) Use Python code to generate the kernel instead of doing MLIR text code
# stitching.
# (2) Use scf.for instead of an unrolled loop to write out the dimension sizes
# when tensor.dim supports non-constant dimension value.
def _get_create_sparse_tensor_kernel(
sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> str:
"""Creates an MLIR text kernel to contruct a sparse tensor from a file.
The kernel returns a _SparseTensorDescriptor structure.
"""
rank = len(sparsity_codes)
# Use ? to represent a dimension in the dynamic shape string representation.
shape = "x".join(map(lambda d: "?", range(rank)))
# Convert the encoded sparsity values to a string representation.
sparsity = ", ".join(
map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
# Get the MLIR text code to write the dimension sizes to the output buffer.
output_dims = "\n".join(
map(lambda d: _output_one_dim(d, rank, shape), range(rank)))
# Return the MLIR text kernel.
return f"""
!Ptr = type !llvm.ptr<i8>
#enc = #sparse_tensor.encoding<{{
dimLevelType = [ {sparsity} ]
}}>
func @{_ENTRY_NAME}(%filename: !Ptr) -> (tensor<{shape}xf64, #enc>, memref<{rank}xindex>)
attributes {{ llvm.emit_c_interface }} {{
%t = sparse_tensor.new %filename : !Ptr to tensor<{shape}xf64, #enc>
%b = memref.alloc() : memref<{rank}xindex>
{output_dims}
return %t, %b : tensor<{shape}xf64, #enc>, memref<{rank}xindex>
}}"""
def create_sparse_tensor(
filename: str, sparsity: Sequence[sparse_tensor.DimLevelType]
) -> Tuple[ctypes.c_void_p, np.ndarray]:
"""Creates an MLIR sparse tensor from the input file.
Args:
filename: A string for the name of the file that contains the tensor data in
a COO-flavored format.
sparsity: A sequence of DimLevelType values, one for each dimension of the
tensor.
Returns:
A Tuple containing the following values:
storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
shape: A 1D numpy array of integers, for the shape of the tensor.
Raises:
OSError: If there is any problem in loading the supporting C shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
with ir.Context() as ctx, ir.Location.unknown():
module = _get_create_sparse_tensor_kernel(sparsity)
module = ir.Module.parse(module)
engine = compile_and_build_engine(module)
# A sparse tensor descriptor to receive the kernel result.
c_tensor_desc = _SparseTensorDescriptor()
# Convert the filename to a byte stream.
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
arg_pointers = [
ctypes.byref(ctypes.pointer(c_tensor_desc)),
ctypes.byref(c_filename)
]
# Invoke the execution engine to run the module and return the result.
engine.invoke(_ENTRY_NAME, *arg_pointers)
shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
return c_tensor_desc.storage, shape

View File

@ -0,0 +1,121 @@
# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
from typing import Sequence
import dataclasses
import numpy as np
import os
import sys
import tempfile
from mlir.dialects import sparse_tensor
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from tools import mlir_pytaco
from tools import mlir_pytaco_utils as pytaco_utils
# Define the aliases to shorten the code.
_COMPRESSED = mlir_pytaco.ModeFormat.COMPRESSED
_DENSE = mlir_pytaco.ModeFormat.DENSE
def _to_string(s: Sequence[int]) -> str:
"""Converts a sequence of integer to a space separated value string."""
return " ".join(map(lambda e: str(e), s))
def _add_one(s: Sequence[int]) -> Sequence[int]:
"""Adds one to each element in the sequence of integer."""
return [i + 1 for i in s]
@dataclasses.dataclass(frozen=True)
class _SparseTensorCOO:
"""Values for a COO-flavored format sparse tensor.
Attributes:
rank: An integer rank for the tensor.
nse: An integer for the number of non-zero values.
shape: A sequence of integer for the dimension size.
values: A sequence of float for the non-zero values of the tensor.
indices: A sequence of coordinate, each coordinate is a sequence of integer.
"""
rank: int
nse: int
shape: Sequence[int]
values: Sequence[float]
indices: Sequence[Sequence[int]]
def _coo_values_to_tns_format(t: _SparseTensorCOO) -> str:
"""Converts a sparse tensor COO-flavored values to TNS text format."""
# The coo_value_str contains one line for each (coordinate value) pair.
# Indices are 1-based in TNS text format but 0-based in MLIR.
coo_value_str = "\n".join(
map(lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
range(t.nse)))
# Returns the TNS text format representation for the tensor.
return f"""{t.rank} {t.nse}
{_to_string(t.shape)}
{coo_value_str}
"""
def _implement_read_tns_test(
t: _SparseTensorCOO,
sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> int:
tns_data = _coo_values_to_tns_format(t)
# Write sparse tensor data to a file.
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
with open(file_name, "w") as file:
file.write(tns_data)
# Read the data from the file and construct an MLIR sparse tensor.
sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
file_name, sparsity_codes)
passed = 0
# Verify the output shape for the tensor.
if np.allclose(o_shape, t.shape):
passed += 1
# Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
# values and verify the values.
o_rank, o_nse, o_shape, o_values, o_indices = (
pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64))
if o_rank == t.rank and o_nse == t.nse and np.allclose(
o_shape, t.shape) and np.allclose(o_values, t.values) and np.allclose(
o_indices, t.indices):
passed += 1
return passed
# A 2D sparse tensor data in COO-flavored format.
_rank = 2
_nse = 3
_shape = [4, 5]
_values = [3.0, 2.0, 4.0]
_indices = [[0, 4], [1, 0], [3, 1]]
_t = _SparseTensorCOO(_rank, _nse, _shape, _values, _indices)
_s = [_COMPRESSED, _COMPRESSED]
# CHECK: PASSED 2D: 2
print("PASSED 2D: ", _implement_read_tns_test(_t, _s))
# A 3D sparse tensor data in COO-flavored format.
_rank = 3
_nse = 3
_shape = [2, 5, 4]
_values = [3.0, 2.0, 4.0]
_indices = [[0, 4, 3], [1, 3, 0], [1, 3, 1]]
_t = _SparseTensorCOO(_rank, _nse, _shape, _values, _indices)
_s = [_DENSE, _COMPRESSED, _COMPRESSED]
# CHECK: PASSED 3D: 2
print("PASSED 3D: ", _implement_read_tns_test(_t, _s))