[mlir][sparse][taco] Reorder a class.

Define IndexExpr before IndexVar. This is to prepare for the next change
to support the use of index values in tensor expressions.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D121649
This commit is contained in:
Bixia Zheng 2022-03-14 15:43:12 -07:00
parent 1bf4bbc492
commit 3a4229696d
1 changed files with 359 additions and 358 deletions

View File

@ -365,6 +365,365 @@ def _make_format(formats: List[ModeFormat],
return Format(ModeFormatPack(formats), ModeOrdering(ordering)) return Format(ModeFormatPack(formats), ModeOrdering(ordering))
class IndexExpr(abc.ABC):
"""The index notation base class.
We support the TACO API index_expression class with an alias of this class.
"""
def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
"""Verifies the RHS operand and returns a binary expression.
Args:
rhs: The RHS of the binary operation, which could be any Python object
from user inputs.
op: A _BinaryOp object representing the binary operator.
Raises:
ValueError: If rhs is not an IndexExpr.
"""
if not isinstance(rhs, IndexExpr):
raise ValueError(f"Expected IndexExpr: {rhs}")
return _BinaryExpr(op, self, rhs)
def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
"""Build a unary expression.
Args:
op: A _UnaryOp object representing the unary operation.
"""
return _UnaryExpr(op, self)
def __add__(self, rhs) -> "_BinaryExpr":
"""Defines the operator +.
Args:
rhs: The value being added, which could be any Python object from user
inputs.
Returns:
A _BinaryExpr object representing the operation.
Raises:
ValueError: If rhs is not an IndexExpr.
"""
return self._verify_operand_and_build_expr(rhs, operator.add)
def __mul__(self, rhs) -> "_BinaryExpr":
"""Defines the operator *.
Args:
rhs: The value being multiplied, which could be any Python object from
user inputs.
Returns:
A _BinaryExpr object representing the operation.
Raises:
ValueError: If rhs is not an IndexExpr.
"""
return self._verify_operand_and_build_expr(rhs, operator.mul)
def __abs__(self) -> "_UnaryExpr":
"""Defines the operator abs.
Returns:
A _UnaryExpr object representing the operation.
"""
return self._build_unary_expr(operator.abs)
def __neg__(self) -> "_UnaryExpr":
"""Defines the operator neg.
Returns:
A _UnaryExpr object representing the operation.
"""
return self._build_unary_expr(operator.neg)
def __sub__(self, rhs) -> "_BinaryExpr":
"""Defines the operator -.
Args:
rhs: The value being subtracted, which could be any Python object from
user inputs.
Returns:
A _BinaryExpr object representing the operation.
Raises:
ValueError: If rhs is not an IndexExpr.
"""
return self._verify_operand_and_build_expr(rhs, operator.sub)
@abc.abstractmethod
def _visit(self,
func: _ExprVisitor,
args,
*,
leaf_checker: _SubtreeLeafChecker = None) -> None:
"""A post-order visitor.
Args:
func: A callable applied to each node in the expression tree.
args: The variable-length arguments passed to the callable. These
arguments are grouped as an iterable and will be unpacked before passing
to the callable. This is to enable the keyword argument only syntax
after this argument.
leaf_checker: A callable object to identify nodes that should be treated
as leaf nodes to support partial tree visiting.
"""
pass
@abc.abstractmethod
def _emit_expression(
self,
expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
expr_to_info: _ExprInfoDict,
) -> lang.ScalarExpression:
"""Emits MLIR for the expression tree.
Args:
expr_to_opnd: A dictionary for looking up structured op input operands for
the input nodes of the structured op.
expr_to_info: A dictionary for looking up code generation information for
expressions.
Returns:
A linalg dialect ScalarExpression for the expression.
"""
pass
@abc.abstractmethod
def dtype(self) -> DType:
"""Returns the data type for the result of the expression."""
pass
def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
"""Emits a structured op in the linalg dialect for the expression tree.
We define a DefineOpcallable in the domain specific language for the linalg
dialect and execute the callable to generate the structured op. Self is the
root of the expression tree for the structured op.
Args:
expr_to_info: A dictionary for looking up code generation information for
expressions.
"""
op_info = expr_to_info[self].structop_info
op_name = op_info.dst_name
op_def = lang.LinalgOpDef(name=op_name)
op_callable = lang.DefinedOpCallable(op_name, op_def)
# Collect the input expression nodes for the structured op.
expr_inputs = []
self._visit(
_gather_structured_op_input,
(self, expr_to_info, expr_inputs),
leaf_checker=_is_structured_op_leaf,
)
# Create a linalg structured op operand for each input expression node and
# build a dictionary for looking up the information.
expr_to_input_opnd = {
e: _emit_structured_op_input(e, expr_to_info, op_def)
for e in expr_inputs
}
# Emit the expression tree, which produces the value assigned to the
# destination tensor.
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
# Emit the structured op representation for the destination tensor.
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
lang.OperandKind.OUTPUT_TENSOR)
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
expr_info = expr_to_info[self]
# If the structured op reduces some indices, explicitly represent the
# reduction. This is done by generating a ReduceFn for the dimensions being
# reduced in the linalg dialect and calling the function with the value
# being reduced. We only support add reduction currently.
if expr_info.reduce_indices:
reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
value = lang.ReduceFn.add[reduce_dims](value)
# Emit the assignment as a comprehension in the linalg dialect.
comp = lang.Comprehension((dst_use, value))
op_def.comprehensions.append(comp)
# The structured op in the linalg dialect requires an explicit
# initialization for the destination tensor. Emit MLIR to initialize the
# destination tensor.
init = op_info.emit_tensor_init()
# Collect MLIR values for the linalg input operands, with the assumption
# that dictionary preserves the insertion order.
args = [
expr_to_info[expr].mlir_value
for expr, opnd in expr_to_input_opnd.items()
]
# Execute the DefineOpcallable object for the linalg dialect operation to
# emit MLIR for the linalg structured op.
expr_info.mlir_value = op_callable(*args, outs=[init])
def _identify_structured_ops(
self,
expr_to_info: _ExprInfoDict,
dst: "Tensor",
dst_indices: Tuple["IndexVar", ...],
) -> List["IndexExpr"]:
"""Returns expression nodes for the roots of the identified structured ops.
A structured op in the linalg dialect only supports reduction performed on
the whole expression. If the expression tree contains reduction that are
performed on part of the expression tree, the expression tree needs to be
implemented with multiple structured ops. This routine identifies all the
expression nodes that contain reduction as the root of structured ops in the
linalg dialect.
Args:
expr_to_info: A dictionary for looking up code generation information for
expressions.
dst: A destination Tensor that accepts the value of the expression tree.
dst_indices: The indices used by the destination index expression.
Returns:
An ordered list of IndexExpr for the root expressions of the structured
ops, where child expressions go before parent expressions that use their
results.
"""
reduce_indices = tuple(
set(expr_to_info[self].src_indices) - set(dst_indices))
for reduce_index in reduce_indices:
_mark_structured_op_root(self, reduce_index, expr_to_info)
self._visit(_accumulate_reduce_indices, (expr_to_info,))
structop_roots = []
self._visit(_gather_structured_op, (expr_to_info, structop_roots))
# Handle the root of the top level expression.
if not structop_roots or structop_roots[-1] != self:
# The top level expression is not a reduction. Add the top level
# expression as a structured op root.
structop_roots.append(self)
# Use user specified information for the destination tensor to build an
# _StructOpInfo for the top level expression.
expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
tuple(dst.shape),
self.dtype(), dst.name,
dst.format)
return structop_roots
def _validate_and_collect_expr_info(
self,
dst: "Tensor",
dst_indices: Tuple["IndexVar", ...],
) -> _ExprInfoDict:
"""Propagates expression information for validation.
Propagates the indices used by child expression nodes to parent expression
nodes. Also collects and validates the sizes for the dimensions
corresponding to the indices.
Args:
dst: A destination Tensor that accepts the value of the expression tree.
dst_indices: The indices used by the destination index expression.
Raises:
ValueError if there is any inconsistency in indices or dimensional
values.
Returns:
A dictionary of (IndexExpr, _ExprInfo).
"""
expr_to_info = {}
# Validate the expression tree and construct expression information.
self._visit(_validate_and_collect_expr_info, (expr_to_info,))
# Validate the destination dimension information.
info = expr_to_info[self]
index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
for i, d, in zip(dst_indices, dst.shape):
if i not in index_to_dim_info:
raise ValueError("Destination IndexVar not used in the "
f"source expression: {i}")
else:
if d != index_to_dim_info[i].dim:
raise ValueError(f"Inconsistent destination dimension for {i}: "
f"{d} vs {index_to_dim_info[i].dim}")
return expr_to_info
def _emit_assignment(
self,
module: ir.Module,
dst: "Tensor",
dst_indices: Tuple["IndexVar", ...],
expr_to_info: _ExprInfoDict,
input_accesses: List["Access"],
) -> None:
"""Emits an MLIR function for assigning the expression to a tensor."""
input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
# Build the kernel for the operations.
with ir.InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
def linalg_funcop(*args):
# Set up the mapping from the Access nodes to their MLIR values.
for e, mlir in zip(input_accesses, args):
expr_to_info[e].mlir_value = mlir
# Emit structured ops in the linalg dialect to implement the assignment.
for structop_root in self._identify_structured_ops(
expr_to_info, dst, dst_indices):
structop_root._emit_structured_op(expr_to_info)
dst._record_stats(expr_to_info[structop_root].structop_info)
# The function returns the MLIR value of the root expression.
return expr_to_info[self].mlir_value
linalg_funcop.func_op.attributes[
"llvm.emit_c_interface"] = ir.UnitAttr.get()
def get_input_accesses(self) -> List["Access"]:
"""Compute the list of input accesses for the expression."""
input_accesses = []
self._visit(_gather_input_accesses_index_vars, (input_accesses,))
return input_accesses
def compile(
self,
dst: "Tensor",
dst_indices: Tuple["IndexVar", ...],
) -> execution_engine.ExecutionEngine:
"""Compiles the tensor assignment dst[dst_indices] = expression.
Args:
dst: The destination tensor.
dst_indices: The tuple of IndexVar used to access the destination tensor.
Returns:
The execution engine for the tensor assignment.
Raises:
ValueError: If the expression is not proper or not supported.
"""
expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
input_accesses = self.get_input_accesses()
# Build and compile the module to produce the execution engine.
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
self._emit_assignment(module, dst, dst_indices, expr_to_info,
input_accesses)
engine = utils.compile_and_build_engine(module)
return engine
class _AtomicCounter: class _AtomicCounter:
"""An atomic counter.""" """An atomic counter."""
@ -1203,364 +1562,6 @@ class _ExprInfo:
self.acc_reduce_indices = self.acc_reduce_indices or set() self.acc_reduce_indices = self.acc_reduce_indices or set()
class IndexExpr(abc.ABC):
"""The index notation base class.
We support the TACO API index_expression class with an alias of this class.
"""
def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
"""Verifies the RHS operand and returns a binary expression.
Args:
rhs: The RHS of the binary operation, which could be any Python object
from user inputs.
op: A _BinaryOp object representing the binary operator.
Raises:
ValueError: If rhs is not an IndexExpr.
"""
if not isinstance(rhs, IndexExpr):
raise ValueError(f"Expected IndexExpr: {rhs}")
return _BinaryExpr(op, self, rhs)
def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
"""Build a unary expression.
Args:
op: A _UnaryOp object representing the unary operation.
"""
return _UnaryExpr(op, self)
def __add__(self, rhs) -> "_BinaryExpr":
"""Defines the operator +.
Args:
rhs: The value being added, which could be any Python object from user
inputs.
Returns:
A _BinaryExpr object representing the operation.
Raises:
ValueError: If rhs is not an IndexExpr.
"""
return self._verify_operand_and_build_expr(rhs, operator.add)
def __mul__(self, rhs) -> "_BinaryExpr":
"""Defines the operator *.
Args:
rhs: The value being multiplied, which could be any Python object from
user inputs.
Returns:
A _BinaryExpr object representing the operation.
Raises:
ValueError: If rhs is not an IndexExpr.
"""
return self._verify_operand_and_build_expr(rhs, operator.mul)
def __abs__(self) -> "_UnaryExpr":
"""Defines the operator abs.
Returns:
A _UnaryExpr object representing the operation.
"""
return self._build_unary_expr(operator.abs)
def __neg__(self) -> "_UnaryExpr":
"""Defines the operator neg.
Returns:
A _UnaryExpr object representing the operation.
"""
return self._build_unary_expr(operator.neg)
def __sub__(self, rhs) -> "_BinaryExpr":
"""Defines the operator -.
Args:
rhs: The value being subtracted, which could be any Python object from
user inputs.
Returns:
A _BinaryExpr object representing the operation.
Raises:
ValueError: If rhs is not an IndexExpr.
"""
return self._verify_operand_and_build_expr(rhs, operator.sub)
@abc.abstractmethod
def _visit(self,
func: _ExprVisitor,
args,
*,
leaf_checker: _SubtreeLeafChecker = None) -> None:
"""A post-order visitor.
Args:
func: A callable applied to each node in the expression tree.
args: The variable-length arguments passed to the callable. These
arguments are grouped as an iterable and will be unpacked before passing
to the callable. This is to enable the keyword argument only syntax
after this argument.
leaf_checker: A callable object to identify nodes that should be treated
as leaf nodes to support partial tree visiting.
"""
pass
@abc.abstractmethod
def _emit_expression(
self,
expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
expr_to_info: _ExprInfoDict,
) -> lang.ScalarExpression:
"""Emits MLIR for the expression tree.
Args:
expr_to_opnd: A dictionary for looking up structured op input operands for
the input nodes of the structured op.
expr_to_info: A dictionary for looking up code generation information for
expressions.
Returns:
A linalg dialect ScalarExpression for the expression.
"""
pass
@abc.abstractmethod
def dtype(self) -> DType:
"""Returns the data type for the result of the expression."""
pass
def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
"""Emits a structured op in the linalg dialect for the expression tree.
We define a DefineOpcallable in the domain specific language for the linalg
dialect and execute the callable to generate the structured op. Self is the
root of the expression tree for the structured op.
Args:
expr_to_info: A dictionary for looking up code generation information for
expressions.
"""
op_info = expr_to_info[self].structop_info
op_name = op_info.dst_name
op_def = lang.LinalgOpDef(name=op_name)
op_callable = lang.DefinedOpCallable(op_name, op_def)
# Collect the input expression nodes for the structured op.
expr_inputs = []
self._visit(
_gather_structured_op_input,
(self, expr_to_info, expr_inputs),
leaf_checker=_is_structured_op_leaf,
)
# Create a linalg structured op operand for each input expression node and
# build a dictionary for looking up the information.
expr_to_input_opnd = {
e: _emit_structured_op_input(e, expr_to_info, op_def)
for e in expr_inputs
}
# Emit the expression tree, which produces the value assigned to the
# destination tensor.
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
# Emit the structured op representation for the destination tensor.
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
lang.OperandKind.OUTPUT_TENSOR)
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
expr_info = expr_to_info[self]
# If the structured op reduces some indices, explicitly represent the
# reduction. This is done by generating a ReduceFn for the dimensions being
# reduced in the linalg dialect and calling the function with the value
# being reduced. We only support add reduction currently.
if expr_info.reduce_indices:
reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
value = lang.ReduceFn.add[reduce_dims](value)
# Emit the assignment as a comprehension in the linalg dialect.
comp = lang.Comprehension((dst_use, value))
op_def.comprehensions.append(comp)
# The structured op in the linalg dialect requires an explicit
# initialization for the destination tensor. Emit MLIR to initialize the
# destination tensor.
init = op_info.emit_tensor_init()
# Collect MLIR values for the linalg input operands, with the assumption
# that dictionary preserves the insertion order.
args = [
expr_to_info[expr].mlir_value
for expr, opnd in expr_to_input_opnd.items()
]
# Execute the DefineOpcallable object for the linalg dialect operation to
# emit MLIR for the linalg structured op.
expr_info.mlir_value = op_callable(*args, outs=[init])
def _identify_structured_ops(
self,
expr_to_info: _ExprInfoDict,
dst: Tensor,
dst_indices: Tuple[IndexVar, ...],
) -> List["IndexExpr"]:
"""Returns expression nodes for the roots of the identified structured ops.
A structured op in the linalg dialect only supports reduction performed on
the whole expression. If the expression tree contains reduction that are
performed on part of the expression tree, the expression tree needs to be
implemented with multiple structured ops. This routine identifies all the
expression nodes that contain reduction as the root of structured ops in the
linalg dialect.
Args:
expr_to_info: A dictionary for looking up code generation information for
expressions.
dst: A destination Tensor that accepts the value of the expression tree.
dst_indices: The indices used by the destination index expression.
Returns:
An ordered list of IndexExpr for the root expressions of the structured
ops, where child expressions go before parent expressions that use their
results.
"""
reduce_indices = tuple(
set(expr_to_info[self].src_indices) - set(dst_indices))
for reduce_index in reduce_indices:
_mark_structured_op_root(self, reduce_index, expr_to_info)
self._visit(_accumulate_reduce_indices, (expr_to_info,))
structop_roots = []
self._visit(_gather_structured_op, (expr_to_info, structop_roots))
# Handle the root of the top level expression.
if not structop_roots or structop_roots[-1] != self:
# The top level expression is not a reduction. Add the top level
# expression as a structured op root.
structop_roots.append(self)
# Use user specified information for the destination tensor to build an
# _StructOpInfo for the top level expression.
expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
tuple(dst.shape),
self.dtype(), dst.name,
dst.format)
return structop_roots
def _validate_and_collect_expr_info(
self,
dst: Tensor,
dst_indices: Tuple[IndexVar, ...],
) -> _ExprInfoDict:
"""Propagates expression information for validation.
Propagates the indices used by child expression nodes to parent expression
nodes. Also collects and validates the sizes for the dimensions
corresponding to the indices.
Args:
dst: A destination Tensor that accepts the value of the expression tree.
dst_indices: The indices used by the destination index expression.
Raises:
ValueError if there is any inconsistency in indices or dimensional
values.
Returns:
A dictionary of (IndexExpr, _ExprInfo).
"""
expr_to_info = {}
# Validate the expression tree and construct expression information.
self._visit(_validate_and_collect_expr_info, (expr_to_info,))
# Validate the destination dimension information.
info = expr_to_info[self]
index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
for i, d, in zip(dst_indices, dst.shape):
if i not in index_to_dim_info:
raise ValueError("Destination IndexVar not used in the "
f"source expression: {i}")
else:
if d != index_to_dim_info[i].dim:
raise ValueError(f"Inconsistent destination dimension for {i}: "
f"{d} vs {index_to_dim_info[i].dim}")
return expr_to_info
def _emit_assignment(
self,
module: ir.Module,
dst: Tensor,
dst_indices: Tuple[IndexVar, ...],
expr_to_info: _ExprInfoDict,
input_accesses: List["Access"],
) -> None:
"""Emits an MLIR function for assigning the expression to a tensor."""
input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
# Build the kernel for the operations.
with ir.InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
def linalg_funcop(*args):
# Set up the mapping from the Access nodes to their MLIR values.
for e, mlir in zip(input_accesses, args):
expr_to_info[e].mlir_value = mlir
# Emit structured ops in the linalg dialect to implement the assignment.
for structop_root in self._identify_structured_ops(
expr_to_info, dst, dst_indices):
structop_root._emit_structured_op(expr_to_info)
dst._record_stats(expr_to_info[structop_root].structop_info)
# The function returns the MLIR value of the root expression.
return expr_to_info[self].mlir_value
linalg_funcop.func_op.attributes[
"llvm.emit_c_interface"] = ir.UnitAttr.get()
def get_input_accesses(self) -> List["Access"]:
"""Compute the list of input accesses for the expression."""
input_accesses = []
self._visit(_gather_input_accesses_index_vars, (input_accesses,))
return input_accesses
def compile(
self,
dst: Tensor,
dst_indices: Tuple[IndexVar, ...],
) -> execution_engine.ExecutionEngine:
"""Compiles the tensor assignment dst[dst_indices] = expression.
Args:
dst: The destination tensor.
dst_indices: The tuple of IndexVar used to access the destination tensor.
Returns:
The execution engine for the tensor assignment.
Raises:
ValueError: If the expression is not proper or not supported.
"""
expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
input_accesses = self.get_input_accesses()
# Build and compile the module to produce the execution engine.
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
self._emit_assignment(module, dst, dst_indices, expr_to_info,
input_accesses)
engine = utils.compile_and_build_engine(module)
return engine
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Access(IndexExpr): class Access(IndexExpr):
"""The tensor access class. """The tensor access class.