forked from OSchip/llvm-project
[mlir][sparse][taco] Reorder a class.
Define IndexExpr before IndexVar. This is to prepare for the next change to support the use of index values in tensor expressions. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D121649
This commit is contained in:
parent
1bf4bbc492
commit
3a4229696d
|
@ -365,6 +365,365 @@ def _make_format(formats: List[ModeFormat],
|
|||
return Format(ModeFormatPack(formats), ModeOrdering(ordering))
|
||||
|
||||
|
||||
class IndexExpr(abc.ABC):
|
||||
"""The index notation base class.
|
||||
|
||||
We support the TACO API index_expression class with an alias of this class.
|
||||
"""
|
||||
|
||||
def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
|
||||
"""Verifies the RHS operand and returns a binary expression.
|
||||
|
||||
Args:
|
||||
rhs: The RHS of the binary operation, which could be any Python object
|
||||
from user inputs.
|
||||
op: A _BinaryOp object representing the binary operator.
|
||||
|
||||
Raises:
|
||||
ValueError: If rhs is not an IndexExpr.
|
||||
"""
|
||||
if not isinstance(rhs, IndexExpr):
|
||||
raise ValueError(f"Expected IndexExpr: {rhs}")
|
||||
return _BinaryExpr(op, self, rhs)
|
||||
|
||||
def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
|
||||
"""Build a unary expression.
|
||||
|
||||
Args:
|
||||
op: A _UnaryOp object representing the unary operation.
|
||||
"""
|
||||
return _UnaryExpr(op, self)
|
||||
|
||||
def __add__(self, rhs) -> "_BinaryExpr":
|
||||
"""Defines the operator +.
|
||||
|
||||
Args:
|
||||
rhs: The value being added, which could be any Python object from user
|
||||
inputs.
|
||||
|
||||
Returns:
|
||||
A _BinaryExpr object representing the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If rhs is not an IndexExpr.
|
||||
"""
|
||||
return self._verify_operand_and_build_expr(rhs, operator.add)
|
||||
|
||||
def __mul__(self, rhs) -> "_BinaryExpr":
|
||||
"""Defines the operator *.
|
||||
|
||||
Args:
|
||||
rhs: The value being multiplied, which could be any Python object from
|
||||
user inputs.
|
||||
|
||||
Returns:
|
||||
A _BinaryExpr object representing the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If rhs is not an IndexExpr.
|
||||
"""
|
||||
return self._verify_operand_and_build_expr(rhs, operator.mul)
|
||||
|
||||
def __abs__(self) -> "_UnaryExpr":
|
||||
"""Defines the operator abs.
|
||||
|
||||
Returns:
|
||||
A _UnaryExpr object representing the operation.
|
||||
"""
|
||||
return self._build_unary_expr(operator.abs)
|
||||
|
||||
def __neg__(self) -> "_UnaryExpr":
|
||||
"""Defines the operator neg.
|
||||
|
||||
Returns:
|
||||
A _UnaryExpr object representing the operation.
|
||||
"""
|
||||
return self._build_unary_expr(operator.neg)
|
||||
|
||||
def __sub__(self, rhs) -> "_BinaryExpr":
|
||||
"""Defines the operator -.
|
||||
|
||||
Args:
|
||||
rhs: The value being subtracted, which could be any Python object from
|
||||
user inputs.
|
||||
|
||||
Returns:
|
||||
A _BinaryExpr object representing the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If rhs is not an IndexExpr.
|
||||
"""
|
||||
return self._verify_operand_and_build_expr(rhs, operator.sub)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _visit(self,
|
||||
func: _ExprVisitor,
|
||||
args,
|
||||
*,
|
||||
leaf_checker: _SubtreeLeafChecker = None) -> None:
|
||||
"""A post-order visitor.
|
||||
|
||||
Args:
|
||||
func: A callable applied to each node in the expression tree.
|
||||
args: The variable-length arguments passed to the callable. These
|
||||
arguments are grouped as an iterable and will be unpacked before passing
|
||||
to the callable. This is to enable the keyword argument only syntax
|
||||
after this argument.
|
||||
leaf_checker: A callable object to identify nodes that should be treated
|
||||
as leaf nodes to support partial tree visiting.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _emit_expression(
|
||||
self,
|
||||
expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
|
||||
expr_to_info: _ExprInfoDict,
|
||||
) -> lang.ScalarExpression:
|
||||
"""Emits MLIR for the expression tree.
|
||||
|
||||
Args:
|
||||
expr_to_opnd: A dictionary for looking up structured op input operands for
|
||||
the input nodes of the structured op.
|
||||
expr_to_info: A dictionary for looking up code generation information for
|
||||
expressions.
|
||||
|
||||
Returns:
|
||||
A linalg dialect ScalarExpression for the expression.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def dtype(self) -> DType:
|
||||
"""Returns the data type for the result of the expression."""
|
||||
pass
|
||||
|
||||
def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
|
||||
"""Emits a structured op in the linalg dialect for the expression tree.
|
||||
|
||||
We define a DefineOpcallable in the domain specific language for the linalg
|
||||
dialect and execute the callable to generate the structured op. Self is the
|
||||
root of the expression tree for the structured op.
|
||||
|
||||
Args:
|
||||
expr_to_info: A dictionary for looking up code generation information for
|
||||
expressions.
|
||||
"""
|
||||
op_info = expr_to_info[self].structop_info
|
||||
op_name = op_info.dst_name
|
||||
op_def = lang.LinalgOpDef(name=op_name)
|
||||
op_callable = lang.DefinedOpCallable(op_name, op_def)
|
||||
|
||||
# Collect the input expression nodes for the structured op.
|
||||
expr_inputs = []
|
||||
self._visit(
|
||||
_gather_structured_op_input,
|
||||
(self, expr_to_info, expr_inputs),
|
||||
leaf_checker=_is_structured_op_leaf,
|
||||
)
|
||||
|
||||
# Create a linalg structured op operand for each input expression node and
|
||||
# build a dictionary for looking up the information.
|
||||
expr_to_input_opnd = {
|
||||
e: _emit_structured_op_input(e, expr_to_info, op_def)
|
||||
for e in expr_inputs
|
||||
}
|
||||
|
||||
# Emit the expression tree, which produces the value assigned to the
|
||||
# destination tensor.
|
||||
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
|
||||
# Emit the structured op representation for the destination tensor.
|
||||
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
|
||||
lang.OperandKind.OUTPUT_TENSOR)
|
||||
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
|
||||
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
|
||||
|
||||
expr_info = expr_to_info[self]
|
||||
# If the structured op reduces some indices, explicitly represent the
|
||||
# reduction. This is done by generating a ReduceFn for the dimensions being
|
||||
# reduced in the linalg dialect and calling the function with the value
|
||||
# being reduced. We only support add reduction currently.
|
||||
if expr_info.reduce_indices:
|
||||
reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
|
||||
value = lang.ReduceFn.add[reduce_dims](value)
|
||||
|
||||
# Emit the assignment as a comprehension in the linalg dialect.
|
||||
comp = lang.Comprehension((dst_use, value))
|
||||
op_def.comprehensions.append(comp)
|
||||
|
||||
# The structured op in the linalg dialect requires an explicit
|
||||
# initialization for the destination tensor. Emit MLIR to initialize the
|
||||
# destination tensor.
|
||||
init = op_info.emit_tensor_init()
|
||||
|
||||
# Collect MLIR values for the linalg input operands, with the assumption
|
||||
# that dictionary preserves the insertion order.
|
||||
args = [
|
||||
expr_to_info[expr].mlir_value
|
||||
for expr, opnd in expr_to_input_opnd.items()
|
||||
]
|
||||
# Execute the DefineOpcallable object for the linalg dialect operation to
|
||||
# emit MLIR for the linalg structured op.
|
||||
expr_info.mlir_value = op_callable(*args, outs=[init])
|
||||
|
||||
def _identify_structured_ops(
|
||||
self,
|
||||
expr_to_info: _ExprInfoDict,
|
||||
dst: "Tensor",
|
||||
dst_indices: Tuple["IndexVar", ...],
|
||||
) -> List["IndexExpr"]:
|
||||
"""Returns expression nodes for the roots of the identified structured ops.
|
||||
|
||||
A structured op in the linalg dialect only supports reduction performed on
|
||||
the whole expression. If the expression tree contains reduction that are
|
||||
performed on part of the expression tree, the expression tree needs to be
|
||||
implemented with multiple structured ops. This routine identifies all the
|
||||
expression nodes that contain reduction as the root of structured ops in the
|
||||
linalg dialect.
|
||||
|
||||
Args:
|
||||
expr_to_info: A dictionary for looking up code generation information for
|
||||
expressions.
|
||||
dst: A destination Tensor that accepts the value of the expression tree.
|
||||
dst_indices: The indices used by the destination index expression.
|
||||
|
||||
Returns:
|
||||
An ordered list of IndexExpr for the root expressions of the structured
|
||||
ops, where child expressions go before parent expressions that use their
|
||||
results.
|
||||
"""
|
||||
reduce_indices = tuple(
|
||||
set(expr_to_info[self].src_indices) - set(dst_indices))
|
||||
for reduce_index in reduce_indices:
|
||||
_mark_structured_op_root(self, reduce_index, expr_to_info)
|
||||
|
||||
self._visit(_accumulate_reduce_indices, (expr_to_info,))
|
||||
structop_roots = []
|
||||
self._visit(_gather_structured_op, (expr_to_info, structop_roots))
|
||||
|
||||
# Handle the root of the top level expression.
|
||||
if not structop_roots or structop_roots[-1] != self:
|
||||
# The top level expression is not a reduction. Add the top level
|
||||
# expression as a structured op root.
|
||||
structop_roots.append(self)
|
||||
|
||||
# Use user specified information for the destination tensor to build an
|
||||
# _StructOpInfo for the top level expression.
|
||||
expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
|
||||
tuple(dst.shape),
|
||||
self.dtype(), dst.name,
|
||||
dst.format)
|
||||
|
||||
return structop_roots
|
||||
|
||||
def _validate_and_collect_expr_info(
|
||||
self,
|
||||
dst: "Tensor",
|
||||
dst_indices: Tuple["IndexVar", ...],
|
||||
) -> _ExprInfoDict:
|
||||
"""Propagates expression information for validation.
|
||||
|
||||
Propagates the indices used by child expression nodes to parent expression
|
||||
nodes. Also collects and validates the sizes for the dimensions
|
||||
corresponding to the indices.
|
||||
|
||||
Args:
|
||||
dst: A destination Tensor that accepts the value of the expression tree.
|
||||
dst_indices: The indices used by the destination index expression.
|
||||
|
||||
Raises:
|
||||
ValueError if there is any inconsistency in indices or dimensional
|
||||
values.
|
||||
|
||||
Returns:
|
||||
A dictionary of (IndexExpr, _ExprInfo).
|
||||
"""
|
||||
expr_to_info = {}
|
||||
# Validate the expression tree and construct expression information.
|
||||
self._visit(_validate_and_collect_expr_info, (expr_to_info,))
|
||||
|
||||
# Validate the destination dimension information.
|
||||
info = expr_to_info[self]
|
||||
index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
|
||||
for i, d, in zip(dst_indices, dst.shape):
|
||||
if i not in index_to_dim_info:
|
||||
raise ValueError("Destination IndexVar not used in the "
|
||||
f"source expression: {i}")
|
||||
else:
|
||||
if d != index_to_dim_info[i].dim:
|
||||
raise ValueError(f"Inconsistent destination dimension for {i}: "
|
||||
f"{d} vs {index_to_dim_info[i].dim}")
|
||||
|
||||
return expr_to_info
|
||||
|
||||
def _emit_assignment(
|
||||
self,
|
||||
module: ir.Module,
|
||||
dst: "Tensor",
|
||||
dst_indices: Tuple["IndexVar", ...],
|
||||
expr_to_info: _ExprInfoDict,
|
||||
input_accesses: List["Access"],
|
||||
) -> None:
|
||||
"""Emits an MLIR function for assigning the expression to a tensor."""
|
||||
input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
|
||||
|
||||
# Build the kernel for the operations.
|
||||
with ir.InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
|
||||
def linalg_funcop(*args):
|
||||
# Set up the mapping from the Access nodes to their MLIR values.
|
||||
for e, mlir in zip(input_accesses, args):
|
||||
expr_to_info[e].mlir_value = mlir
|
||||
|
||||
# Emit structured ops in the linalg dialect to implement the assignment.
|
||||
for structop_root in self._identify_structured_ops(
|
||||
expr_to_info, dst, dst_indices):
|
||||
structop_root._emit_structured_op(expr_to_info)
|
||||
dst._record_stats(expr_to_info[structop_root].structop_info)
|
||||
|
||||
# The function returns the MLIR value of the root expression.
|
||||
return expr_to_info[self].mlir_value
|
||||
|
||||
linalg_funcop.func_op.attributes[
|
||||
"llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
|
||||
def get_input_accesses(self) -> List["Access"]:
|
||||
"""Compute the list of input accesses for the expression."""
|
||||
input_accesses = []
|
||||
self._visit(_gather_input_accesses_index_vars, (input_accesses,))
|
||||
return input_accesses
|
||||
|
||||
def compile(
|
||||
self,
|
||||
dst: "Tensor",
|
||||
dst_indices: Tuple["IndexVar", ...],
|
||||
) -> execution_engine.ExecutionEngine:
|
||||
"""Compiles the tensor assignment dst[dst_indices] = expression.
|
||||
|
||||
Args:
|
||||
dst: The destination tensor.
|
||||
dst_indices: The tuple of IndexVar used to access the destination tensor.
|
||||
|
||||
Returns:
|
||||
The execution engine for the tensor assignment.
|
||||
|
||||
Raises:
|
||||
ValueError: If the expression is not proper or not supported.
|
||||
"""
|
||||
expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
|
||||
input_accesses = self.get_input_accesses()
|
||||
|
||||
# Build and compile the module to produce the execution engine.
|
||||
with ir.Context(), ir.Location.unknown():
|
||||
module = ir.Module.create()
|
||||
self._emit_assignment(module, dst, dst_indices, expr_to_info,
|
||||
input_accesses)
|
||||
engine = utils.compile_and_build_engine(module)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
class _AtomicCounter:
|
||||
"""An atomic counter."""
|
||||
|
||||
|
@ -1203,364 +1562,6 @@ class _ExprInfo:
|
|||
self.acc_reduce_indices = self.acc_reduce_indices or set()
|
||||
|
||||
|
||||
class IndexExpr(abc.ABC):
|
||||
"""The index notation base class.
|
||||
|
||||
We support the TACO API index_expression class with an alias of this class.
|
||||
"""
|
||||
|
||||
def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
|
||||
"""Verifies the RHS operand and returns a binary expression.
|
||||
|
||||
Args:
|
||||
rhs: The RHS of the binary operation, which could be any Python object
|
||||
from user inputs.
|
||||
op: A _BinaryOp object representing the binary operator.
|
||||
|
||||
Raises:
|
||||
ValueError: If rhs is not an IndexExpr.
|
||||
"""
|
||||
if not isinstance(rhs, IndexExpr):
|
||||
raise ValueError(f"Expected IndexExpr: {rhs}")
|
||||
return _BinaryExpr(op, self, rhs)
|
||||
|
||||
def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
|
||||
"""Build a unary expression.
|
||||
|
||||
Args:
|
||||
op: A _UnaryOp object representing the unary operation.
|
||||
"""
|
||||
return _UnaryExpr(op, self)
|
||||
|
||||
def __add__(self, rhs) -> "_BinaryExpr":
|
||||
"""Defines the operator +.
|
||||
|
||||
Args:
|
||||
rhs: The value being added, which could be any Python object from user
|
||||
inputs.
|
||||
|
||||
Returns:
|
||||
A _BinaryExpr object representing the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If rhs is not an IndexExpr.
|
||||
"""
|
||||
return self._verify_operand_and_build_expr(rhs, operator.add)
|
||||
|
||||
def __mul__(self, rhs) -> "_BinaryExpr":
|
||||
"""Defines the operator *.
|
||||
|
||||
Args:
|
||||
rhs: The value being multiplied, which could be any Python object from
|
||||
user inputs.
|
||||
|
||||
Returns:
|
||||
A _BinaryExpr object representing the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If rhs is not an IndexExpr.
|
||||
"""
|
||||
return self._verify_operand_and_build_expr(rhs, operator.mul)
|
||||
|
||||
def __abs__(self) -> "_UnaryExpr":
|
||||
"""Defines the operator abs.
|
||||
|
||||
Returns:
|
||||
A _UnaryExpr object representing the operation.
|
||||
"""
|
||||
return self._build_unary_expr(operator.abs)
|
||||
|
||||
def __neg__(self) -> "_UnaryExpr":
|
||||
"""Defines the operator neg.
|
||||
|
||||
Returns:
|
||||
A _UnaryExpr object representing the operation.
|
||||
"""
|
||||
return self._build_unary_expr(operator.neg)
|
||||
|
||||
def __sub__(self, rhs) -> "_BinaryExpr":
|
||||
"""Defines the operator -.
|
||||
|
||||
Args:
|
||||
rhs: The value being subtracted, which could be any Python object from
|
||||
user inputs.
|
||||
|
||||
Returns:
|
||||
A _BinaryExpr object representing the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If rhs is not an IndexExpr.
|
||||
"""
|
||||
return self._verify_operand_and_build_expr(rhs, operator.sub)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _visit(self,
|
||||
func: _ExprVisitor,
|
||||
args,
|
||||
*,
|
||||
leaf_checker: _SubtreeLeafChecker = None) -> None:
|
||||
"""A post-order visitor.
|
||||
|
||||
Args:
|
||||
func: A callable applied to each node in the expression tree.
|
||||
args: The variable-length arguments passed to the callable. These
|
||||
arguments are grouped as an iterable and will be unpacked before passing
|
||||
to the callable. This is to enable the keyword argument only syntax
|
||||
after this argument.
|
||||
leaf_checker: A callable object to identify nodes that should be treated
|
||||
as leaf nodes to support partial tree visiting.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _emit_expression(
|
||||
self,
|
||||
expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
|
||||
expr_to_info: _ExprInfoDict,
|
||||
) -> lang.ScalarExpression:
|
||||
"""Emits MLIR for the expression tree.
|
||||
|
||||
Args:
|
||||
expr_to_opnd: A dictionary for looking up structured op input operands for
|
||||
the input nodes of the structured op.
|
||||
expr_to_info: A dictionary for looking up code generation information for
|
||||
expressions.
|
||||
|
||||
Returns:
|
||||
A linalg dialect ScalarExpression for the expression.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def dtype(self) -> DType:
|
||||
"""Returns the data type for the result of the expression."""
|
||||
pass
|
||||
|
||||
def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
|
||||
"""Emits a structured op in the linalg dialect for the expression tree.
|
||||
|
||||
We define a DefineOpcallable in the domain specific language for the linalg
|
||||
dialect and execute the callable to generate the structured op. Self is the
|
||||
root of the expression tree for the structured op.
|
||||
|
||||
Args:
|
||||
expr_to_info: A dictionary for looking up code generation information for
|
||||
expressions.
|
||||
"""
|
||||
op_info = expr_to_info[self].structop_info
|
||||
op_name = op_info.dst_name
|
||||
op_def = lang.LinalgOpDef(name=op_name)
|
||||
op_callable = lang.DefinedOpCallable(op_name, op_def)
|
||||
|
||||
# Collect the input expression nodes for the structured op.
|
||||
expr_inputs = []
|
||||
self._visit(
|
||||
_gather_structured_op_input,
|
||||
(self, expr_to_info, expr_inputs),
|
||||
leaf_checker=_is_structured_op_leaf,
|
||||
)
|
||||
|
||||
# Create a linalg structured op operand for each input expression node and
|
||||
# build a dictionary for looking up the information.
|
||||
expr_to_input_opnd = {
|
||||
e: _emit_structured_op_input(e, expr_to_info, op_def)
|
||||
for e in expr_inputs
|
||||
}
|
||||
|
||||
# Emit the expression tree, which produces the value assigned to the
|
||||
# destination tensor.
|
||||
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
|
||||
# Emit the structured op representation for the destination tensor.
|
||||
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
|
||||
lang.OperandKind.OUTPUT_TENSOR)
|
||||
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
|
||||
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
|
||||
|
||||
expr_info = expr_to_info[self]
|
||||
# If the structured op reduces some indices, explicitly represent the
|
||||
# reduction. This is done by generating a ReduceFn for the dimensions being
|
||||
# reduced in the linalg dialect and calling the function with the value
|
||||
# being reduced. We only support add reduction currently.
|
||||
if expr_info.reduce_indices:
|
||||
reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
|
||||
value = lang.ReduceFn.add[reduce_dims](value)
|
||||
|
||||
# Emit the assignment as a comprehension in the linalg dialect.
|
||||
comp = lang.Comprehension((dst_use, value))
|
||||
op_def.comprehensions.append(comp)
|
||||
|
||||
# The structured op in the linalg dialect requires an explicit
|
||||
# initialization for the destination tensor. Emit MLIR to initialize the
|
||||
# destination tensor.
|
||||
init = op_info.emit_tensor_init()
|
||||
|
||||
# Collect MLIR values for the linalg input operands, with the assumption
|
||||
# that dictionary preserves the insertion order.
|
||||
args = [
|
||||
expr_to_info[expr].mlir_value
|
||||
for expr, opnd in expr_to_input_opnd.items()
|
||||
]
|
||||
# Execute the DefineOpcallable object for the linalg dialect operation to
|
||||
# emit MLIR for the linalg structured op.
|
||||
expr_info.mlir_value = op_callable(*args, outs=[init])
|
||||
|
||||
def _identify_structured_ops(
|
||||
self,
|
||||
expr_to_info: _ExprInfoDict,
|
||||
dst: Tensor,
|
||||
dst_indices: Tuple[IndexVar, ...],
|
||||
) -> List["IndexExpr"]:
|
||||
"""Returns expression nodes for the roots of the identified structured ops.
|
||||
|
||||
A structured op in the linalg dialect only supports reduction performed on
|
||||
the whole expression. If the expression tree contains reduction that are
|
||||
performed on part of the expression tree, the expression tree needs to be
|
||||
implemented with multiple structured ops. This routine identifies all the
|
||||
expression nodes that contain reduction as the root of structured ops in the
|
||||
linalg dialect.
|
||||
|
||||
Args:
|
||||
expr_to_info: A dictionary for looking up code generation information for
|
||||
expressions.
|
||||
dst: A destination Tensor that accepts the value of the expression tree.
|
||||
dst_indices: The indices used by the destination index expression.
|
||||
|
||||
Returns:
|
||||
An ordered list of IndexExpr for the root expressions of the structured
|
||||
ops, where child expressions go before parent expressions that use their
|
||||
results.
|
||||
"""
|
||||
reduce_indices = tuple(
|
||||
set(expr_to_info[self].src_indices) - set(dst_indices))
|
||||
for reduce_index in reduce_indices:
|
||||
_mark_structured_op_root(self, reduce_index, expr_to_info)
|
||||
|
||||
self._visit(_accumulate_reduce_indices, (expr_to_info,))
|
||||
structop_roots = []
|
||||
self._visit(_gather_structured_op, (expr_to_info, structop_roots))
|
||||
|
||||
# Handle the root of the top level expression.
|
||||
if not structop_roots or structop_roots[-1] != self:
|
||||
# The top level expression is not a reduction. Add the top level
|
||||
# expression as a structured op root.
|
||||
structop_roots.append(self)
|
||||
|
||||
# Use user specified information for the destination tensor to build an
|
||||
# _StructOpInfo for the top level expression.
|
||||
expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
|
||||
tuple(dst.shape),
|
||||
self.dtype(), dst.name,
|
||||
dst.format)
|
||||
|
||||
return structop_roots
|
||||
|
||||
def _validate_and_collect_expr_info(
|
||||
self,
|
||||
dst: Tensor,
|
||||
dst_indices: Tuple[IndexVar, ...],
|
||||
) -> _ExprInfoDict:
|
||||
"""Propagates expression information for validation.
|
||||
|
||||
Propagates the indices used by child expression nodes to parent expression
|
||||
nodes. Also collects and validates the sizes for the dimensions
|
||||
corresponding to the indices.
|
||||
|
||||
Args:
|
||||
dst: A destination Tensor that accepts the value of the expression tree.
|
||||
dst_indices: The indices used by the destination index expression.
|
||||
|
||||
Raises:
|
||||
ValueError if there is any inconsistency in indices or dimensional
|
||||
values.
|
||||
|
||||
Returns:
|
||||
A dictionary of (IndexExpr, _ExprInfo).
|
||||
"""
|
||||
expr_to_info = {}
|
||||
# Validate the expression tree and construct expression information.
|
||||
self._visit(_validate_and_collect_expr_info, (expr_to_info,))
|
||||
|
||||
# Validate the destination dimension information.
|
||||
info = expr_to_info[self]
|
||||
index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
|
||||
for i, d, in zip(dst_indices, dst.shape):
|
||||
if i not in index_to_dim_info:
|
||||
raise ValueError("Destination IndexVar not used in the "
|
||||
f"source expression: {i}")
|
||||
else:
|
||||
if d != index_to_dim_info[i].dim:
|
||||
raise ValueError(f"Inconsistent destination dimension for {i}: "
|
||||
f"{d} vs {index_to_dim_info[i].dim}")
|
||||
|
||||
return expr_to_info
|
||||
|
||||
def _emit_assignment(
|
||||
self,
|
||||
module: ir.Module,
|
||||
dst: Tensor,
|
||||
dst_indices: Tuple[IndexVar, ...],
|
||||
expr_to_info: _ExprInfoDict,
|
||||
input_accesses: List["Access"],
|
||||
) -> None:
|
||||
"""Emits an MLIR function for assigning the expression to a tensor."""
|
||||
input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
|
||||
|
||||
# Build the kernel for the operations.
|
||||
with ir.InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
|
||||
def linalg_funcop(*args):
|
||||
# Set up the mapping from the Access nodes to their MLIR values.
|
||||
for e, mlir in zip(input_accesses, args):
|
||||
expr_to_info[e].mlir_value = mlir
|
||||
|
||||
# Emit structured ops in the linalg dialect to implement the assignment.
|
||||
for structop_root in self._identify_structured_ops(
|
||||
expr_to_info, dst, dst_indices):
|
||||
structop_root._emit_structured_op(expr_to_info)
|
||||
dst._record_stats(expr_to_info[structop_root].structop_info)
|
||||
|
||||
# The function returns the MLIR value of the root expression.
|
||||
return expr_to_info[self].mlir_value
|
||||
|
||||
linalg_funcop.func_op.attributes[
|
||||
"llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
|
||||
def get_input_accesses(self) -> List["Access"]:
|
||||
"""Compute the list of input accesses for the expression."""
|
||||
input_accesses = []
|
||||
self._visit(_gather_input_accesses_index_vars, (input_accesses,))
|
||||
return input_accesses
|
||||
|
||||
def compile(
|
||||
self,
|
||||
dst: Tensor,
|
||||
dst_indices: Tuple[IndexVar, ...],
|
||||
) -> execution_engine.ExecutionEngine:
|
||||
"""Compiles the tensor assignment dst[dst_indices] = expression.
|
||||
|
||||
Args:
|
||||
dst: The destination tensor.
|
||||
dst_indices: The tuple of IndexVar used to access the destination tensor.
|
||||
|
||||
Returns:
|
||||
The execution engine for the tensor assignment.
|
||||
|
||||
Raises:
|
||||
ValueError: If the expression is not proper or not supported.
|
||||
"""
|
||||
expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
|
||||
input_accesses = self.get_input_accesses()
|
||||
|
||||
# Build and compile the module to produce the execution engine.
|
||||
with ir.Context(), ir.Location.unknown():
|
||||
module = ir.Module.create()
|
||||
self._emit_assignment(module, dst, dst_indices, expr_to_info,
|
||||
input_accesses)
|
||||
engine = utils.compile_and_build_engine(module)
|
||||
|
||||
return engine
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Access(IndexExpr):
|
||||
"""The tensor access class.
|
||||
|
|
Loading…
Reference in New Issue