forked from OSchip/llvm-project
[mlir] provide Python bindings for the Transform dialect
Python bindings for extensions of the Transform dialect are defined in separate Python source files that can be imported on-demand, i.e., that are not imported with the "main" transform dialect. This requires a minor addition to the ODS-based bindings generator. This approach is consistent with the current model for downstream projects that are expected to bundle MLIR Python bindings: such projects can include their custom extensions into the bundle similarly to how they include their dialects. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D126208
This commit is contained in:
parent
cc6c159203
commit
3f71765a71
|
@ -355,6 +355,61 @@ function(declare_mlir_dialect_python_bindings)
|
||||||
endif()
|
endif()
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
|
# Function: declare_mlir_dialect_extension_python_bindings
|
||||||
|
# Helper to generate source groups for dialect extensions, including both
|
||||||
|
# static source files and a TD_FILE to generate wrappers.
|
||||||
|
#
|
||||||
|
# This will generate a source group named ${ADD_TO_PARENT}.${EXTENSION_NAME}.
|
||||||
|
#
|
||||||
|
# Arguments:
|
||||||
|
# ROOT_DIR: Same as for declare_mlir_python_sources().
|
||||||
|
# ADD_TO_PARENT: Same as for declare_mlir_python_sources(). Unique names
|
||||||
|
# for the subordinate source groups are derived from this.
|
||||||
|
# TD_FILE: Tablegen file to generate source for (relative to ROOT_DIR).
|
||||||
|
# DIALECT_NAME: Python name of the dialect.
|
||||||
|
# EXTENSION_NAME: Python name of the dialect extension.
|
||||||
|
# SOURCES: Same as declare_mlir_python_sources().
|
||||||
|
# SOURCES_GLOB: Same as declare_mlir_python_sources().
|
||||||
|
# DEPENDS: Additional dependency targets.
|
||||||
|
function(declare_mlir_dialect_extension_python_bindings)
|
||||||
|
cmake_parse_arguments(ARG
|
||||||
|
""
|
||||||
|
"ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME"
|
||||||
|
"SOURCES;SOURCES_GLOB;DEPENDS"
|
||||||
|
${ARGN})
|
||||||
|
# Source files.
|
||||||
|
set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}")
|
||||||
|
declare_mlir_python_sources(${_extension_target}
|
||||||
|
ROOT_DIR "${ARG_ROOT_DIR}"
|
||||||
|
ADD_TO_PARENT "${ARG_ADD_TO_PARENT}"
|
||||||
|
SOURCES "${ARG_SOURCES}"
|
||||||
|
SOURCES_GLOB "${ARG_SOURCES_GLOB}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tablegen
|
||||||
|
if(ARG_TD_FILE)
|
||||||
|
set(tblgen_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}.tablegen")
|
||||||
|
set(td_file "${ARG_ROOT_DIR}/${ARG_TD_FILE}")
|
||||||
|
get_filename_component(relative_td_directory "${ARG_TD_FILE}" DIRECTORY)
|
||||||
|
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}")
|
||||||
|
set(output_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_ops_gen.py")
|
||||||
|
set(LLVM_TARGET_DEFINITIONS ${td_file})
|
||||||
|
mlir_tablegen("${output_filename}" -gen-python-op-bindings
|
||||||
|
-bind-dialect=${ARG_DIALECT_NAME}
|
||||||
|
-dialect-extension=${ARG_EXTENSION_NAME})
|
||||||
|
add_public_tablegen_target(${tblgen_target})
|
||||||
|
if(ARG_DEPENDS)
|
||||||
|
add_dependencies(${tblgen_target} ${ARG_DEPENDS})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
declare_mlir_python_sources("${_extension_target}.ops_gen"
|
||||||
|
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
|
||||||
|
ADD_TO_PARENT "${_extension_target}"
|
||||||
|
SOURCES "${output_filename}"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
# Function: mlir_python_setup_extension_rpath
|
# Function: mlir_python_setup_extension_rpath
|
||||||
# Sets RPATH properties on a target, assuming that it is being output to
|
# Sets RPATH properties on a target, assuming that it is being output to
|
||||||
# an _mlir_libs directory with all other libraries. For static linkage,
|
# an _mlir_libs directory with all other libraries. For static linkage,
|
||||||
|
|
|
@ -116,6 +116,25 @@ declare_mlir_dialect_python_bindings(
|
||||||
DIALECT_NAME linalg
|
DIALECT_NAME linalg
|
||||||
DEPENDS LinalgOdsGen)
|
DEPENDS LinalgOdsGen)
|
||||||
|
|
||||||
|
declare_mlir_dialect_python_bindings(
|
||||||
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||||
|
TD_FILE dialects/TransformOps.td
|
||||||
|
SOURCES
|
||||||
|
dialects/_transform_ops_ext.py
|
||||||
|
dialects/transform/__init__.py
|
||||||
|
DIALECT_NAME transform)
|
||||||
|
|
||||||
|
declare_mlir_dialect_extension_python_bindings(
|
||||||
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||||
|
TD_FILE dialects/LinalgStructuredTransformOps.td
|
||||||
|
SOURCES
|
||||||
|
dialects/_structured_transform_ops_ext.py
|
||||||
|
dialects/transform/structured.py
|
||||||
|
DIALECT_NAME transform
|
||||||
|
EXTENSION_NAME structured_transform)
|
||||||
|
|
||||||
declare_mlir_dialect_python_bindings(
|
declare_mlir_dialect_python_bindings(
|
||||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===-- LinalgStructuredTransformOps.td --------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Entry point of the Python bindings generator for the structured transform ops
|
||||||
|
// provided by Linalg (and other dialects).
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
|
||||||
|
#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
|
||||||
|
|
||||||
|
include "mlir/Bindings/Python/Attributes.td"
|
||||||
|
include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td"
|
||||||
|
|
||||||
|
#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
|
|
@ -0,0 +1,15 @@
|
||||||
|
//===-- TransformOps.td - Transform ops bind entry point ---*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef PYTHON_BINDINGS_TRANSFORM_OPS
|
||||||
|
#define PYTHON_BINDINGS_TRANSFORM_OPS
|
||||||
|
|
||||||
|
include "mlir/Bindings/Python/Attributes.td"
|
||||||
|
include "mlir/Dialect/Transform/IR/TransformOps.td"
|
||||||
|
|
||||||
|
#endif // PYTHON_BINDINGS_TRANSFORM_OPS
|
|
@ -0,0 +1,178 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ..ir import *
|
||||||
|
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||||
|
from ..dialects import pdl
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError("Error loading imports from extension module") from e
|
||||||
|
|
||||||
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
|
||||||
|
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_array_attr(
|
||||||
|
values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
|
||||||
|
"""Creates an array attribute from its operand."""
|
||||||
|
if values is None:
|
||||||
|
return ArrayAttr.get([])
|
||||||
|
if isinstance(values, ArrayAttr):
|
||||||
|
return values
|
||||||
|
|
||||||
|
return ArrayAttr.get(values)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_int_array_attr(
|
||||||
|
values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]]
|
||||||
|
) -> ArrayAttr:
|
||||||
|
"""Creates an integer array attribute from its operand.
|
||||||
|
|
||||||
|
If the operand is already an array attribute, forwards it. Otherwise treats
|
||||||
|
the operand as a list of attributes or integers, possibly intersperced, to
|
||||||
|
create a new array attribute containing integer attributes. Expects the
|
||||||
|
thread-local MLIR context to have been set by the context manager.
|
||||||
|
"""
|
||||||
|
if values is None:
|
||||||
|
return ArrayAttr.get([])
|
||||||
|
if isinstance(values, ArrayAttr):
|
||||||
|
return values
|
||||||
|
|
||||||
|
attributes = []
|
||||||
|
for value in values:
|
||||||
|
if isinstance(value, IntegerAttr):
|
||||||
|
attributes.append(value)
|
||||||
|
else:
|
||||||
|
attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value))
|
||||||
|
return ArrayAttr.get(attributes)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_int_int_array_attr(
|
||||||
|
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
|
||||||
|
IntOrAttrList]]]]
|
||||||
|
) -> ArrayAttr:
|
||||||
|
"""Creates an array attribute containing array attributes of integers.
|
||||||
|
|
||||||
|
If the operand is already an array attribute, forwards it. Otherwise treats
|
||||||
|
the operand as a list of attributes or integers, potentially interpserced, to
|
||||||
|
create a new array-of-array attribute. Expects the thread-local MLIR context
|
||||||
|
to have been set by the context manager.
|
||||||
|
"""
|
||||||
|
if values is None:
|
||||||
|
return ArrayAttr.get([])
|
||||||
|
if isinstance(values, ArrayAttr):
|
||||||
|
return values
|
||||||
|
|
||||||
|
return ArrayAttr.get([_get_int_array_attr(value) for value in values])
|
||||||
|
|
||||||
|
|
||||||
|
class InterchangeOp:
|
||||||
|
"""Specialization for InterchangeOp class."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
target: Union[Operation, Value],
|
||||||
|
*,
|
||||||
|
iterator_interchange: OptionalIntList = None,
|
||||||
|
loc=None,
|
||||||
|
ip=None):
|
||||||
|
pdl_operation_type = pdl.OperationType.get()
|
||||||
|
interchange_attr = _get_int_array_attr(iterator_interchange)
|
||||||
|
super().__init__(
|
||||||
|
pdl_operation_type,
|
||||||
|
_get_op_result_or_value(target),
|
||||||
|
iterator_interchange=interchange_attr,
|
||||||
|
loc=loc,
|
||||||
|
ip=ip)
|
||||||
|
|
||||||
|
|
||||||
|
class PadOp:
|
||||||
|
"""Specialization for PadOp class."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
target: Union[Operation, Value],
|
||||||
|
*,
|
||||||
|
padding_values: Optional[Union[ArrayAttr,
|
||||||
|
Sequence[Attribute]]] = None,
|
||||||
|
padding_dimensions: OptionalIntList = None,
|
||||||
|
pack_paddings: OptionalIntList = None,
|
||||||
|
hoist_paddings: OptionalIntList = None,
|
||||||
|
transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
|
||||||
|
ArrayAttr, IntOrAttrList]]]] = None,
|
||||||
|
loc=None,
|
||||||
|
ip=None):
|
||||||
|
pdl_operation_type = pdl.OperationType.get()
|
||||||
|
padding_values_attr = _get_array_attr(padding_values)
|
||||||
|
padding_dimensions_attr = _get_int_array_attr(padding_dimensions)
|
||||||
|
pack_paddings_attr = _get_int_array_attr(pack_paddings)
|
||||||
|
hoist_paddings_attr = _get_int_array_attr(hoist_paddings)
|
||||||
|
transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
|
||||||
|
super().__init__(
|
||||||
|
pdl_operation_type,
|
||||||
|
_get_op_result_or_value(target),
|
||||||
|
padding_values=padding_values_attr,
|
||||||
|
padding_dimensions=padding_dimensions_attr,
|
||||||
|
pack_paddings=pack_paddings_attr,
|
||||||
|
hoist_paddings=hoist_paddings_attr,
|
||||||
|
transpose_paddings=transpose_paddings_attr,
|
||||||
|
loc=loc,
|
||||||
|
ip=ip)
|
||||||
|
|
||||||
|
|
||||||
|
class ScalarizeOp:
|
||||||
|
"""Specialization for ScalarizeOp class."""
|
||||||
|
|
||||||
|
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||||
|
pdl_operation_type = pdl.OperationType.get()
|
||||||
|
super().__init__(
|
||||||
|
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)
|
||||||
|
|
||||||
|
|
||||||
|
class TileOp:
|
||||||
|
"""Specialization for TileOp class."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
target: Union[Operation, Value],
|
||||||
|
*,
|
||||||
|
sizes: OptionalIntList = None,
|
||||||
|
interchange: OptionalIntList = None,
|
||||||
|
loc=None,
|
||||||
|
ip=None):
|
||||||
|
pdl_operation_type = pdl.OperationType.get()
|
||||||
|
sizes_attr = _get_int_array_attr(sizes)
|
||||||
|
num_loops = sum(
|
||||||
|
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
|
||||||
|
super().__init__(
|
||||||
|
pdl_operation_type, [pdl_operation_type] * num_loops,
|
||||||
|
_get_op_result_or_value(target),
|
||||||
|
sizes=sizes_attr,
|
||||||
|
interchange=_get_int_array_attr(interchange) if interchange else None,
|
||||||
|
loc=loc,
|
||||||
|
ip=ip)
|
||||||
|
|
||||||
|
def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]:
|
||||||
|
if not attr:
|
||||||
|
return []
|
||||||
|
return [IntegerAttr(element).value for element in attr]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorizeOp:
|
||||||
|
"""Specialization for VectorizeOp class."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
target: Union[Operation, Value],
|
||||||
|
*,
|
||||||
|
vectorize_padding: Union[bool, BoolAttr] = False,
|
||||||
|
loc=None,
|
||||||
|
ip=None):
|
||||||
|
pdl_operation_type = pdl.OperationType.get()
|
||||||
|
if isinstance(vectorize_padding, bool):
|
||||||
|
vectorize_padding = BoolAttr.get(vectorize_padding)
|
||||||
|
super().__init__(
|
||||||
|
pdl_operation_type,
|
||||||
|
_get_op_result_or_value(target),
|
||||||
|
vectorize_padding=vectorize_padding,
|
||||||
|
loc=loc,
|
||||||
|
ip=ip)
|
|
@ -0,0 +1,106 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ..ir import *
|
||||||
|
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||||
|
from ..dialects import pdl
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError("Error loading imports from extension module") from e
|
||||||
|
|
||||||
|
from typing import Optional, overload, Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
|
def _get_symbol_ref_attr(value: Union[Attribute, str]):
|
||||||
|
if isinstance(value, Attribute):
|
||||||
|
return value
|
||||||
|
return FlatSymbolRefAttr.get(value)
|
||||||
|
|
||||||
|
|
||||||
|
class GetClosestIsolatedParentOp:
|
||||||
|
|
||||||
|
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||||
|
super().__init__(
|
||||||
|
pdl.OperationType.get(),
|
||||||
|
_get_op_result_or_value(target),
|
||||||
|
loc=loc,
|
||||||
|
ip=ip)
|
||||||
|
|
||||||
|
|
||||||
|
class PDLMatchOp:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
target: Union[Operation, Value],
|
||||||
|
pattern_name: Union[Attribute, str],
|
||||||
|
*,
|
||||||
|
loc=None,
|
||||||
|
ip=None):
|
||||||
|
super().__init__(
|
||||||
|
pdl.OperationType.get(),
|
||||||
|
_get_op_result_or_value(target),
|
||||||
|
_get_symbol_ref_attr(pattern_name),
|
||||||
|
loc=loc,
|
||||||
|
ip=ip)
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceOp:
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(self, resultsOrRoot: Sequence[Type],
|
||||||
|
optionalRoot: Optional[Union[Operation, Value]]):
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]],
|
||||||
|
optionalRoot: NoneType):
|
||||||
|
...
|
||||||
|
|
||||||
|
def __init__(self, resultsOrRoot=None, optionalRoot=None):
|
||||||
|
results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else []
|
||||||
|
root = (
|
||||||
|
resultsOrRoot
|
||||||
|
if not isinstance(resultsOrRoot, Sequence) else optionalRoot)
|
||||||
|
root = _get_op_result_or_value(root) if root else None
|
||||||
|
super().__init__(results_=results, root=root)
|
||||||
|
self.regions[0].blocks.append(pdl.OperationType.get())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def body(self) -> Block:
|
||||||
|
return self.regions[0].blocks[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bodyTarget(self) -> Value:
|
||||||
|
return self.body.arguments[0]
|
||||||
|
|
||||||
|
|
||||||
|
class WithPDLPatternsOp:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
target: Optional[Union[Operation, Value]] = None,
|
||||||
|
*,
|
||||||
|
loc=None,
|
||||||
|
ip=None):
|
||||||
|
super().__init__(
|
||||||
|
root=_get_op_result_or_value(target) if target else None,
|
||||||
|
loc=loc,
|
||||||
|
ip=ip)
|
||||||
|
self.regions[0].blocks.append(pdl.OperationType.get())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def body(self) -> Block:
|
||||||
|
return self.regions[0].blocks[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bodyTarget(self) -> Value:
|
||||||
|
return self.body.arguments[0]
|
||||||
|
|
||||||
|
|
||||||
|
class YieldOp:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
operands: Union[Operation, Sequence[Value]] = [],
|
||||||
|
*,
|
||||||
|
loc=None,
|
||||||
|
ip=None):
|
||||||
|
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from .._transform_ops_gen import *
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from .._structured_transform_ops_gen import *
|
|
@ -0,0 +1,84 @@
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
from mlir.ir import *
|
||||||
|
from mlir.dialects import transform
|
||||||
|
from mlir.dialects import pdl
|
||||||
|
|
||||||
|
|
||||||
|
def run(f):
|
||||||
|
with Context(), Location.unknown():
|
||||||
|
module = Module.create()
|
||||||
|
with InsertionPoint(module.body):
|
||||||
|
print("\nTEST:", f.__name__)
|
||||||
|
f()
|
||||||
|
print(module)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testSequenceOp():
|
||||||
|
sequence = transform.SequenceOp([pdl.OperationType.get()])
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
transform.YieldOp([sequence.bodyTarget])
|
||||||
|
# CHECK-LABEL: TEST: testSequenceOp
|
||||||
|
# CHECK: = transform.sequence {
|
||||||
|
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
|
||||||
|
# CHECK: yield %[[ARG0]] : !pdl.operation
|
||||||
|
# CHECK: } : !pdl.operation
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testNestedSequenceOp():
|
||||||
|
sequence = transform.SequenceOp()
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
nested = transform.SequenceOp(sequence.bodyTarget)
|
||||||
|
with InsertionPoint(nested.body):
|
||||||
|
doubly_nested = transform.SequenceOp([pdl.OperationType.get()],
|
||||||
|
nested.bodyTarget)
|
||||||
|
with InsertionPoint(doubly_nested.body):
|
||||||
|
transform.YieldOp([doubly_nested.bodyTarget])
|
||||||
|
transform.YieldOp()
|
||||||
|
transform.YieldOp()
|
||||||
|
# CHECK-LABEL: TEST: testNestedSequenceOp
|
||||||
|
# CHECK: transform.sequence {
|
||||||
|
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
|
||||||
|
# CHECK: sequence %[[ARG0]] {
|
||||||
|
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
|
||||||
|
# CHECK: = sequence %[[ARG1]] {
|
||||||
|
# CHECK: ^{{.*}}(%[[ARG2:.+]]: !pdl.operation):
|
||||||
|
# CHECK: yield %[[ARG2]] : !pdl.operation
|
||||||
|
# CHECK: } : !pdl.operation
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testTransformPDLOps():
|
||||||
|
withPdl = transform.WithPDLPatternsOp()
|
||||||
|
with InsertionPoint(withPdl.body):
|
||||||
|
sequence = transform.SequenceOp([pdl.OperationType.get()],
|
||||||
|
withPdl.bodyTarget)
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
match = transform.PDLMatchOp(sequence.bodyTarget, "pdl_matcher")
|
||||||
|
transform.YieldOp(match)
|
||||||
|
# CHECK-LABEL: TEST: testTransformPDLOps
|
||||||
|
# CHECK: transform.with_pdl_patterns {
|
||||||
|
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
|
||||||
|
# CHECK: = sequence %[[ARG0]] {
|
||||||
|
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
|
||||||
|
# CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
|
||||||
|
# CHECK: yield %[[RES]] : !pdl.operation
|
||||||
|
# CHECK: } : !pdl.operation
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testGetClosestIsolatedParentOp():
|
||||||
|
sequence = transform.SequenceOp()
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
transform.GetClosestIsolatedParentOp(sequence.bodyTarget)
|
||||||
|
transform.YieldOp()
|
||||||
|
# CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
|
||||||
|
# CHECK: transform.sequence
|
||||||
|
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
|
||||||
|
# CHECK: = get_closest_isolated_parent %[[ARG1]]
|
|
@ -0,0 +1,118 @@
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
from mlir.ir import *
|
||||||
|
from mlir.dialects import transform
|
||||||
|
from mlir.dialects import pdl
|
||||||
|
from mlir.dialects.transform import structured
|
||||||
|
|
||||||
|
|
||||||
|
def run(f):
|
||||||
|
with Context(), Location.unknown():
|
||||||
|
module = Module.create()
|
||||||
|
with InsertionPoint(module.body):
|
||||||
|
print("\nTEST:", f.__name__)
|
||||||
|
f()
|
||||||
|
print(module)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testInterchange():
|
||||||
|
sequence = transform.SequenceOp()
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
structured.InterchangeOp(
|
||||||
|
sequence.bodyTarget,
|
||||||
|
iterator_interchange=[
|
||||||
|
IntegerAttr.get(IntegerType.get_signless(64), 1), 0
|
||||||
|
])
|
||||||
|
transform.YieldOp()
|
||||||
|
# CHECK-LABEL: TEST: testInterchange
|
||||||
|
# CHECK: transform.sequence
|
||||||
|
# CHECK: transform.structured.interchange
|
||||||
|
# CHECK: iterator_interchange = [1, 0]
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testPad():
|
||||||
|
sequence = transform.SequenceOp()
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
structured.PadOp(
|
||||||
|
sequence.bodyTarget,
|
||||||
|
padding_values=[FloatAttr.get_f32(42.0)],
|
||||||
|
padding_dimensions=[1],
|
||||||
|
transpose_paddings=[[1, 0]])
|
||||||
|
transform.YieldOp()
|
||||||
|
# CHECK-LABEL: TEST: testPad
|
||||||
|
# CHECK: transform.sequence
|
||||||
|
# CHECK: transform.structured.pad
|
||||||
|
# CHECK-DAG: padding_values = [4.200000e+01 : f32]
|
||||||
|
# CHECK-DAG: padding_dimensions = [1]
|
||||||
|
# CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
|
||||||
|
# CHECK-DAG: hoist_paddings = []
|
||||||
|
# CHECK-DAG: pack_paddings = []
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testScalarize():
|
||||||
|
sequence = transform.SequenceOp()
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
structured.ScalarizeOp(sequence.bodyTarget)
|
||||||
|
transform.YieldOp()
|
||||||
|
# CHECK-LABEL: TEST: testScalarize
|
||||||
|
# CHECK: transform.structured.scalarize
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testTileCompact():
|
||||||
|
sequence = transform.SequenceOp()
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
|
||||||
|
transform.YieldOp()
|
||||||
|
# CHECK-LABEL: TEST: testTileCompact
|
||||||
|
# CHECK: transform.sequence
|
||||||
|
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
|
||||||
|
# CHECK-DAG: interchange = [0, 1]
|
||||||
|
# CHECK-DAG: sizes = [4, 8]
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testTileAttributes():
|
||||||
|
sequence = transform.SequenceOp()
|
||||||
|
attr = ArrayAttr.get(
|
||||||
|
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
|
||||||
|
ichange = ArrayAttr.get(
|
||||||
|
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
|
||||||
|
transform.YieldOp()
|
||||||
|
# CHECK-LABEL: TEST: testTileAttributes
|
||||||
|
# CHECK: transform.sequence
|
||||||
|
# CHECK: structured.tile
|
||||||
|
# CHECK-DAG: interchange = [0, 1]
|
||||||
|
# CHECK-DAG: sizes = [4, 8]
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testTileZero():
|
||||||
|
sequence = transform.SequenceOp()
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
structured.TileOp(
|
||||||
|
sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
|
||||||
|
transform.YieldOp()
|
||||||
|
# CHECK-LABEL: TEST: testTileZero
|
||||||
|
# CHECK: transform.sequence
|
||||||
|
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
|
||||||
|
# CHECK-DAG: interchange = [0, 1, 2, 3]
|
||||||
|
# CHECK-DAG: sizes = [4, 0, 2, 0]
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def testVectorize():
|
||||||
|
sequence = transform.SequenceOp()
|
||||||
|
with InsertionPoint(sequence.body):
|
||||||
|
structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
|
||||||
|
transform.YieldOp()
|
||||||
|
# CHECK-LABEL: TEST: testVectorize
|
||||||
|
# CHECK: transform.sequence
|
||||||
|
# CHECK: = transform.structured.vectorize
|
||||||
|
# CHECK: vectorize_padding = true
|
|
@ -50,6 +50,10 @@ class _Dialect(_ods_ir.Dialect):
|
||||||
|
|
||||||
)Py";
|
)Py";
|
||||||
|
|
||||||
|
constexpr const char *dialectExtensionTemplate = R"Py(
|
||||||
|
from ._{0}_ops_gen import _Dialect
|
||||||
|
)Py";
|
||||||
|
|
||||||
/// Template for operation class:
|
/// Template for operation class:
|
||||||
/// {0} is the Python class name;
|
/// {0} is the Python class name;
|
||||||
/// {1} is the operation name.
|
/// {1} is the operation name.
|
||||||
|
@ -270,6 +274,10 @@ static llvm::cl::opt<std::string>
|
||||||
llvm::cl::desc("The dialect to run the generator for"),
|
llvm::cl::desc("The dialect to run the generator for"),
|
||||||
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
|
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
|
||||||
|
|
||||||
|
static llvm::cl::opt<std::string> clDialectExtensionName(
|
||||||
|
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
|
||||||
|
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
|
||||||
|
|
||||||
using AttributeClasses = DenseMap<StringRef, StringRef>;
|
using AttributeClasses = DenseMap<StringRef, StringRef>;
|
||||||
|
|
||||||
/// Checks whether `str` is a Python keyword.
|
/// Checks whether `str` is a Python keyword.
|
||||||
|
@ -1014,8 +1022,14 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||||
AttributeClasses attributeClasses;
|
AttributeClasses attributeClasses;
|
||||||
constructAttributeMapping(records, attributeClasses);
|
constructAttributeMapping(records, attributeClasses);
|
||||||
|
|
||||||
os << llvm::formatv(fileHeader, clDialectName.getValue());
|
bool isExtension = !clDialectExtensionName.empty();
|
||||||
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
os << llvm::formatv(fileHeader, isExtension
|
||||||
|
? clDialectExtensionName.getValue()
|
||||||
|
: clDialectName.getValue());
|
||||||
|
if (isExtension)
|
||||||
|
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
|
||||||
|
else
|
||||||
|
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
||||||
|
|
||||||
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
|
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
|
||||||
Operator op(rec);
|
Operator op(rec);
|
||||||
|
|
|
@ -825,6 +825,74 @@ filegroup(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
##---------------------------------------------------------------------------##
|
||||||
|
# Transform dialect and extensions.
|
||||||
|
##---------------------------------------------------------------------------##
|
||||||
|
|
||||||
|
td_library(
|
||||||
|
name = "TransformOpsPyTdFiles",
|
||||||
|
srcs = [
|
||||||
|
"//mlir:include/mlir/Bindings/Python/Attributes.td",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mlir:OpBaseTdFiles",
|
||||||
|
"//mlir:TransformDialectTdFiles",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
gentbl_filegroup(
|
||||||
|
name = "TransformOpsPyGen",
|
||||||
|
tbl_outs = [
|
||||||
|
(
|
||||||
|
[
|
||||||
|
"-gen-python-op-bindings",
|
||||||
|
"-bind-dialect=transform",
|
||||||
|
],
|
||||||
|
"mlir/dialects/_transform_ops_gen.py",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tblgen = "//mlir:mlir-tblgen",
|
||||||
|
td_file = "mlir/dialects/TransformOps.td",
|
||||||
|
deps = [
|
||||||
|
":TransformOpsPyTdFiles",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
gentbl_filegroup(
|
||||||
|
name = "StructuredTransformOpsPyGen",
|
||||||
|
tbl_outs = [
|
||||||
|
(
|
||||||
|
[
|
||||||
|
"-gen-python-op-bindings",
|
||||||
|
"-bind-dialect=transform",
|
||||||
|
"-dialect-extension=structured_transform",
|
||||||
|
],
|
||||||
|
"mlir/dialects/_structured_transform_ops_gen.py",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tblgen = "//mlir:mlir-tblgen",
|
||||||
|
td_file = "mlir/dialects/LinalgStructuredTransformOps.td",
|
||||||
|
deps = [
|
||||||
|
":TransformOpsPyTdFiles",
|
||||||
|
"//mlir:LinalgTransformOpsTdFiles",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "TransformOpsPyFiles",
|
||||||
|
srcs = [
|
||||||
|
"mlir/dialects/_structured_transform_ops_ext.py",
|
||||||
|
"mlir/dialects/_transform_ops_ext.py",
|
||||||
|
":StructuredTransformOpsPyGen",
|
||||||
|
":TransformOpsPyGen",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "TransformOpsPackagePyFiles",
|
||||||
|
srcs = glob(["mlir/dialects/transform/*.py"]),
|
||||||
|
)
|
||||||
|
|
||||||
##---------------------------------------------------------------------------##
|
##---------------------------------------------------------------------------##
|
||||||
# Vector dialect.
|
# Vector dialect.
|
||||||
##---------------------------------------------------------------------------##
|
##---------------------------------------------------------------------------##
|
||||||
|
|
Loading…
Reference in New Issue