From 3a4229696df39665c260d0d94f10ef2ef23d3370 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Mon, 14 Mar 2022 15:43:12 -0700 Subject: [PATCH] [mlir][sparse][taco] Reorder a class. Define IndexExpr before IndexVar. This is to prepare for the next change to support the use of index values in tensor expressions. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D121649 --- .../SparseTensor/taco/tools/mlir_pytaco.py | 717 +++++++++--------- 1 file changed, 359 insertions(+), 358 deletions(-) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py index 2dce759bfca0..b0cb21694483 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -365,6 +365,365 @@ def _make_format(formats: List[ModeFormat], return Format(ModeFormatPack(formats), ModeOrdering(ordering)) +class IndexExpr(abc.ABC): + """The index notation base class. + + We support the TACO API index_expression class with an alias of this class. + """ + + def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr": + """Verifies the RHS operand and returns a binary expression. + + Args: + rhs: The RHS of the binary operation, which could be any Python object + from user inputs. + op: A _BinaryOp object representing the binary operator. + + Raises: + ValueError: If rhs is not an IndexExpr. + """ + if not isinstance(rhs, IndexExpr): + raise ValueError(f"Expected IndexExpr: {rhs}") + return _BinaryExpr(op, self, rhs) + + def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr": + """Build a unary expression. + + Args: + op: A _UnaryOp object representing the unary operation. + """ + return _UnaryExpr(op, self) + + def __add__(self, rhs) -> "_BinaryExpr": + """Defines the operator +. + + Args: + rhs: The value being added, which could be any Python object from user + inputs. + + Returns: + A _BinaryExpr object representing the operation. + + Raises: + ValueError: If rhs is not an IndexExpr. + """ + return self._verify_operand_and_build_expr(rhs, operator.add) + + def __mul__(self, rhs) -> "_BinaryExpr": + """Defines the operator *. + + Args: + rhs: The value being multiplied, which could be any Python object from + user inputs. + + Returns: + A _BinaryExpr object representing the operation. + + Raises: + ValueError: If rhs is not an IndexExpr. + """ + return self._verify_operand_and_build_expr(rhs, operator.mul) + + def __abs__(self) -> "_UnaryExpr": + """Defines the operator abs. + + Returns: + A _UnaryExpr object representing the operation. + """ + return self._build_unary_expr(operator.abs) + + def __neg__(self) -> "_UnaryExpr": + """Defines the operator neg. + + Returns: + A _UnaryExpr object representing the operation. + """ + return self._build_unary_expr(operator.neg) + + def __sub__(self, rhs) -> "_BinaryExpr": + """Defines the operator -. + + Args: + rhs: The value being subtracted, which could be any Python object from + user inputs. + + Returns: + A _BinaryExpr object representing the operation. + + Raises: + ValueError: If rhs is not an IndexExpr. + """ + return self._verify_operand_and_build_expr(rhs, operator.sub) + + @abc.abstractmethod + def _visit(self, + func: _ExprVisitor, + args, + *, + leaf_checker: _SubtreeLeafChecker = None) -> None: + """A post-order visitor. + + Args: + func: A callable applied to each node in the expression tree. + args: The variable-length arguments passed to the callable. These + arguments are grouped as an iterable and will be unpacked before passing + to the callable. This is to enable the keyword argument only syntax + after this argument. + leaf_checker: A callable object to identify nodes that should be treated + as leaf nodes to support partial tree visiting. + """ + pass + + @abc.abstractmethod + def _emit_expression( + self, + expr_to_opnd: Dict["IndexExpr", lang.OperandDef], + expr_to_info: _ExprInfoDict, + ) -> lang.ScalarExpression: + """Emits MLIR for the expression tree. + + Args: + expr_to_opnd: A dictionary for looking up structured op input operands for + the input nodes of the structured op. + expr_to_info: A dictionary for looking up code generation information for + expressions. + + Returns: + A linalg dialect ScalarExpression for the expression. + """ + pass + + @abc.abstractmethod + def dtype(self) -> DType: + """Returns the data type for the result of the expression.""" + pass + + def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None: + """Emits a structured op in the linalg dialect for the expression tree. + + We define a DefineOpcallable in the domain specific language for the linalg + dialect and execute the callable to generate the structured op. Self is the + root of the expression tree for the structured op. + + Args: + expr_to_info: A dictionary for looking up code generation information for + expressions. + """ + op_info = expr_to_info[self].structop_info + op_name = op_info.dst_name + op_def = lang.LinalgOpDef(name=op_name) + op_callable = lang.DefinedOpCallable(op_name, op_def) + + # Collect the input expression nodes for the structured op. + expr_inputs = [] + self._visit( + _gather_structured_op_input, + (self, expr_to_info, expr_inputs), + leaf_checker=_is_structured_op_leaf, + ) + + # Create a linalg structured op operand for each input expression node and + # build a dictionary for looking up the information. + expr_to_input_opnd = { + e: _emit_structured_op_input(e, expr_to_info, op_def) + for e in expr_inputs + } + + # Emit the expression tree, which produces the value assigned to the + # destination tensor. + value = self._emit_expression(expr_to_input_opnd, expr_to_info) + # Emit the structured op representation for the destination tensor. + dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name, + lang.OperandKind.OUTPUT_TENSOR) + dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices) + dst_use = lang.TensorUse(dst_opnd, dst_dim_syms) + + expr_info = expr_to_info[self] + # If the structured op reduces some indices, explicitly represent the + # reduction. This is done by generating a ReduceFn for the dimensions being + # reduced in the linalg dialect and calling the function with the value + # being reduced. We only support add reduction currently. + if expr_info.reduce_indices: + reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices) + value = lang.ReduceFn.add[reduce_dims](value) + + # Emit the assignment as a comprehension in the linalg dialect. + comp = lang.Comprehension((dst_use, value)) + op_def.comprehensions.append(comp) + + # The structured op in the linalg dialect requires an explicit + # initialization for the destination tensor. Emit MLIR to initialize the + # destination tensor. + init = op_info.emit_tensor_init() + + # Collect MLIR values for the linalg input operands, with the assumption + # that dictionary preserves the insertion order. + args = [ + expr_to_info[expr].mlir_value + for expr, opnd in expr_to_input_opnd.items() + ] + # Execute the DefineOpcallable object for the linalg dialect operation to + # emit MLIR for the linalg structured op. + expr_info.mlir_value = op_callable(*args, outs=[init]) + + def _identify_structured_ops( + self, + expr_to_info: _ExprInfoDict, + dst: "Tensor", + dst_indices: Tuple["IndexVar", ...], + ) -> List["IndexExpr"]: + """Returns expression nodes for the roots of the identified structured ops. + + A structured op in the linalg dialect only supports reduction performed on + the whole expression. If the expression tree contains reduction that are + performed on part of the expression tree, the expression tree needs to be + implemented with multiple structured ops. This routine identifies all the + expression nodes that contain reduction as the root of structured ops in the + linalg dialect. + + Args: + expr_to_info: A dictionary for looking up code generation information for + expressions. + dst: A destination Tensor that accepts the value of the expression tree. + dst_indices: The indices used by the destination index expression. + + Returns: + An ordered list of IndexExpr for the root expressions of the structured + ops, where child expressions go before parent expressions that use their + results. + """ + reduce_indices = tuple( + set(expr_to_info[self].src_indices) - set(dst_indices)) + for reduce_index in reduce_indices: + _mark_structured_op_root(self, reduce_index, expr_to_info) + + self._visit(_accumulate_reduce_indices, (expr_to_info,)) + structop_roots = [] + self._visit(_gather_structured_op, (expr_to_info, structop_roots)) + + # Handle the root of the top level expression. + if not structop_roots or structop_roots[-1] != self: + # The top level expression is not a reduction. Add the top level + # expression as a structured op root. + structop_roots.append(self) + + # Use user specified information for the destination tensor to build an + # _StructOpInfo for the top level expression. + expr_to_info[self].structop_info = _StructOpInfo(dst_indices, + tuple(dst.shape), + self.dtype(), dst.name, + dst.format) + + return structop_roots + + def _validate_and_collect_expr_info( + self, + dst: "Tensor", + dst_indices: Tuple["IndexVar", ...], + ) -> _ExprInfoDict: + """Propagates expression information for validation. + + Propagates the indices used by child expression nodes to parent expression + nodes. Also collects and validates the sizes for the dimensions + corresponding to the indices. + + Args: + dst: A destination Tensor that accepts the value of the expression tree. + dst_indices: The indices used by the destination index expression. + + Raises: + ValueError if there is any inconsistency in indices or dimensional + values. + + Returns: + A dictionary of (IndexExpr, _ExprInfo). + """ + expr_to_info = {} + # Validate the expression tree and construct expression information. + self._visit(_validate_and_collect_expr_info, (expr_to_info,)) + + # Validate the destination dimension information. + info = expr_to_info[self] + index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)} + for i, d, in zip(dst_indices, dst.shape): + if i not in index_to_dim_info: + raise ValueError("Destination IndexVar not used in the " + f"source expression: {i}") + else: + if d != index_to_dim_info[i].dim: + raise ValueError(f"Inconsistent destination dimension for {i}: " + f"{d} vs {index_to_dim_info[i].dim}") + + return expr_to_info + + def _emit_assignment( + self, + module: ir.Module, + dst: "Tensor", + dst_indices: Tuple["IndexVar", ...], + expr_to_info: _ExprInfoDict, + input_accesses: List["Access"], + ) -> None: + """Emits an MLIR function for assigning the expression to a tensor.""" + input_types = [a.tensor.mlir_tensor_type() for a in input_accesses] + + # Build the kernel for the operations. + with ir.InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME) + def linalg_funcop(*args): + # Set up the mapping from the Access nodes to their MLIR values. + for e, mlir in zip(input_accesses, args): + expr_to_info[e].mlir_value = mlir + + # Emit structured ops in the linalg dialect to implement the assignment. + for structop_root in self._identify_structured_ops( + expr_to_info, dst, dst_indices): + structop_root._emit_structured_op(expr_to_info) + dst._record_stats(expr_to_info[structop_root].structop_info) + + # The function returns the MLIR value of the root expression. + return expr_to_info[self].mlir_value + + linalg_funcop.func_op.attributes[ + "llvm.emit_c_interface"] = ir.UnitAttr.get() + + def get_input_accesses(self) -> List["Access"]: + """Compute the list of input accesses for the expression.""" + input_accesses = [] + self._visit(_gather_input_accesses_index_vars, (input_accesses,)) + return input_accesses + + def compile( + self, + dst: "Tensor", + dst_indices: Tuple["IndexVar", ...], + ) -> execution_engine.ExecutionEngine: + """Compiles the tensor assignment dst[dst_indices] = expression. + + Args: + dst: The destination tensor. + dst_indices: The tuple of IndexVar used to access the destination tensor. + + Returns: + The execution engine for the tensor assignment. + + Raises: + ValueError: If the expression is not proper or not supported. + """ + expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices) + input_accesses = self.get_input_accesses() + + # Build and compile the module to produce the execution engine. + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + self._emit_assignment(module, dst, dst_indices, expr_to_info, + input_accesses) + engine = utils.compile_and_build_engine(module) + + return engine + + class _AtomicCounter: """An atomic counter.""" @@ -1203,364 +1562,6 @@ class _ExprInfo: self.acc_reduce_indices = self.acc_reduce_indices or set() -class IndexExpr(abc.ABC): - """The index notation base class. - - We support the TACO API index_expression class with an alias of this class. - """ - - def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr": - """Verifies the RHS operand and returns a binary expression. - - Args: - rhs: The RHS of the binary operation, which could be any Python object - from user inputs. - op: A _BinaryOp object representing the binary operator. - - Raises: - ValueError: If rhs is not an IndexExpr. - """ - if not isinstance(rhs, IndexExpr): - raise ValueError(f"Expected IndexExpr: {rhs}") - return _BinaryExpr(op, self, rhs) - - def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr": - """Build a unary expression. - - Args: - op: A _UnaryOp object representing the unary operation. - """ - return _UnaryExpr(op, self) - - def __add__(self, rhs) -> "_BinaryExpr": - """Defines the operator +. - - Args: - rhs: The value being added, which could be any Python object from user - inputs. - - Returns: - A _BinaryExpr object representing the operation. - - Raises: - ValueError: If rhs is not an IndexExpr. - """ - return self._verify_operand_and_build_expr(rhs, operator.add) - - def __mul__(self, rhs) -> "_BinaryExpr": - """Defines the operator *. - - Args: - rhs: The value being multiplied, which could be any Python object from - user inputs. - - Returns: - A _BinaryExpr object representing the operation. - - Raises: - ValueError: If rhs is not an IndexExpr. - """ - return self._verify_operand_and_build_expr(rhs, operator.mul) - - def __abs__(self) -> "_UnaryExpr": - """Defines the operator abs. - - Returns: - A _UnaryExpr object representing the operation. - """ - return self._build_unary_expr(operator.abs) - - def __neg__(self) -> "_UnaryExpr": - """Defines the operator neg. - - Returns: - A _UnaryExpr object representing the operation. - """ - return self._build_unary_expr(operator.neg) - - def __sub__(self, rhs) -> "_BinaryExpr": - """Defines the operator -. - - Args: - rhs: The value being subtracted, which could be any Python object from - user inputs. - - Returns: - A _BinaryExpr object representing the operation. - - Raises: - ValueError: If rhs is not an IndexExpr. - """ - return self._verify_operand_and_build_expr(rhs, operator.sub) - - @abc.abstractmethod - def _visit(self, - func: _ExprVisitor, - args, - *, - leaf_checker: _SubtreeLeafChecker = None) -> None: - """A post-order visitor. - - Args: - func: A callable applied to each node in the expression tree. - args: The variable-length arguments passed to the callable. These - arguments are grouped as an iterable and will be unpacked before passing - to the callable. This is to enable the keyword argument only syntax - after this argument. - leaf_checker: A callable object to identify nodes that should be treated - as leaf nodes to support partial tree visiting. - """ - pass - - @abc.abstractmethod - def _emit_expression( - self, - expr_to_opnd: Dict["IndexExpr", lang.OperandDef], - expr_to_info: _ExprInfoDict, - ) -> lang.ScalarExpression: - """Emits MLIR for the expression tree. - - Args: - expr_to_opnd: A dictionary for looking up structured op input operands for - the input nodes of the structured op. - expr_to_info: A dictionary for looking up code generation information for - expressions. - - Returns: - A linalg dialect ScalarExpression for the expression. - """ - pass - - @abc.abstractmethod - def dtype(self) -> DType: - """Returns the data type for the result of the expression.""" - pass - - def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None: - """Emits a structured op in the linalg dialect for the expression tree. - - We define a DefineOpcallable in the domain specific language for the linalg - dialect and execute the callable to generate the structured op. Self is the - root of the expression tree for the structured op. - - Args: - expr_to_info: A dictionary for looking up code generation information for - expressions. - """ - op_info = expr_to_info[self].structop_info - op_name = op_info.dst_name - op_def = lang.LinalgOpDef(name=op_name) - op_callable = lang.DefinedOpCallable(op_name, op_def) - - # Collect the input expression nodes for the structured op. - expr_inputs = [] - self._visit( - _gather_structured_op_input, - (self, expr_to_info, expr_inputs), - leaf_checker=_is_structured_op_leaf, - ) - - # Create a linalg structured op operand for each input expression node and - # build a dictionary for looking up the information. - expr_to_input_opnd = { - e: _emit_structured_op_input(e, expr_to_info, op_def) - for e in expr_inputs - } - - # Emit the expression tree, which produces the value assigned to the - # destination tensor. - value = self._emit_expression(expr_to_input_opnd, expr_to_info) - # Emit the structured op representation for the destination tensor. - dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name, - lang.OperandKind.OUTPUT_TENSOR) - dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices) - dst_use = lang.TensorUse(dst_opnd, dst_dim_syms) - - expr_info = expr_to_info[self] - # If the structured op reduces some indices, explicitly represent the - # reduction. This is done by generating a ReduceFn for the dimensions being - # reduced in the linalg dialect and calling the function with the value - # being reduced. We only support add reduction currently. - if expr_info.reduce_indices: - reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices) - value = lang.ReduceFn.add[reduce_dims](value) - - # Emit the assignment as a comprehension in the linalg dialect. - comp = lang.Comprehension((dst_use, value)) - op_def.comprehensions.append(comp) - - # The structured op in the linalg dialect requires an explicit - # initialization for the destination tensor. Emit MLIR to initialize the - # destination tensor. - init = op_info.emit_tensor_init() - - # Collect MLIR values for the linalg input operands, with the assumption - # that dictionary preserves the insertion order. - args = [ - expr_to_info[expr].mlir_value - for expr, opnd in expr_to_input_opnd.items() - ] - # Execute the DefineOpcallable object for the linalg dialect operation to - # emit MLIR for the linalg structured op. - expr_info.mlir_value = op_callable(*args, outs=[init]) - - def _identify_structured_ops( - self, - expr_to_info: _ExprInfoDict, - dst: Tensor, - dst_indices: Tuple[IndexVar, ...], - ) -> List["IndexExpr"]: - """Returns expression nodes for the roots of the identified structured ops. - - A structured op in the linalg dialect only supports reduction performed on - the whole expression. If the expression tree contains reduction that are - performed on part of the expression tree, the expression tree needs to be - implemented with multiple structured ops. This routine identifies all the - expression nodes that contain reduction as the root of structured ops in the - linalg dialect. - - Args: - expr_to_info: A dictionary for looking up code generation information for - expressions. - dst: A destination Tensor that accepts the value of the expression tree. - dst_indices: The indices used by the destination index expression. - - Returns: - An ordered list of IndexExpr for the root expressions of the structured - ops, where child expressions go before parent expressions that use their - results. - """ - reduce_indices = tuple( - set(expr_to_info[self].src_indices) - set(dst_indices)) - for reduce_index in reduce_indices: - _mark_structured_op_root(self, reduce_index, expr_to_info) - - self._visit(_accumulate_reduce_indices, (expr_to_info,)) - structop_roots = [] - self._visit(_gather_structured_op, (expr_to_info, structop_roots)) - - # Handle the root of the top level expression. - if not structop_roots or structop_roots[-1] != self: - # The top level expression is not a reduction. Add the top level - # expression as a structured op root. - structop_roots.append(self) - - # Use user specified information for the destination tensor to build an - # _StructOpInfo for the top level expression. - expr_to_info[self].structop_info = _StructOpInfo(dst_indices, - tuple(dst.shape), - self.dtype(), dst.name, - dst.format) - - return structop_roots - - def _validate_and_collect_expr_info( - self, - dst: Tensor, - dst_indices: Tuple[IndexVar, ...], - ) -> _ExprInfoDict: - """Propagates expression information for validation. - - Propagates the indices used by child expression nodes to parent expression - nodes. Also collects and validates the sizes for the dimensions - corresponding to the indices. - - Args: - dst: A destination Tensor that accepts the value of the expression tree. - dst_indices: The indices used by the destination index expression. - - Raises: - ValueError if there is any inconsistency in indices or dimensional - values. - - Returns: - A dictionary of (IndexExpr, _ExprInfo). - """ - expr_to_info = {} - # Validate the expression tree and construct expression information. - self._visit(_validate_and_collect_expr_info, (expr_to_info,)) - - # Validate the destination dimension information. - info = expr_to_info[self] - index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)} - for i, d, in zip(dst_indices, dst.shape): - if i not in index_to_dim_info: - raise ValueError("Destination IndexVar not used in the " - f"source expression: {i}") - else: - if d != index_to_dim_info[i].dim: - raise ValueError(f"Inconsistent destination dimension for {i}: " - f"{d} vs {index_to_dim_info[i].dim}") - - return expr_to_info - - def _emit_assignment( - self, - module: ir.Module, - dst: Tensor, - dst_indices: Tuple[IndexVar, ...], - expr_to_info: _ExprInfoDict, - input_accesses: List["Access"], - ) -> None: - """Emits an MLIR function for assigning the expression to a tensor.""" - input_types = [a.tensor.mlir_tensor_type() for a in input_accesses] - - # Build the kernel for the operations. - with ir.InsertionPoint(module.body): - - @builtin.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME) - def linalg_funcop(*args): - # Set up the mapping from the Access nodes to their MLIR values. - for e, mlir in zip(input_accesses, args): - expr_to_info[e].mlir_value = mlir - - # Emit structured ops in the linalg dialect to implement the assignment. - for structop_root in self._identify_structured_ops( - expr_to_info, dst, dst_indices): - structop_root._emit_structured_op(expr_to_info) - dst._record_stats(expr_to_info[structop_root].structop_info) - - # The function returns the MLIR value of the root expression. - return expr_to_info[self].mlir_value - - linalg_funcop.func_op.attributes[ - "llvm.emit_c_interface"] = ir.UnitAttr.get() - - def get_input_accesses(self) -> List["Access"]: - """Compute the list of input accesses for the expression.""" - input_accesses = [] - self._visit(_gather_input_accesses_index_vars, (input_accesses,)) - return input_accesses - - def compile( - self, - dst: Tensor, - dst_indices: Tuple[IndexVar, ...], - ) -> execution_engine.ExecutionEngine: - """Compiles the tensor assignment dst[dst_indices] = expression. - - Args: - dst: The destination tensor. - dst_indices: The tuple of IndexVar used to access the destination tensor. - - Returns: - The execution engine for the tensor assignment. - - Raises: - ValueError: If the expression is not proper or not supported. - """ - expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices) - input_accesses = self.get_input_accesses() - - # Build and compile the module to produce the execution engine. - with ir.Context(), ir.Location.unknown(): - module = ir.Module.create() - self._emit_assignment(module, dst, dst_indices, expr_to_info, - input_accesses) - engine = utils.compile_and_build_engine(module) - - return engine - @dataclasses.dataclass(frozen=True) class Access(IndexExpr): """The tensor access class.