forked from OSchip/llvm-project
[mlir][sparse][taco] Handle tensor copy and trivial reduction expression.
Handle tensor copy, such as A[i, j] = B[i, j]. Also, handle trivial reduction expression, such as A[i] = B[i, j]. Add unit tests. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D119867
This commit is contained in:
parent
b8438a6975
commit
746c68eafd
|
@ -1683,10 +1683,16 @@ def _mark_structured_op_root(
|
|||
to perform a reduction.
|
||||
expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
|
||||
"""
|
||||
expr_info = expr_to_info[expr]
|
||||
if isinstance(expr, Access):
|
||||
# Handle simple reduction expression in the format of A[i] = B[i, j].
|
||||
if reduce_index in expr_info.src_indices:
|
||||
expr_info.reduce_indices.add(reduce_index)
|
||||
return
|
||||
|
||||
assert (isinstance(expr, _BinaryExpr))
|
||||
a_info = expr_to_info[expr.a]
|
||||
b_info = expr_to_info[expr.b]
|
||||
expr_info = expr_to_info[expr]
|
||||
|
||||
if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
|
||||
expr_info.reduce_indices.add(reduce_index)
|
||||
|
@ -1724,6 +1730,9 @@ def _accumulate_reduce_indices(
|
|||
| expr_info.reduce_indices)
|
||||
else:
|
||||
assert isinstance(expr, Access)
|
||||
# Handle simple reduction expression in the format of A[i] = B[i, j].
|
||||
expr_info.acc_reduce_indices = expr_info.reduce_indices
|
||||
|
||||
|
||||
|
||||
def _gather_structured_op(
|
||||
|
@ -1821,9 +1830,10 @@ def _gather_structured_op_input(
|
|||
structop_inputs: The resulting list of IndexExpr that provide input to the
|
||||
current structured op.
|
||||
"""
|
||||
if (expr != root and expr not in structop_inputs) and (
|
||||
isinstance(expr, Access) or
|
||||
(expr in expr_to_info and expr_to_info[expr].structop_info)):
|
||||
if ((expr != root or isinstance(expr, Access)) and
|
||||
expr not in structop_inputs) and (isinstance(expr, Access) or
|
||||
(expr in expr_to_info and
|
||||
expr_to_info[expr].structop_info)):
|
||||
structop_inputs.append(expr)
|
||||
|
||||
|
||||
|
@ -1843,7 +1853,7 @@ def _emit_structured_op_input(
|
|||
An OperandDef in the linalg dialect for the input IndexExpr.
|
||||
"""
|
||||
op_info = expr_to_info[expr].structop_info
|
||||
if op_info:
|
||||
if op_info and not isinstance(expr, Access):
|
||||
# The input is a temporary tensor produced by another structured op.
|
||||
indices = op_info.dst_indices
|
||||
name = op_info.dst_name
|
||||
|
|
|
@ -37,3 +37,42 @@ def test_tensor_true_dense():
|
|||
passed += (a.shape[0] == 5)
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_tensor_copy
|
||||
@testing_utils.run_test
|
||||
def test_tensor_copy():
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
I = 2
|
||||
J = 3
|
||||
A = mlir_pytaco.Tensor([I, J])
|
||||
A.insert([0, 1], 5.0)
|
||||
A.insert([1, 2], 6.0)
|
||||
B = mlir_pytaco.Tensor([I, J])
|
||||
B[i, j] = A[i, j]
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0, 1], [1, 2]])
|
||||
passed += np.allclose(values, [5.0, 6.0])
|
||||
|
||||
# CHECK: Number of passed: 2
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_tensor_trivial_reduction
|
||||
@testing_utils.run_test
|
||||
def test_tensor_trivial_reduction():
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
I = 2
|
||||
J = 3
|
||||
A = mlir_pytaco.Tensor([I, J])
|
||||
A.insert([0, 1], 5.0)
|
||||
A.insert([0, 2], 3.0)
|
||||
A.insert([1, 2], 6.0)
|
||||
B = mlir_pytaco.Tensor([I])
|
||||
B[i] = A[i, j]
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.allclose(values, [8.0, 6.0])
|
||||
|
||||
# CHECK: Number of passed: 2
|
||||
print("Number of passed:", passed)
|
||||
|
|
Loading…
Reference in New Issue