[mlir][taco] Add a utility to create an MLIR sparse tensor from a file.

Move the functions that retrieve the supporting C library, compile an MLIR
module and build a JIT execution engine to mlir_pytaco_utils.

Add a function to create an MLIR sparse tensor from a file and return a pointer
to the MLIR sparse tensor as well as the shape of the sparse tensor.

Add unit tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D118496
This commit is contained in:
Bixia Zheng 2022-01-28 10:56:50 -08:00
parent 46add4901f
commit ae7ee655a9
3 changed files with 289 additions and 62 deletions

View File

@ -30,8 +30,6 @@ import os
import threading
# Import MLIR related modules.
from mlir import all_passes_registration # Register MLIR compiler passes.
from mlir import execution_engine
from mlir import ir
from mlir import runtime
from mlir.dialects import arith
@ -40,7 +38,6 @@ from mlir.dialects import linalg
from mlir.dialects import std
from mlir.dialects import sparse_tensor
from mlir.dialects.linalg.opdsl import lang
from mlir.passmanager import PassManager
from . import mlir_pytaco_utils as utils
@ -51,13 +48,6 @@ _TACO_TENSOR_PREFIX = "A"
# Bitwidths for pointers and indices.
_POINTER_BIT_WIDTH = 0
_INDEX_BIT_WIDTH = 0
# The name for the environment variable that provides the full path for the
# supporting library.
_SUPPORTLIB_ENV_VAR = "SUPPORTLIB"
# The default supporting library if the environment variable is not provided.
_DEFAULT_SUPPORTLIB = "libmlir_c_runner_utils.so"
# The JIT compiler optimization level.
_OPT_LEVEL = 2
# The entry point to the JIT compiled program.
_ENTRY_NAME = "main"
@ -134,33 +124,6 @@ def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
return dtype_to_irtype[dtype.kind]
def _compile_mlir(module: ir.Module) -> ir.Module:
"""Compiles an MLIR module and returns the compiled module."""
# TODO: Replace this with a pipeline implemented for
# https://github.com/llvm/llvm-project/issues/51751.
pipeline = (
f"sparsification,"
f"sparse-tensor-conversion,"
f"builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
f"convert-scf-to-std,"
f"func-bufferize,"
f"arith-bufferize,"
f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),"
f"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},"
f"lower-affine,"
f"convert-memref-to-llvm,"
f"convert-std-to-llvm,"
f"reconcile-unrealized-casts")
PassManager.parse(pipeline).run(module)
return module
@functools.lru_cache()
def _get_support_lib_name() -> str:
"""Returns the string for the supporting C shared library."""
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer:
"""Returns the ctype pointer for the given numpy array."""
return ctypes.pointer(
@ -900,8 +863,7 @@ class Tensor:
shape = np.array(self._shape, np.int64)
indices = np.array(self._coords, np.int64)
values = np.array(self._values, self._dtype.value)
ptr = utils.coo_tensor_to_sparse_tensor(_get_support_lib_name(), shape,
values, indices)
ptr = utils.coo_tensor_to_sparse_tensor(shape, values, indices)
return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
def get_coordinates_and_values(
@ -1316,18 +1278,12 @@ class IndexExpr(abc.ABC):
input_accesses = []
self._visit(_gather_input_accesses_index_vars, (input_accesses,))
support_lib = _get_support_lib_name()
# Build and compile the module to produce the execution engine.
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
self._emit_assignment(module, dst, dst_indices, expr_to_info,
input_accesses)
compiled_module = _compile_mlir(module)
# We currently rely on an environment to pass in the full path of a
# supporting library for the execution engine.
engine = execution_engine.ExecutionEngine(
compiled_module, opt_level=_OPT_LEVEL, shared_libs=[support_lib])
engine = utils.compile_and_build_engine(module)
# Gather the pointers for the input buffers.
input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
@ -1351,7 +1307,6 @@ class IndexExpr(abc.ABC):
# Check and return the sparse tensor output.
rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
support_lib,
ctypes.cast(arg_pointers[-1][0], ctypes.c_void_p),
np.float64,
)

View File

@ -4,21 +4,47 @@
# This file contains the utilities to process sparse tensor outputs.
from typing import Tuple
from typing import Sequence, Tuple
import ctypes
import functools
import numpy as np
import os
# Import MLIR related modules.
from mlir import all_passes_registration # Register MLIR compiler passes.
from mlir import execution_engine
from mlir import ir
from mlir import runtime
from mlir.dialects import sparse_tensor
from mlir.passmanager import PassManager
# The name for the environment variable that provides the full path for the
# supporting library.
_SUPPORTLIB_ENV_VAR = "SUPPORTLIB"
# The default supporting library if the environment variable is not provided.
_DEFAULT_SUPPORTLIB = "libmlir_c_runner_utils.so"
# The JIT compiler optimization level.
_OPT_LEVEL = 2
# The entry point to the JIT compiled program.
_ENTRY_NAME = "main"
@functools.lru_cache()
def _get_c_shared_lib(lib_name: str) -> ctypes.CDLL:
"""Loads and returns the requested C shared library.
def _get_support_lib_name() -> str:
"""Gets the string name for the supporting C shared library."""
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
Args:
lib_name: A string representing the C shared library.
@functools.lru_cache()
def _get_c_shared_lib() -> ctypes.CDLL:
"""Loads the supporting C shared library with the needed routines.
The name of the supporting C shared library is either provided by an
an environment variable or a default value.
Returns:
The C shared library.
The supporting C shared library.
Raises:
OSError: If there is any problem in loading the shared library.
@ -26,7 +52,7 @@ def _get_c_shared_lib(lib_name: str) -> ctypes.CDLL:
"""
# This raises OSError exception if there is any problem in loading the shared
# library.
c_lib = ctypes.CDLL(lib_name)
c_lib = ctypes.CDLL(_get_support_lib_name())
try:
c_lib.convertToMLIRSparseTensor.restype = ctypes.c_void_p
@ -44,14 +70,12 @@ def _get_c_shared_lib(lib_name: str) -> ctypes.CDLL:
def sparse_tensor_to_coo_tensor(
lib_name: str,
sparse_tensor: ctypes.c_void_p,
dtype: np.dtype,
) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
"""Converts an MLIR sparse tensor to a COO-flavored format tensor.
Args:
lib_name: A string for the supporting C shared library.
sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
dtype: The numpy data type for the tensor elements.
@ -69,7 +93,7 @@ def sparse_tensor_to_coo_tensor(
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
c_lib = _get_c_shared_lib(lib_name)
c_lib = _get_c_shared_lib()
rank = ctypes.c_ulonglong(0)
nse = ctypes.c_ulonglong(0)
@ -84,16 +108,14 @@ def sparse_tensor_to_coo_tensor(
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
values = np.ctypeslib.as_array(values, shape=[nse.value])
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
return rank, nse, shape, values, indices
return rank.value, nse.value, shape, values, indices
def coo_tensor_to_sparse_tensor(lib_name: str, np_shape: np.ndarray,
np_values: np.ndarray,
def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
np_indices: np.ndarray) -> int:
"""Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
Args:
lib_name: A string for the supporting C shared library.
np_shape: A 1D numpy array of integers, for the shape of the tensor.
np_values: A 1D numpy array, for the non-zero values in the tensor.
np_indices: A 2D numpy array of integers, representing the indices for the
@ -115,7 +137,136 @@ def coo_tensor_to_sparse_tensor(lib_name: str, np_shape: np.ndarray,
ctypes.POINTER(np.ctypeslib.as_ctypes_type(np_values.dtype)))
indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
c_lib = _get_c_shared_lib(lib_name)
c_lib = _get_c_shared_lib()
ptr = c_lib.convertToMLIRSparseTensor(rank, nse, shape, values, indices)
assert ptr is not None, "Problem with calling convertToMLIRSparseTensor"
return ptr
def compile_and_build_engine(
module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles an MLIR module and builds a JIT execution engine.
Args:
module: The MLIR module.
Returns:
A JIT execution engine for the MLIR module.
"""
pipeline = (
f"sparsification,"
f"sparse-tensor-conversion,"
f"builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf),"
f"convert-scf-to-std,"
f"func-bufferize,"
f"arith-bufferize,"
f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),"
f"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},"
f"lower-affine,"
f"convert-memref-to-llvm,"
f"convert-std-to-llvm,"
f"reconcile-unrealized-casts")
PassManager.parse(pipeline).run(module)
return execution_engine.ExecutionEngine(
module, opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
class _SparseTensorDescriptor(ctypes.Structure):
"""A C structure for an MLIR sparse tensor."""
_fields_ = [
# A pointer for the MLIR sparse tensor storage.
("storage", ctypes.POINTER(ctypes.c_ulonglong)),
# An MLIR MemRef descriptor for the shape of the sparse tensor.
("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
]
def _output_one_dim(dim: int, rank: int, shape: str) -> str:
"""Produces the MLIR text code to output the size for the given dimension."""
return f"""
%c{dim} = arith.constant {dim} : index
%d{dim} = tensor.dim %t, %c{dim} : tensor<{shape}xf64, #enc>
memref.store %d{dim}, %b[%c{dim}] : memref<{rank}xindex>
"""
# TODO: With better support from MLIR, we may improve the current implementation
# by doing the following:
# (1) Use Python code to generate the kernel instead of doing MLIR text code
# stitching.
# (2) Use scf.for instead of an unrolled loop to write out the dimension sizes
# when tensor.dim supports non-constant dimension value.
def _get_create_sparse_tensor_kernel(
sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> str:
"""Creates an MLIR text kernel to contruct a sparse tensor from a file.
The kernel returns a _SparseTensorDescriptor structure.
"""
rank = len(sparsity_codes)
# Use ? to represent a dimension in the dynamic shape string representation.
shape = "x".join(map(lambda d: "?", range(rank)))
# Convert the encoded sparsity values to a string representation.
sparsity = ", ".join(
map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
# Get the MLIR text code to write the dimension sizes to the output buffer.
output_dims = "\n".join(
map(lambda d: _output_one_dim(d, rank, shape), range(rank)))
# Return the MLIR text kernel.
return f"""
!Ptr = type !llvm.ptr<i8>
#enc = #sparse_tensor.encoding<{{
dimLevelType = [ {sparsity} ]
}}>
func @{_ENTRY_NAME}(%filename: !Ptr) -> (tensor<{shape}xf64, #enc>, memref<{rank}xindex>)
attributes {{ llvm.emit_c_interface }} {{
%t = sparse_tensor.new %filename : !Ptr to tensor<{shape}xf64, #enc>
%b = memref.alloc() : memref<{rank}xindex>
{output_dims}
return %t, %b : tensor<{shape}xf64, #enc>, memref<{rank}xindex>
}}"""
def create_sparse_tensor(
filename: str, sparsity: Sequence[sparse_tensor.DimLevelType]
) -> Tuple[ctypes.c_void_p, np.ndarray]:
"""Creates an MLIR sparse tensor from the input file.
Args:
filename: A string for the name of the file that contains the tensor data in
a COO-flavored format.
sparsity: A sequence of DimLevelType values, one for each dimension of the
tensor.
Returns:
A Tuple containing the following values:
storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
shape: A 1D numpy array of integers, for the shape of the tensor.
Raises:
OSError: If there is any problem in loading the supporting C shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
with ir.Context() as ctx, ir.Location.unknown():
module = _get_create_sparse_tensor_kernel(sparsity)
module = ir.Module.parse(module)
engine = compile_and_build_engine(module)
# A sparse tensor descriptor to receive the kernel result.
c_tensor_desc = _SparseTensorDescriptor()
# Convert the filename to a byte stream.
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
arg_pointers = [
ctypes.byref(ctypes.pointer(c_tensor_desc)),
ctypes.byref(c_filename)
]
# Invoke the execution engine to run the module and return the result.
engine.invoke(_ENTRY_NAME, *arg_pointers)
shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
return c_tensor_desc.storage, shape

View File

@ -0,0 +1,121 @@
# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
from typing import Sequence
import dataclasses
import numpy as np
import os
import sys
import tempfile
from mlir.dialects import sparse_tensor
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from tools import mlir_pytaco
from tools import mlir_pytaco_utils as pytaco_utils
# Define the aliases to shorten the code.
_COMPRESSED = mlir_pytaco.ModeFormat.COMPRESSED
_DENSE = mlir_pytaco.ModeFormat.DENSE
def _to_string(s: Sequence[int]) -> str:
"""Converts a sequence of integer to a space separated value string."""
return " ".join(map(lambda e: str(e), s))
def _add_one(s: Sequence[int]) -> Sequence[int]:
"""Adds one to each element in the sequence of integer."""
return [i + 1 for i in s]
@dataclasses.dataclass(frozen=True)
class _SparseTensorCOO:
"""Values for a COO-flavored format sparse tensor.
Attributes:
rank: An integer rank for the tensor.
nse: An integer for the number of non-zero values.
shape: A sequence of integer for the dimension size.
values: A sequence of float for the non-zero values of the tensor.
indices: A sequence of coordinate, each coordinate is a sequence of integer.
"""
rank: int
nse: int
shape: Sequence[int]
values: Sequence[float]
indices: Sequence[Sequence[int]]
def _coo_values_to_tns_format(t: _SparseTensorCOO) -> str:
"""Converts a sparse tensor COO-flavored values to TNS text format."""
# The coo_value_str contains one line for each (coordinate value) pair.
# Indices are 1-based in TNS text format but 0-based in MLIR.
coo_value_str = "\n".join(
map(lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
range(t.nse)))
# Returns the TNS text format representation for the tensor.
return f"""{t.rank} {t.nse}
{_to_string(t.shape)}
{coo_value_str}
"""
def _implement_read_tns_test(
t: _SparseTensorCOO,
sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> int:
tns_data = _coo_values_to_tns_format(t)
# Write sparse tensor data to a file.
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
with open(file_name, "w") as file:
file.write(tns_data)
# Read the data from the file and construct an MLIR sparse tensor.
sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
file_name, sparsity_codes)
passed = 0
# Verify the output shape for the tensor.
if np.allclose(o_shape, t.shape):
passed += 1
# Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
# values and verify the values.
o_rank, o_nse, o_shape, o_values, o_indices = (
pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64))
if o_rank == t.rank and o_nse == t.nse and np.allclose(
o_shape, t.shape) and np.allclose(o_values, t.values) and np.allclose(
o_indices, t.indices):
passed += 1
return passed
# A 2D sparse tensor data in COO-flavored format.
_rank = 2
_nse = 3
_shape = [4, 5]
_values = [3.0, 2.0, 4.0]
_indices = [[0, 4], [1, 0], [3, 1]]
_t = _SparseTensorCOO(_rank, _nse, _shape, _values, _indices)
_s = [_COMPRESSED, _COMPRESSED]
# CHECK: PASSED 2D: 2
print("PASSED 2D: ", _implement_read_tns_test(_t, _s))
# A 3D sparse tensor data in COO-flavored format.
_rank = 3
_nse = 3
_shape = [2, 5, 4]
_values = [3.0, 2.0, 4.0]
_indices = [[0, 4, 3], [1, 3, 0], [1, 3, 1]]
_t = _SparseTensorCOO(_rank, _nse, _shape, _values, _indices)
_s = [_DENSE, _COMPRESSED, _COMPRESSED]
# CHECK: PASSED 3D: 2
print("PASSED 3D: ", _implement_read_tns_test(_t, _s))