forked from OSchip/llvm-project
[mlir][sparse][taco] Reorder a class.
Define IndexExpr before IndexVar. This is to prepare for the next change to support the use of index values in tensor expressions. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D121649
This commit is contained in:
parent
1bf4bbc492
commit
3a4229696d
|
@ -365,6 +365,365 @@ def _make_format(formats: List[ModeFormat],
|
||||||
return Format(ModeFormatPack(formats), ModeOrdering(ordering))
|
return Format(ModeFormatPack(formats), ModeOrdering(ordering))
|
||||||
|
|
||||||
|
|
||||||
|
class IndexExpr(abc.ABC):
|
||||||
|
"""The index notation base class.
|
||||||
|
|
||||||
|
We support the TACO API index_expression class with an alias of this class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
|
||||||
|
"""Verifies the RHS operand and returns a binary expression.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rhs: The RHS of the binary operation, which could be any Python object
|
||||||
|
from user inputs.
|
||||||
|
op: A _BinaryOp object representing the binary operator.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If rhs is not an IndexExpr.
|
||||||
|
"""
|
||||||
|
if not isinstance(rhs, IndexExpr):
|
||||||
|
raise ValueError(f"Expected IndexExpr: {rhs}")
|
||||||
|
return _BinaryExpr(op, self, rhs)
|
||||||
|
|
||||||
|
def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
|
||||||
|
"""Build a unary expression.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: A _UnaryOp object representing the unary operation.
|
||||||
|
"""
|
||||||
|
return _UnaryExpr(op, self)
|
||||||
|
|
||||||
|
def __add__(self, rhs) -> "_BinaryExpr":
|
||||||
|
"""Defines the operator +.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rhs: The value being added, which could be any Python object from user
|
||||||
|
inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A _BinaryExpr object representing the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If rhs is not an IndexExpr.
|
||||||
|
"""
|
||||||
|
return self._verify_operand_and_build_expr(rhs, operator.add)
|
||||||
|
|
||||||
|
def __mul__(self, rhs) -> "_BinaryExpr":
|
||||||
|
"""Defines the operator *.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rhs: The value being multiplied, which could be any Python object from
|
||||||
|
user inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A _BinaryExpr object representing the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If rhs is not an IndexExpr.
|
||||||
|
"""
|
||||||
|
return self._verify_operand_and_build_expr(rhs, operator.mul)
|
||||||
|
|
||||||
|
def __abs__(self) -> "_UnaryExpr":
|
||||||
|
"""Defines the operator abs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A _UnaryExpr object representing the operation.
|
||||||
|
"""
|
||||||
|
return self._build_unary_expr(operator.abs)
|
||||||
|
|
||||||
|
def __neg__(self) -> "_UnaryExpr":
|
||||||
|
"""Defines the operator neg.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A _UnaryExpr object representing the operation.
|
||||||
|
"""
|
||||||
|
return self._build_unary_expr(operator.neg)
|
||||||
|
|
||||||
|
def __sub__(self, rhs) -> "_BinaryExpr":
|
||||||
|
"""Defines the operator -.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rhs: The value being subtracted, which could be any Python object from
|
||||||
|
user inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A _BinaryExpr object representing the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If rhs is not an IndexExpr.
|
||||||
|
"""
|
||||||
|
return self._verify_operand_and_build_expr(rhs, operator.sub)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _visit(self,
|
||||||
|
func: _ExprVisitor,
|
||||||
|
args,
|
||||||
|
*,
|
||||||
|
leaf_checker: _SubtreeLeafChecker = None) -> None:
|
||||||
|
"""A post-order visitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: A callable applied to each node in the expression tree.
|
||||||
|
args: The variable-length arguments passed to the callable. These
|
||||||
|
arguments are grouped as an iterable and will be unpacked before passing
|
||||||
|
to the callable. This is to enable the keyword argument only syntax
|
||||||
|
after this argument.
|
||||||
|
leaf_checker: A callable object to identify nodes that should be treated
|
||||||
|
as leaf nodes to support partial tree visiting.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _emit_expression(
|
||||||
|
self,
|
||||||
|
expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
|
||||||
|
expr_to_info: _ExprInfoDict,
|
||||||
|
) -> lang.ScalarExpression:
|
||||||
|
"""Emits MLIR for the expression tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr_to_opnd: A dictionary for looking up structured op input operands for
|
||||||
|
the input nodes of the structured op.
|
||||||
|
expr_to_info: A dictionary for looking up code generation information for
|
||||||
|
expressions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A linalg dialect ScalarExpression for the expression.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
"""Returns the data type for the result of the expression."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
|
||||||
|
"""Emits a structured op in the linalg dialect for the expression tree.
|
||||||
|
|
||||||
|
We define a DefineOpcallable in the domain specific language for the linalg
|
||||||
|
dialect and execute the callable to generate the structured op. Self is the
|
||||||
|
root of the expression tree for the structured op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr_to_info: A dictionary for looking up code generation information for
|
||||||
|
expressions.
|
||||||
|
"""
|
||||||
|
op_info = expr_to_info[self].structop_info
|
||||||
|
op_name = op_info.dst_name
|
||||||
|
op_def = lang.LinalgOpDef(name=op_name)
|
||||||
|
op_callable = lang.DefinedOpCallable(op_name, op_def)
|
||||||
|
|
||||||
|
# Collect the input expression nodes for the structured op.
|
||||||
|
expr_inputs = []
|
||||||
|
self._visit(
|
||||||
|
_gather_structured_op_input,
|
||||||
|
(self, expr_to_info, expr_inputs),
|
||||||
|
leaf_checker=_is_structured_op_leaf,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a linalg structured op operand for each input expression node and
|
||||||
|
# build a dictionary for looking up the information.
|
||||||
|
expr_to_input_opnd = {
|
||||||
|
e: _emit_structured_op_input(e, expr_to_info, op_def)
|
||||||
|
for e in expr_inputs
|
||||||
|
}
|
||||||
|
|
||||||
|
# Emit the expression tree, which produces the value assigned to the
|
||||||
|
# destination tensor.
|
||||||
|
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
|
||||||
|
# Emit the structured op representation for the destination tensor.
|
||||||
|
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
|
||||||
|
lang.OperandKind.OUTPUT_TENSOR)
|
||||||
|
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
|
||||||
|
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
|
||||||
|
|
||||||
|
expr_info = expr_to_info[self]
|
||||||
|
# If the structured op reduces some indices, explicitly represent the
|
||||||
|
# reduction. This is done by generating a ReduceFn for the dimensions being
|
||||||
|
# reduced in the linalg dialect and calling the function with the value
|
||||||
|
# being reduced. We only support add reduction currently.
|
||||||
|
if expr_info.reduce_indices:
|
||||||
|
reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
|
||||||
|
value = lang.ReduceFn.add[reduce_dims](value)
|
||||||
|
|
||||||
|
# Emit the assignment as a comprehension in the linalg dialect.
|
||||||
|
comp = lang.Comprehension((dst_use, value))
|
||||||
|
op_def.comprehensions.append(comp)
|
||||||
|
|
||||||
|
# The structured op in the linalg dialect requires an explicit
|
||||||
|
# initialization for the destination tensor. Emit MLIR to initialize the
|
||||||
|
# destination tensor.
|
||||||
|
init = op_info.emit_tensor_init()
|
||||||
|
|
||||||
|
# Collect MLIR values for the linalg input operands, with the assumption
|
||||||
|
# that dictionary preserves the insertion order.
|
||||||
|
args = [
|
||||||
|
expr_to_info[expr].mlir_value
|
||||||
|
for expr, opnd in expr_to_input_opnd.items()
|
||||||
|
]
|
||||||
|
# Execute the DefineOpcallable object for the linalg dialect operation to
|
||||||
|
# emit MLIR for the linalg structured op.
|
||||||
|
expr_info.mlir_value = op_callable(*args, outs=[init])
|
||||||
|
|
||||||
|
def _identify_structured_ops(
|
||||||
|
self,
|
||||||
|
expr_to_info: _ExprInfoDict,
|
||||||
|
dst: "Tensor",
|
||||||
|
dst_indices: Tuple["IndexVar", ...],
|
||||||
|
) -> List["IndexExpr"]:
|
||||||
|
"""Returns expression nodes for the roots of the identified structured ops.
|
||||||
|
|
||||||
|
A structured op in the linalg dialect only supports reduction performed on
|
||||||
|
the whole expression. If the expression tree contains reduction that are
|
||||||
|
performed on part of the expression tree, the expression tree needs to be
|
||||||
|
implemented with multiple structured ops. This routine identifies all the
|
||||||
|
expression nodes that contain reduction as the root of structured ops in the
|
||||||
|
linalg dialect.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr_to_info: A dictionary for looking up code generation information for
|
||||||
|
expressions.
|
||||||
|
dst: A destination Tensor that accepts the value of the expression tree.
|
||||||
|
dst_indices: The indices used by the destination index expression.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An ordered list of IndexExpr for the root expressions of the structured
|
||||||
|
ops, where child expressions go before parent expressions that use their
|
||||||
|
results.
|
||||||
|
"""
|
||||||
|
reduce_indices = tuple(
|
||||||
|
set(expr_to_info[self].src_indices) - set(dst_indices))
|
||||||
|
for reduce_index in reduce_indices:
|
||||||
|
_mark_structured_op_root(self, reduce_index, expr_to_info)
|
||||||
|
|
||||||
|
self._visit(_accumulate_reduce_indices, (expr_to_info,))
|
||||||
|
structop_roots = []
|
||||||
|
self._visit(_gather_structured_op, (expr_to_info, structop_roots))
|
||||||
|
|
||||||
|
# Handle the root of the top level expression.
|
||||||
|
if not structop_roots or structop_roots[-1] != self:
|
||||||
|
# The top level expression is not a reduction. Add the top level
|
||||||
|
# expression as a structured op root.
|
||||||
|
structop_roots.append(self)
|
||||||
|
|
||||||
|
# Use user specified information for the destination tensor to build an
|
||||||
|
# _StructOpInfo for the top level expression.
|
||||||
|
expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
|
||||||
|
tuple(dst.shape),
|
||||||
|
self.dtype(), dst.name,
|
||||||
|
dst.format)
|
||||||
|
|
||||||
|
return structop_roots
|
||||||
|
|
||||||
|
def _validate_and_collect_expr_info(
|
||||||
|
self,
|
||||||
|
dst: "Tensor",
|
||||||
|
dst_indices: Tuple["IndexVar", ...],
|
||||||
|
) -> _ExprInfoDict:
|
||||||
|
"""Propagates expression information for validation.
|
||||||
|
|
||||||
|
Propagates the indices used by child expression nodes to parent expression
|
||||||
|
nodes. Also collects and validates the sizes for the dimensions
|
||||||
|
corresponding to the indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst: A destination Tensor that accepts the value of the expression tree.
|
||||||
|
dst_indices: The indices used by the destination index expression.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if there is any inconsistency in indices or dimensional
|
||||||
|
values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of (IndexExpr, _ExprInfo).
|
||||||
|
"""
|
||||||
|
expr_to_info = {}
|
||||||
|
# Validate the expression tree and construct expression information.
|
||||||
|
self._visit(_validate_and_collect_expr_info, (expr_to_info,))
|
||||||
|
|
||||||
|
# Validate the destination dimension information.
|
||||||
|
info = expr_to_info[self]
|
||||||
|
index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
|
||||||
|
for i, d, in zip(dst_indices, dst.shape):
|
||||||
|
if i not in index_to_dim_info:
|
||||||
|
raise ValueError("Destination IndexVar not used in the "
|
||||||
|
f"source expression: {i}")
|
||||||
|
else:
|
||||||
|
if d != index_to_dim_info[i].dim:
|
||||||
|
raise ValueError(f"Inconsistent destination dimension for {i}: "
|
||||||
|
f"{d} vs {index_to_dim_info[i].dim}")
|
||||||
|
|
||||||
|
return expr_to_info
|
||||||
|
|
||||||
|
def _emit_assignment(
|
||||||
|
self,
|
||||||
|
module: ir.Module,
|
||||||
|
dst: "Tensor",
|
||||||
|
dst_indices: Tuple["IndexVar", ...],
|
||||||
|
expr_to_info: _ExprInfoDict,
|
||||||
|
input_accesses: List["Access"],
|
||||||
|
) -> None:
|
||||||
|
"""Emits an MLIR function for assigning the expression to a tensor."""
|
||||||
|
input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
|
||||||
|
|
||||||
|
# Build the kernel for the operations.
|
||||||
|
with ir.InsertionPoint(module.body):
|
||||||
|
|
||||||
|
@builtin.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
|
||||||
|
def linalg_funcop(*args):
|
||||||
|
# Set up the mapping from the Access nodes to their MLIR values.
|
||||||
|
for e, mlir in zip(input_accesses, args):
|
||||||
|
expr_to_info[e].mlir_value = mlir
|
||||||
|
|
||||||
|
# Emit structured ops in the linalg dialect to implement the assignment.
|
||||||
|
for structop_root in self._identify_structured_ops(
|
||||||
|
expr_to_info, dst, dst_indices):
|
||||||
|
structop_root._emit_structured_op(expr_to_info)
|
||||||
|
dst._record_stats(expr_to_info[structop_root].structop_info)
|
||||||
|
|
||||||
|
# The function returns the MLIR value of the root expression.
|
||||||
|
return expr_to_info[self].mlir_value
|
||||||
|
|
||||||
|
linalg_funcop.func_op.attributes[
|
||||||
|
"llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||||
|
|
||||||
|
def get_input_accesses(self) -> List["Access"]:
|
||||||
|
"""Compute the list of input accesses for the expression."""
|
||||||
|
input_accesses = []
|
||||||
|
self._visit(_gather_input_accesses_index_vars, (input_accesses,))
|
||||||
|
return input_accesses
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
dst: "Tensor",
|
||||||
|
dst_indices: Tuple["IndexVar", ...],
|
||||||
|
) -> execution_engine.ExecutionEngine:
|
||||||
|
"""Compiles the tensor assignment dst[dst_indices] = expression.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst: The destination tensor.
|
||||||
|
dst_indices: The tuple of IndexVar used to access the destination tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The execution engine for the tensor assignment.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the expression is not proper or not supported.
|
||||||
|
"""
|
||||||
|
expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
|
||||||
|
input_accesses = self.get_input_accesses()
|
||||||
|
|
||||||
|
# Build and compile the module to produce the execution engine.
|
||||||
|
with ir.Context(), ir.Location.unknown():
|
||||||
|
module = ir.Module.create()
|
||||||
|
self._emit_assignment(module, dst, dst_indices, expr_to_info,
|
||||||
|
input_accesses)
|
||||||
|
engine = utils.compile_and_build_engine(module)
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
class _AtomicCounter:
|
class _AtomicCounter:
|
||||||
"""An atomic counter."""
|
"""An atomic counter."""
|
||||||
|
|
||||||
|
@ -1203,364 +1562,6 @@ class _ExprInfo:
|
||||||
self.acc_reduce_indices = self.acc_reduce_indices or set()
|
self.acc_reduce_indices = self.acc_reduce_indices or set()
|
||||||
|
|
||||||
|
|
||||||
class IndexExpr(abc.ABC):
|
|
||||||
"""The index notation base class.
|
|
||||||
|
|
||||||
We support the TACO API index_expression class with an alias of this class.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
|
|
||||||
"""Verifies the RHS operand and returns a binary expression.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rhs: The RHS of the binary operation, which could be any Python object
|
|
||||||
from user inputs.
|
|
||||||
op: A _BinaryOp object representing the binary operator.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If rhs is not an IndexExpr.
|
|
||||||
"""
|
|
||||||
if not isinstance(rhs, IndexExpr):
|
|
||||||
raise ValueError(f"Expected IndexExpr: {rhs}")
|
|
||||||
return _BinaryExpr(op, self, rhs)
|
|
||||||
|
|
||||||
def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
|
|
||||||
"""Build a unary expression.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op: A _UnaryOp object representing the unary operation.
|
|
||||||
"""
|
|
||||||
return _UnaryExpr(op, self)
|
|
||||||
|
|
||||||
def __add__(self, rhs) -> "_BinaryExpr":
|
|
||||||
"""Defines the operator +.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rhs: The value being added, which could be any Python object from user
|
|
||||||
inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A _BinaryExpr object representing the operation.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If rhs is not an IndexExpr.
|
|
||||||
"""
|
|
||||||
return self._verify_operand_and_build_expr(rhs, operator.add)
|
|
||||||
|
|
||||||
def __mul__(self, rhs) -> "_BinaryExpr":
|
|
||||||
"""Defines the operator *.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rhs: The value being multiplied, which could be any Python object from
|
|
||||||
user inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A _BinaryExpr object representing the operation.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If rhs is not an IndexExpr.
|
|
||||||
"""
|
|
||||||
return self._verify_operand_and_build_expr(rhs, operator.mul)
|
|
||||||
|
|
||||||
def __abs__(self) -> "_UnaryExpr":
|
|
||||||
"""Defines the operator abs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A _UnaryExpr object representing the operation.
|
|
||||||
"""
|
|
||||||
return self._build_unary_expr(operator.abs)
|
|
||||||
|
|
||||||
def __neg__(self) -> "_UnaryExpr":
|
|
||||||
"""Defines the operator neg.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A _UnaryExpr object representing the operation.
|
|
||||||
"""
|
|
||||||
return self._build_unary_expr(operator.neg)
|
|
||||||
|
|
||||||
def __sub__(self, rhs) -> "_BinaryExpr":
|
|
||||||
"""Defines the operator -.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rhs: The value being subtracted, which could be any Python object from
|
|
||||||
user inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A _BinaryExpr object representing the operation.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If rhs is not an IndexExpr.
|
|
||||||
"""
|
|
||||||
return self._verify_operand_and_build_expr(rhs, operator.sub)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _visit(self,
|
|
||||||
func: _ExprVisitor,
|
|
||||||
args,
|
|
||||||
*,
|
|
||||||
leaf_checker: _SubtreeLeafChecker = None) -> None:
|
|
||||||
"""A post-order visitor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: A callable applied to each node in the expression tree.
|
|
||||||
args: The variable-length arguments passed to the callable. These
|
|
||||||
arguments are grouped as an iterable and will be unpacked before passing
|
|
||||||
to the callable. This is to enable the keyword argument only syntax
|
|
||||||
after this argument.
|
|
||||||
leaf_checker: A callable object to identify nodes that should be treated
|
|
||||||
as leaf nodes to support partial tree visiting.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _emit_expression(
|
|
||||||
self,
|
|
||||||
expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
|
|
||||||
expr_to_info: _ExprInfoDict,
|
|
||||||
) -> lang.ScalarExpression:
|
|
||||||
"""Emits MLIR for the expression tree.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
expr_to_opnd: A dictionary for looking up structured op input operands for
|
|
||||||
the input nodes of the structured op.
|
|
||||||
expr_to_info: A dictionary for looking up code generation information for
|
|
||||||
expressions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A linalg dialect ScalarExpression for the expression.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def dtype(self) -> DType:
|
|
||||||
"""Returns the data type for the result of the expression."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
|
|
||||||
"""Emits a structured op in the linalg dialect for the expression tree.
|
|
||||||
|
|
||||||
We define a DefineOpcallable in the domain specific language for the linalg
|
|
||||||
dialect and execute the callable to generate the structured op. Self is the
|
|
||||||
root of the expression tree for the structured op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
expr_to_info: A dictionary for looking up code generation information for
|
|
||||||
expressions.
|
|
||||||
"""
|
|
||||||
op_info = expr_to_info[self].structop_info
|
|
||||||
op_name = op_info.dst_name
|
|
||||||
op_def = lang.LinalgOpDef(name=op_name)
|
|
||||||
op_callable = lang.DefinedOpCallable(op_name, op_def)
|
|
||||||
|
|
||||||
# Collect the input expression nodes for the structured op.
|
|
||||||
expr_inputs = []
|
|
||||||
self._visit(
|
|
||||||
_gather_structured_op_input,
|
|
||||||
(self, expr_to_info, expr_inputs),
|
|
||||||
leaf_checker=_is_structured_op_leaf,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a linalg structured op operand for each input expression node and
|
|
||||||
# build a dictionary for looking up the information.
|
|
||||||
expr_to_input_opnd = {
|
|
||||||
e: _emit_structured_op_input(e, expr_to_info, op_def)
|
|
||||||
for e in expr_inputs
|
|
||||||
}
|
|
||||||
|
|
||||||
# Emit the expression tree, which produces the value assigned to the
|
|
||||||
# destination tensor.
|
|
||||||
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
|
|
||||||
# Emit the structured op representation for the destination tensor.
|
|
||||||
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
|
|
||||||
lang.OperandKind.OUTPUT_TENSOR)
|
|
||||||
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
|
|
||||||
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
|
|
||||||
|
|
||||||
expr_info = expr_to_info[self]
|
|
||||||
# If the structured op reduces some indices, explicitly represent the
|
|
||||||
# reduction. This is done by generating a ReduceFn for the dimensions being
|
|
||||||
# reduced in the linalg dialect and calling the function with the value
|
|
||||||
# being reduced. We only support add reduction currently.
|
|
||||||
if expr_info.reduce_indices:
|
|
||||||
reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
|
|
||||||
value = lang.ReduceFn.add[reduce_dims](value)
|
|
||||||
|
|
||||||
# Emit the assignment as a comprehension in the linalg dialect.
|
|
||||||
comp = lang.Comprehension((dst_use, value))
|
|
||||||
op_def.comprehensions.append(comp)
|
|
||||||
|
|
||||||
# The structured op in the linalg dialect requires an explicit
|
|
||||||
# initialization for the destination tensor. Emit MLIR to initialize the
|
|
||||||
# destination tensor.
|
|
||||||
init = op_info.emit_tensor_init()
|
|
||||||
|
|
||||||
# Collect MLIR values for the linalg input operands, with the assumption
|
|
||||||
# that dictionary preserves the insertion order.
|
|
||||||
args = [
|
|
||||||
expr_to_info[expr].mlir_value
|
|
||||||
for expr, opnd in expr_to_input_opnd.items()
|
|
||||||
]
|
|
||||||
# Execute the DefineOpcallable object for the linalg dialect operation to
|
|
||||||
# emit MLIR for the linalg structured op.
|
|
||||||
expr_info.mlir_value = op_callable(*args, outs=[init])
|
|
||||||
|
|
||||||
def _identify_structured_ops(
|
|
||||||
self,
|
|
||||||
expr_to_info: _ExprInfoDict,
|
|
||||||
dst: Tensor,
|
|
||||||
dst_indices: Tuple[IndexVar, ...],
|
|
||||||
) -> List["IndexExpr"]:
|
|
||||||
"""Returns expression nodes for the roots of the identified structured ops.
|
|
||||||
|
|
||||||
A structured op in the linalg dialect only supports reduction performed on
|
|
||||||
the whole expression. If the expression tree contains reduction that are
|
|
||||||
performed on part of the expression tree, the expression tree needs to be
|
|
||||||
implemented with multiple structured ops. This routine identifies all the
|
|
||||||
expression nodes that contain reduction as the root of structured ops in the
|
|
||||||
linalg dialect.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
expr_to_info: A dictionary for looking up code generation information for
|
|
||||||
expressions.
|
|
||||||
dst: A destination Tensor that accepts the value of the expression tree.
|
|
||||||
dst_indices: The indices used by the destination index expression.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An ordered list of IndexExpr for the root expressions of the structured
|
|
||||||
ops, where child expressions go before parent expressions that use their
|
|
||||||
results.
|
|
||||||
"""
|
|
||||||
reduce_indices = tuple(
|
|
||||||
set(expr_to_info[self].src_indices) - set(dst_indices))
|
|
||||||
for reduce_index in reduce_indices:
|
|
||||||
_mark_structured_op_root(self, reduce_index, expr_to_info)
|
|
||||||
|
|
||||||
self._visit(_accumulate_reduce_indices, (expr_to_info,))
|
|
||||||
structop_roots = []
|
|
||||||
self._visit(_gather_structured_op, (expr_to_info, structop_roots))
|
|
||||||
|
|
||||||
# Handle the root of the top level expression.
|
|
||||||
if not structop_roots or structop_roots[-1] != self:
|
|
||||||
# The top level expression is not a reduction. Add the top level
|
|
||||||
# expression as a structured op root.
|
|
||||||
structop_roots.append(self)
|
|
||||||
|
|
||||||
# Use user specified information for the destination tensor to build an
|
|
||||||
# _StructOpInfo for the top level expression.
|
|
||||||
expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
|
|
||||||
tuple(dst.shape),
|
|
||||||
self.dtype(), dst.name,
|
|
||||||
dst.format)
|
|
||||||
|
|
||||||
return structop_roots
|
|
||||||
|
|
||||||
def _validate_and_collect_expr_info(
|
|
||||||
self,
|
|
||||||
dst: Tensor,
|
|
||||||
dst_indices: Tuple[IndexVar, ...],
|
|
||||||
) -> _ExprInfoDict:
|
|
||||||
"""Propagates expression information for validation.
|
|
||||||
|
|
||||||
Propagates the indices used by child expression nodes to parent expression
|
|
||||||
nodes. Also collects and validates the sizes for the dimensions
|
|
||||||
corresponding to the indices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dst: A destination Tensor that accepts the value of the expression tree.
|
|
||||||
dst_indices: The indices used by the destination index expression.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError if there is any inconsistency in indices or dimensional
|
|
||||||
values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of (IndexExpr, _ExprInfo).
|
|
||||||
"""
|
|
||||||
expr_to_info = {}
|
|
||||||
# Validate the expression tree and construct expression information.
|
|
||||||
self._visit(_validate_and_collect_expr_info, (expr_to_info,))
|
|
||||||
|
|
||||||
# Validate the destination dimension information.
|
|
||||||
info = expr_to_info[self]
|
|
||||||
index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
|
|
||||||
for i, d, in zip(dst_indices, dst.shape):
|
|
||||||
if i not in index_to_dim_info:
|
|
||||||
raise ValueError("Destination IndexVar not used in the "
|
|
||||||
f"source expression: {i}")
|
|
||||||
else:
|
|
||||||
if d != index_to_dim_info[i].dim:
|
|
||||||
raise ValueError(f"Inconsistent destination dimension for {i}: "
|
|
||||||
f"{d} vs {index_to_dim_info[i].dim}")
|
|
||||||
|
|
||||||
return expr_to_info
|
|
||||||
|
|
||||||
def _emit_assignment(
|
|
||||||
self,
|
|
||||||
module: ir.Module,
|
|
||||||
dst: Tensor,
|
|
||||||
dst_indices: Tuple[IndexVar, ...],
|
|
||||||
expr_to_info: _ExprInfoDict,
|
|
||||||
input_accesses: List["Access"],
|
|
||||||
) -> None:
|
|
||||||
"""Emits an MLIR function for assigning the expression to a tensor."""
|
|
||||||
input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
|
|
||||||
|
|
||||||
# Build the kernel for the operations.
|
|
||||||
with ir.InsertionPoint(module.body):
|
|
||||||
|
|
||||||
@builtin.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
|
|
||||||
def linalg_funcop(*args):
|
|
||||||
# Set up the mapping from the Access nodes to their MLIR values.
|
|
||||||
for e, mlir in zip(input_accesses, args):
|
|
||||||
expr_to_info[e].mlir_value = mlir
|
|
||||||
|
|
||||||
# Emit structured ops in the linalg dialect to implement the assignment.
|
|
||||||
for structop_root in self._identify_structured_ops(
|
|
||||||
expr_to_info, dst, dst_indices):
|
|
||||||
structop_root._emit_structured_op(expr_to_info)
|
|
||||||
dst._record_stats(expr_to_info[structop_root].structop_info)
|
|
||||||
|
|
||||||
# The function returns the MLIR value of the root expression.
|
|
||||||
return expr_to_info[self].mlir_value
|
|
||||||
|
|
||||||
linalg_funcop.func_op.attributes[
|
|
||||||
"llvm.emit_c_interface"] = ir.UnitAttr.get()
|
|
||||||
|
|
||||||
def get_input_accesses(self) -> List["Access"]:
|
|
||||||
"""Compute the list of input accesses for the expression."""
|
|
||||||
input_accesses = []
|
|
||||||
self._visit(_gather_input_accesses_index_vars, (input_accesses,))
|
|
||||||
return input_accesses
|
|
||||||
|
|
||||||
def compile(
|
|
||||||
self,
|
|
||||||
dst: Tensor,
|
|
||||||
dst_indices: Tuple[IndexVar, ...],
|
|
||||||
) -> execution_engine.ExecutionEngine:
|
|
||||||
"""Compiles the tensor assignment dst[dst_indices] = expression.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dst: The destination tensor.
|
|
||||||
dst_indices: The tuple of IndexVar used to access the destination tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The execution engine for the tensor assignment.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the expression is not proper or not supported.
|
|
||||||
"""
|
|
||||||
expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
|
|
||||||
input_accesses = self.get_input_accesses()
|
|
||||||
|
|
||||||
# Build and compile the module to produce the execution engine.
|
|
||||||
with ir.Context(), ir.Location.unknown():
|
|
||||||
module = ir.Module.create()
|
|
||||||
self._emit_assignment(module, dst, dst_indices, expr_to_info,
|
|
||||||
input_accesses)
|
|
||||||
engine = utils.compile_and_build_engine(module)
|
|
||||||
|
|
||||||
return engine
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Access(IndexExpr):
|
class Access(IndexExpr):
|
||||||
"""The tensor access class.
|
"""The tensor access class.
|
||||||
|
|
Loading…
Reference in New Issue