[mlir][sparse][taco] Handle tensor copy and trivial reduction expression.

Handle tensor copy, such as A[i, j] = B[i, j]. Also, handle trivial
reduction expression, such as A[i] = B[i, j].

Add unit tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D119867
This commit is contained in:
Bixia Zheng 2022-02-15 10:17:51 -08:00
parent b8438a6975
commit 746c68eafd
2 changed files with 54 additions and 5 deletions

View File

@ -1683,10 +1683,16 @@ def _mark_structured_op_root(
to perform a reduction.
expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
"""
expr_info = expr_to_info[expr]
if isinstance(expr, Access):
# Handle simple reduction expression in the format of A[i] = B[i, j].
if reduce_index in expr_info.src_indices:
expr_info.reduce_indices.add(reduce_index)
return
assert (isinstance(expr, _BinaryExpr))
a_info = expr_to_info[expr.a]
b_info = expr_to_info[expr.b]
expr_info = expr_to_info[expr]
if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
expr_info.reduce_indices.add(reduce_index)
@ -1724,6 +1730,9 @@ def _accumulate_reduce_indices(
| expr_info.reduce_indices)
else:
assert isinstance(expr, Access)
# Handle simple reduction expression in the format of A[i] = B[i, j].
expr_info.acc_reduce_indices = expr_info.reduce_indices
def _gather_structured_op(
@ -1821,9 +1830,10 @@ def _gather_structured_op_input(
structop_inputs: The resulting list of IndexExpr that provide input to the
current structured op.
"""
if (expr != root and expr not in structop_inputs) and (
isinstance(expr, Access) or
(expr in expr_to_info and expr_to_info[expr].structop_info)):
if ((expr != root or isinstance(expr, Access)) and
expr not in structop_inputs) and (isinstance(expr, Access) or
(expr in expr_to_info and
expr_to_info[expr].structop_info)):
structop_inputs.append(expr)
@ -1843,7 +1853,7 @@ def _emit_structured_op_input(
An OperandDef in the linalg dialect for the input IndexExpr.
"""
op_info = expr_to_info[expr].structop_info
if op_info:
if op_info and not isinstance(expr, Access):
# The input is a temporary tensor produced by another structured op.
indices = op_info.dst_indices
name = op_info.dst_name

View File

@ -37,3 +37,42 @@ def test_tensor_true_dense():
passed += (a.shape[0] == 5)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK-LABEL: test_tensor_copy
@testing_utils.run_test
def test_tensor_copy():
i, j = mlir_pytaco.get_index_vars(2)
I = 2
J = 3
A = mlir_pytaco.Tensor([I, J])
A.insert([0, 1], 5.0)
A.insert([1, 2], 6.0)
B = mlir_pytaco.Tensor([I, J])
B[i, j] = A[i, j]
indices, values = B.get_coordinates_and_values()
passed = np.array_equal(indices, [[0, 1], [1, 2]])
passed += np.allclose(values, [5.0, 6.0])
# CHECK: Number of passed: 2
print("Number of passed:", passed)
# CHECK-LABEL: test_tensor_trivial_reduction
@testing_utils.run_test
def test_tensor_trivial_reduction():
i, j = mlir_pytaco.get_index_vars(2)
I = 2
J = 3
A = mlir_pytaco.Tensor([I, J])
A.insert([0, 1], 5.0)
A.insert([0, 2], 3.0)
A.insert([1, 2], 6.0)
B = mlir_pytaco.Tensor([I])
B[i] = A[i, j]
indices, values = B.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.allclose(values, [8.0, 6.0])
# CHECK: Number of passed: 2
print("Number of passed:", passed)