Upstream MLIR PyTACO implementation.

Add TACO tests to test/Integration/Dialect/SparseTensor/taco. Add the MLIR
PyTACO implementation as tools under the directory.

Reviewed By: aartbik, mehdi_amini

Differential Revision: https://reviews.llvm.org/D117260
This commit is contained in:
Bixia Zheng 2022-01-13 16:27:28 -08:00
parent cab9616938
commit b7fd91c84b
14 changed files with 2379 additions and 0 deletions

View File

@ -1,3 +1,4 @@
numpy
pybind11>=2.8.0
PyYAML
dataclasses

View File

@ -0,0 +1,27 @@
# MLIR-PyTACO: Implementing PyTACO with MLIR
TACO (http://tensor-compiler.org/) is a tensor algebra compiler. TACO defines
PyTACO, a domain specific language in Python, for writing tensor algebra
applications.
This directory contains the implementation of PyTACO using MLIR. In particular,
we implement a Python layer that accepts the PyTACO language, generates MLIR
linalg.generic OPs with sparse tensor annotation to represent the tensor
computation, and invokes the MLIR sparse tensor code generator
(https://mlir.llvm.org/docs/Dialects/SparseTensorOps/) as well as other MLIR
compilation passes to generate an executable. Then, we invoke the MLIR execution
engine to execute the program and pass the result back to the Python layer.
As can be seen from the tests in this directory, in order to port a PyTACO
program to MLIR-PyTACO, we basically only need to replace this line that imports
PyTACO:
```python
import pytaco as pt
```
with this line to import MLIR-PyTACO:
```python
from tools import mlir_pytaco_api as pt
```

View File

@ -0,0 +1,50 @@
1 1 12
1 2 12
1 3 12
1 4 12
1 5 12
1 6 12
1 7 12
1 8 12
1 9 12
1 10 12
1 11 12
1 12 12
1 13 12
1 14 12
1 15 12
1 16 12
1 17 12
1 18 12
1 19 12
1 20 12
1 21 12
1 22 12
1 23 12
1 24 12
1 25 12
2 1 6
2 2 6
2 3 6
2 4 6
2 5 6
2 6 6
2 7 6
2 8 6
2 9 6
2 10 6
2 11 6
2 12 6
2 13 6
2 14 6
2 15 6
2 16 6
2 17 6
2 18 6
2 19 6
2 20 6
2 21 6
2 22 6
2 23 6
2 24 6
2 25 6

View File

@ -0,0 +1,4 @@
# See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format
1 37102
2 -20.4138
3 804927

View File

@ -0,0 +1,5 @@
1 1 1 1.0
1 2 2 2.0
1 3 4 3.0
2 1 1 1.0
2 4 3 2.0

View File

@ -0,0 +1,11 @@
%%MatrixMarket matrix coordinate real symmetric
%-------------------------------------------------------------------------------
% To download a matrix for a real world application
% https://math.nist.gov/MatrixMarket/
%-------------------------------------------------------------------------------
3 3 5
1 1 37423.0879671
2 1 -22.4050781162
3 1 -300.654980157
3 2 -.00869762944058
3 3 805225.750212

View File

@ -0,0 +1,53 @@
# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
import numpy as np
import os
import sys
import tempfile
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from tools import mlir_pytaco_api as pt
###### This PyTACO part is taken from the TACO open-source project. ######
# See http://tensor-compiler.org/docs/data_analytics/index.html.
compressed = pt.compressed
dense = pt.dense
# Define formats for storing the sparse tensor and dense matrices.
csf = pt.format([compressed, compressed, compressed])
rm = pt.format([dense, dense])
# Load a sparse three-dimensional tensor from file (stored in the FROSTT
# format) and store it as a compressed sparse fiber tensor. We use a small
# tensor for the purpose of testing. To run the program using the data from
# the real application, please download the data from:
# http://frostt.io/tensors/nell-2/
B = pt.read(os.path.join(_SCRIPT_PATH, "data/nell-2.tns"), csf)
# These two lines have been modified from the original program to use static
# data to support result comparison.
C = pt.from_array(np.full((B.shape[1], 25), 1, dtype=np.float64))
D = pt.from_array(np.full((B.shape[2], 25), 2, dtype=np.float64))
# Declare the result to be a dense matrix.
A = pt.tensor([B.shape[0], 25], rm)
# Declare index vars.
i, j, k, l = pt.get_index_vars(4)
# Define the MTTKRP computation.
A[i, j] = B[i, k, l] * D[l, j] * C[k, j]
##########################################################################
# CHECK: Compare result True
# Perform the MTTKRP computation and write the result to file.
with tempfile.TemporaryDirectory() as test_dir:
actual_file = os.path.join(test_dir, "A.tns")
pt.write(actual_file, A)
actual = np.loadtxt(actual_file, np.float64)
expected = np.loadtxt(
os.path.join(_SCRIPT_PATH, "data/gold_A.tns"), np.float64)
print(f"Compare result {np.allclose(actual, expected, rtol=0.01)}")

View File

@ -0,0 +1,54 @@
# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
import numpy as np
import os
import sys
import tempfile
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from tools import mlir_pytaco_api as pt
###### This PyTACO part is taken from the TACO open-source project. ######
# See http://tensor-compiler.org/docs/scientific_computing/index.html.
compressed = pt.compressed
dense = pt.dense
# Define formats for storing the sparse matrix and dense vectors.
csr = pt.format([dense, compressed])
dv = pt.format([dense])
# Load a sparse matrix stored in the matrix market format) and store it
# as a CSR matrix. The matrix in this test is a reduced version of the data
# downloaded from here:
# https://www.cise.ufl.edu/research/sparse/MM/Boeing/pwtk.tar.gz
# In order to run the program using the matrix above, you can download the
# matrix and replace this path to the actual path to the file.
A = pt.read(os.path.join(_SCRIPT_PATH, "data/pwtk.mtx"), csr)
# These two lines have been modified from the original program to use static
# data to support result comparison.
x = pt.from_array(np.full((A.shape[1],), 1, dtype=np.float64))
z = pt.from_array(np.full((A.shape[0],), 2, dtype=np.float64))
# Declare the result to be a dense vector
y = pt.tensor([A.shape[0]], dv)
# Declare index vars
i, j = pt.get_index_vars(2)
# Define the SpMV computation
y[i] = A[i, j] * x[j] + z[i]
##########################################################################
# CHECK: Compare result True
# Perform the SpMV computation and write the result to file
with tempfile.TemporaryDirectory() as test_dir:
actual_file = os.path.join(test_dir, "y.tns")
pt.write(actual_file, y)
actual = np.loadtxt(actual_file, np.float64)
expected = np.loadtxt(
os.path.join(_SCRIPT_PATH, "data/gold_y.tns"), np.float64)
print(f"Compare result {np.allclose(actual, expected, rtol=0.01)}")

View File

@ -0,0 +1,30 @@
# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
import os
import sys
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from tools import mlir_pytaco_api as pt
compressed = pt.compressed
dense = pt.dense
# Ensure that we can run an unmodified PyTACO program with a simple tensor
# algebra expression using tensor index notation, and produce the expected
# result.
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3])
B = pt.tensor([2, 3])
C = pt.tensor([2, 3])
D = pt.tensor([2, 3], dense)
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C.insert([0, 1], 5)
C.insert([1, 2], 7)
D[i, j] = A[i, j] + B[i, j] - C[i, j]
# CHECK: [20. 5. 0. 0. 0. 63.]
print(D.to_array().reshape(6))

View File

@ -0,0 +1,2 @@
# Files in this directory are tools, not tests.
config.unsupported = True

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,47 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Supports the PyTACO API with the MLIR-PyTACO implementation.
See http://tensor-compiler.org/ for TACO tensor compiler.
This module exports the MLIR-PyTACO implementation through the language defined
by PyTACO. In particular, it defines the function and type aliases and constants
needed for the PyTACO API to support the execution of PyTACO programs using the
MLIR-PyTACO implementation.
"""
from . import mlir_pytaco
from . import mlir_pytaco_io
# Functions defined by PyTACO API.
get_index_vars = mlir_pytaco.get_index_vars
from_array = mlir_pytaco.Tensor.from_array
read = mlir_pytaco_io.read
write = mlir_pytaco_io.write
# Classes defined by PyTACO API.
dtype = mlir_pytaco.DType
mode_format = mlir_pytaco.ModeFormat
mode_ordering = mlir_pytaco.ModeOrdering
mode_format_pack = mlir_pytaco.ModeFormatPack
format = mlir_pytaco.Format
index_var = mlir_pytaco.IndexVar
tensor = mlir_pytaco.Tensor
index_expression = mlir_pytaco.IndexExpr
access = mlir_pytaco.Access
# Data type constants defined by PyTACO API.
int16 = mlir_pytaco.DType(mlir_pytaco.Type.INT16)
int32 = mlir_pytaco.DType(mlir_pytaco.Type.INT32)
int64 = mlir_pytaco.DType(mlir_pytaco.Type.INT64)
float32 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32)
float64 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64)
# Storage format constants defined by the PyTACO API. In PyTACO, each storage
# format constant has two aliasing names.
compressed = mlir_pytaco.ModeFormat.COMPRESSED
Compressed = mlir_pytaco.ModeFormat.COMPRESSED
dense = mlir_pytaco.ModeFormat.DENSE
Dense = mlir_pytaco.ModeFormat.DENSE

View File

@ -0,0 +1,206 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Experimental MLIR-PyTACO with sparse tensor support.
See http://tensor-compiler.org/ for TACO tensor compiler.
This module implements the PyTACO API for writing a tensor to a file or reading
a tensor from a file.
See the following links for Matrix Market Exchange (.mtx) format and FROSTT
(.tns) format:
https://math.nist.gov/MatrixMarket/formats.html
http://frostt.io/tensors/file-formats.html
"""
from typing import List, TextIO
from . import mlir_pytaco
# Define the type aliases so that we can write the implementation here as if
# it were part of mlir_pytaco.py.
Tensor = mlir_pytaco.Tensor
Format = mlir_pytaco.Format
DType = mlir_pytaco.DType
Type = mlir_pytaco.Type
# Constants used in the implementation.
_MTX_FILENAME_SUFFIX = ".mtx"
_TNS_FILENAME_SUFFIX = ".tns"
_MTX_HEAD = "%%MatrixMarket"
_MTX_MATRIX = "matrix"
_MTX_COORDINATE = "coordinate"
_MTX_REAL = "real"
_MTX_SYMMETRY = "symmetric"
_MTX_GENERAL = "general"
_SYMMETRY_FIELD_ID = 4
# The TACO supported header for .mtx has the following five fields:
# . %%MatrixMarket
# . matrix | tensor
# . coordinate | array
# . real
# . symmetric | general
#
# This is what we support currently.
_SUPPORTED_HEADER_FIELDS = ((_MTX_HEAD,), (_MTX_MATRIX,), (_MTX_COORDINATE,),
(_MTX_REAL,), (_MTX_GENERAL, _MTX_SYMMETRY))
_A_SPACE = " "
_MTX_COMMENT = "%"
_TNS_COMMENT = "#"
def _coordinate_from_strings(strings: List[str]) -> List[int]:
""""Return the coordinate represented by the input strings."""
# Coordinates are 1-based in the text file and 0-based in memory.
return [int(s) - 1 for s in strings]
def _read_coordinate_format(file: TextIO, tensor: Tensor,
is_symmetric: bool) -> None:
"""Reads tensor values in coordinate format."""
rank = tensor.order
# Process the data for the tensor.
for line in file:
if not line:
continue
fields = line.split(_A_SPACE)
if rank != len(fields) - 1:
raise ValueError("The format and data have mismatched ranks: "
f"{rank} vs {len(fields)-1}.")
coordinate = _coordinate_from_strings(fields[:-1])
value = float(fields[-1])
tensor.insert(coordinate, value)
if is_symmetric and coordinate[0] != coordinate[-1]:
coordinate.reverse()
tensor.insert(coordinate, value)
def _read_mtx(file: TextIO, fmt: Format) -> Tensor:
"""Inputs tensor from a text file with .mtx format."""
# The first line should have this five fields:
# head tensor-kind format data-type symmetry
fields = file.readline().rstrip("\n").split(_A_SPACE)
tuple_to_str = lambda x: "|".join(x)
if len(fields) != len(_SUPPORTED_HEADER_FIELDS):
raise ValueError(
"Expected first line with theses fields "
f"{' '.join(map(tuple_to_str, _SUPPORTED_HEADER_FIELDS))}: "
f"{' '.join(fields)}")
for i, values in enumerate(_SUPPORTED_HEADER_FIELDS):
if fields[i] not in values:
raise ValueError(f"The {i}th field can only be one of these values "
f"{tuple_to_str(values)}: {fields[i]}")
is_symmetric = (fields[_SYMMETRY_FIELD_ID] == _MTX_SYMMETRY)
# Skip leading empty lines or comment lines.
line = file.readline()
while not line or line[0] == _MTX_COMMENT:
line = file.readline()
# Process the first data line with dimensions and number of non-zero values.
fields = line.split(_A_SPACE)
rank = fmt.rank()
if rank != len(fields) - 1:
raise ValueError("The format and data have mismatched ranks: "
f"{rank} vs {len(fields)-1}.")
shape = fields[:-1]
shape = [int(s) for s in shape]
num_non_zero = float(fields[-1])
# Read the tensor values in coordinate format.
tensor = Tensor(shape, fmt)
_read_coordinate_format(file, tensor, is_symmetric)
return tensor
def _read_tns(file: TextIO, fmt: Format) -> Tensor:
"""Inputs tensor from a text file with .tns format."""
rank = fmt.rank()
coordinates = []
values = []
dtype = DType(Type.FLOAT64)
for line in file:
# Skip empty lines and comment lines.
if not line or line[0] == _TNS_COMMENT:
continue
# Process each line with a coordinate and the value at the coordinate.
fields = line.split(_A_SPACE)
if rank != len(fields) - 1:
raise ValueError("The format and data have mismatched ranks: "
f"{rank} vs {len(fields)-1}.")
coordinates.append(tuple(_coordinate_from_strings(fields[:-1])))
values.append(dtype.value(fields[-1]))
return Tensor.from_coo(coordinates, values, fmt, dtype)
def _write_tns(file: TextIO, tensor: Tensor) -> None:
"""Outputs a tensor to a file using .tns format."""
coords, non_zeros = tensor.get_coordinates_and_values()
assert len(coords) == len(non_zeros)
# Output a coordinate and the corresponding value in a line.
for c, v in zip(coords, non_zeros):
# The coordinates are 1-based in the text file and 0-based in memory.
plus_one_to_str = lambda x: str(x + 1)
file.write(f"{' '.join(map(plus_one_to_str,c))} {v}\n")
def read(filename: str, fmt: Format) -> Tensor:
"""Inputs a tensor from a given file.
The name suffix of the file specifies the format of the input tensor. We
currently only support .mtx format for support sparse tensors.
Args:
filename: A string input filename.
fmt: The storage format of the tensor.
Raises:
ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
instance of Format or fmt is not a sparse tensor.
"""
if (not isinstance(filename, str) or
(not filename.endswith(_MTX_FILENAME_SUFFIX) and
not filename.endswith(_TNS_FILENAME_SUFFIX))):
raise ValueError("Expected string filename ends with "
f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
f"{filename}.")
if not isinstance(fmt, Format) or fmt.is_dense():
raise ValueError(f"Expected a sparse Format object: {fmt}.")
with open(filename, "r") as file:
return (_read_mtx(file, fmt) if filename.endswith(_MTX_FILENAME_SUFFIX) else
_read_tns(file, fmt))
def write(filename: str, tensor: Tensor) -> None:
"""Outputs a tensor to a given file.
The name suffix of the file specifies the format of the output. We currently
only support .tns format.
Args:
filename: A string output filename.
tensor: The tensor to output.
Raises:
ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
"""
if (not isinstance(filename, str) or
not filename.endswith(_TNS_FILENAME_SUFFIX)):
raise ValueError("Expected string filename ends with"
f" {_TNS_FILENAME_SUFFIX}: {filename}.")
if not isinstance(tensor, Tensor):
raise ValueError(f"Expected a Tensor object: {tensor}.")
with open(filename, "w") as file:
return _write_tns(file, tensor)

View File

@ -0,0 +1,121 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This file contains the utilities to process sparse tensor outputs.
from typing import Tuple
import ctypes
import functools
import numpy as np
@functools.lru_cache()
def _get_c_shared_lib(lib_name: str) -> ctypes.CDLL:
"""Loads and returns the requested C shared library.
Args:
lib_name: A string representing the C shared library.
Returns:
The C shared library.
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
# This raises OSError exception if there is any problem in loading the shared
# library.
c_lib = ctypes.CDLL(lib_name)
try:
c_lib.convertToMLIRSparseTensor.restype = ctypes.c_void_p
except Exception as e:
raise ValueError("Missing function convertToMLIRSparseTensor from "
f"the supporting C shared library: {e} ") from e
try:
c_lib.convertFromMLIRSparseTensor.restype = ctypes.c_void_p
except Exception as e:
raise ValueError("Missing function convertFromMLIRSparseTensor from "
f"the C shared library: {e} ") from e
return c_lib
def sparse_tensor_to_coo_tensor(
lib_name: str,
sparse_tensor: ctypes.c_void_p,
dtype: np.dtype,
) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
"""Converts an MLIR sparse tensor to a COO-flavored format tensor.
Args:
lib_name: A string for the supporting C shared library.
sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
dtype: The numpy data type for the tensor elements.
Returns:
A tuple that contains the following values for the COO-flavored format
tensor:
rank: An integer for the rank of the tensor.
nse: An interger for the number of non-zero values in the tensor.
shape: A 1D numpy array of integers, for the shape of the tensor.
values: A 1D numpy array, for the non-zero values in the tensor.
indices: A 2D numpy array of integers, representing the indices for the
non-zero values in the tensor.
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
c_lib = _get_c_shared_lib(lib_name)
rank = ctypes.c_ulonglong(0)
nse = ctypes.c_ulonglong(0)
shape = ctypes.POINTER(ctypes.c_ulonglong)()
values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
indices = ctypes.POINTER(ctypes.c_ulonglong)()
c_lib.convertFromMLIRSparseTensor(sparse_tensor, ctypes.byref(rank),
ctypes.byref(nse), ctypes.byref(shape),
ctypes.byref(values), ctypes.byref(indices))
# Convert the returned values to the corresponding numpy types.
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
values = np.ctypeslib.as_array(values, shape=[nse.value])
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
return rank, nse, shape, values, indices
def coo_tensor_to_sparse_tensor(lib_name: str, np_shape: np.ndarray,
np_values: np.ndarray,
np_indices: np.ndarray) -> int:
"""Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
Args:
lib_name: A string for the supporting C shared library.
np_shape: A 1D numpy array of integers, for the shape of the tensor.
np_values: A 1D numpy array, for the non-zero values in the tensor.
np_indices: A 2D numpy array of integers, representing the indices for the
non-zero values in the tensor.
Returns:
An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
descriptor.
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
rank = ctypes.c_ulonglong(len(np_shape))
nse = ctypes.c_ulonglong(len(np_values))
shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
values = np_values.ctypes.data_as(
ctypes.POINTER(np.ctypeslib.as_ctypes_type(np_values.dtype)))
indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
c_lib = _get_c_shared_lib(lib_name)
ptr = c_lib.convertToMLIRSparseTensor(rank, nse, shape, values, indices)
assert ptr is not None, "Problem with calling convertToMLIRSparseTensor"
return ptr