forked from OSchip/llvm-project
[mlir] provide Python bindings for the Transform dialect
Python bindings for extensions of the Transform dialect are defined in separate Python source files that can be imported on-demand, i.e., that are not imported with the "main" transform dialect. This requires a minor addition to the ODS-based bindings generator. This approach is consistent with the current model for downstream projects that are expected to bundle MLIR Python bindings: such projects can include their custom extensions into the bundle similarly to how they include their dialects. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D126208
This commit is contained in:
parent
cc6c159203
commit
3f71765a71
|
@ -355,6 +355,61 @@ function(declare_mlir_dialect_python_bindings)
|
|||
endif()
|
||||
endfunction()
|
||||
|
||||
# Function: declare_mlir_dialect_extension_python_bindings
|
||||
# Helper to generate source groups for dialect extensions, including both
|
||||
# static source files and a TD_FILE to generate wrappers.
|
||||
#
|
||||
# This will generate a source group named ${ADD_TO_PARENT}.${EXTENSION_NAME}.
|
||||
#
|
||||
# Arguments:
|
||||
# ROOT_DIR: Same as for declare_mlir_python_sources().
|
||||
# ADD_TO_PARENT: Same as for declare_mlir_python_sources(). Unique names
|
||||
# for the subordinate source groups are derived from this.
|
||||
# TD_FILE: Tablegen file to generate source for (relative to ROOT_DIR).
|
||||
# DIALECT_NAME: Python name of the dialect.
|
||||
# EXTENSION_NAME: Python name of the dialect extension.
|
||||
# SOURCES: Same as declare_mlir_python_sources().
|
||||
# SOURCES_GLOB: Same as declare_mlir_python_sources().
|
||||
# DEPENDS: Additional dependency targets.
|
||||
function(declare_mlir_dialect_extension_python_bindings)
|
||||
cmake_parse_arguments(ARG
|
||||
""
|
||||
"ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME"
|
||||
"SOURCES;SOURCES_GLOB;DEPENDS"
|
||||
${ARGN})
|
||||
# Source files.
|
||||
set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}")
|
||||
declare_mlir_python_sources(${_extension_target}
|
||||
ROOT_DIR "${ARG_ROOT_DIR}"
|
||||
ADD_TO_PARENT "${ARG_ADD_TO_PARENT}"
|
||||
SOURCES "${ARG_SOURCES}"
|
||||
SOURCES_GLOB "${ARG_SOURCES_GLOB}"
|
||||
)
|
||||
|
||||
# Tablegen
|
||||
if(ARG_TD_FILE)
|
||||
set(tblgen_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}.tablegen")
|
||||
set(td_file "${ARG_ROOT_DIR}/${ARG_TD_FILE}")
|
||||
get_filename_component(relative_td_directory "${ARG_TD_FILE}" DIRECTORY)
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}")
|
||||
set(output_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_ops_gen.py")
|
||||
set(LLVM_TARGET_DEFINITIONS ${td_file})
|
||||
mlir_tablegen("${output_filename}" -gen-python-op-bindings
|
||||
-bind-dialect=${ARG_DIALECT_NAME}
|
||||
-dialect-extension=${ARG_EXTENSION_NAME})
|
||||
add_public_tablegen_target(${tblgen_target})
|
||||
if(ARG_DEPENDS)
|
||||
add_dependencies(${tblgen_target} ${ARG_DEPENDS})
|
||||
endif()
|
||||
|
||||
declare_mlir_python_sources("${_extension_target}.ops_gen"
|
||||
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
ADD_TO_PARENT "${_extension_target}"
|
||||
SOURCES "${output_filename}"
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Function: mlir_python_setup_extension_rpath
|
||||
# Sets RPATH properties on a target, assuming that it is being output to
|
||||
# an _mlir_libs directory with all other libraries. For static linkage,
|
||||
|
|
|
@ -116,6 +116,25 @@ declare_mlir_dialect_python_bindings(
|
|||
DIALECT_NAME linalg
|
||||
DEPENDS LinalgOdsGen)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TransformOps.td
|
||||
SOURCES
|
||||
dialects/_transform_ops_ext.py
|
||||
dialects/transform/__init__.py
|
||||
DIALECT_NAME transform)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/LinalgStructuredTransformOps.td
|
||||
SOURCES
|
||||
dialects/_structured_transform_ops_ext.py
|
||||
dialects/transform/structured.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME structured_transform)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
//===-- LinalgStructuredTransformOps.td --------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Entry point of the Python bindings generator for the structured transform ops
|
||||
// provided by Linalg (and other dialects).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
|
||||
#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td"
|
||||
|
||||
#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
|
|
@ -0,0 +1,15 @@
|
|||
//===-- TransformOps.td - Transform ops bind entry point ---*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_TRANSFORM_OPS
|
||||
#define PYTHON_BINDINGS_TRANSFORM_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformOps.td"
|
||||
|
||||
#endif // PYTHON_BINDINGS_TRANSFORM_OPS
|
|
@ -0,0 +1,178 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ..dialects import pdl
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
|
||||
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
|
||||
def _get_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
|
||||
"""Creates an array attribute from its operand."""
|
||||
if values is None:
|
||||
return ArrayAttr.get([])
|
||||
if isinstance(values, ArrayAttr):
|
||||
return values
|
||||
|
||||
return ArrayAttr.get(values)
|
||||
|
||||
|
||||
def _get_int_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an integer array attribute from its operand.
|
||||
|
||||
If the operand is already an array attribute, forwards it. Otherwise treats
|
||||
the operand as a list of attributes or integers, possibly intersperced, to
|
||||
create a new array attribute containing integer attributes. Expects the
|
||||
thread-local MLIR context to have been set by the context manager.
|
||||
"""
|
||||
if values is None:
|
||||
return ArrayAttr.get([])
|
||||
if isinstance(values, ArrayAttr):
|
||||
return values
|
||||
|
||||
attributes = []
|
||||
for value in values:
|
||||
if isinstance(value, IntegerAttr):
|
||||
attributes.append(value)
|
||||
else:
|
||||
attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value))
|
||||
return ArrayAttr.get(attributes)
|
||||
|
||||
|
||||
def _get_int_int_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
|
||||
IntOrAttrList]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an array attribute containing array attributes of integers.
|
||||
|
||||
If the operand is already an array attribute, forwards it. Otherwise treats
|
||||
the operand as a list of attributes or integers, potentially interpserced, to
|
||||
create a new array-of-array attribute. Expects the thread-local MLIR context
|
||||
to have been set by the context manager.
|
||||
"""
|
||||
if values is None:
|
||||
return ArrayAttr.get([])
|
||||
if isinstance(values, ArrayAttr):
|
||||
return values
|
||||
|
||||
return ArrayAttr.get([_get_int_array_attr(value) for value in values])
|
||||
|
||||
|
||||
class InterchangeOp:
|
||||
"""Specialization for InterchangeOp class."""
|
||||
|
||||
def __init__(self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iterator_interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
interchange_attr = _get_int_array_attr(iterator_interchange)
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
iterator_interchange=interchange_attr,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
|
||||
class PadOp:
|
||||
"""Specialization for PadOp class."""
|
||||
|
||||
def __init__(self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
padding_values: Optional[Union[ArrayAttr,
|
||||
Sequence[Attribute]]] = None,
|
||||
padding_dimensions: OptionalIntList = None,
|
||||
pack_paddings: OptionalIntList = None,
|
||||
hoist_paddings: OptionalIntList = None,
|
||||
transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
|
||||
ArrayAttr, IntOrAttrList]]]] = None,
|
||||
loc=None,
|
||||
ip=None):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
padding_values_attr = _get_array_attr(padding_values)
|
||||
padding_dimensions_attr = _get_int_array_attr(padding_dimensions)
|
||||
pack_paddings_attr = _get_int_array_attr(pack_paddings)
|
||||
hoist_paddings_attr = _get_int_array_attr(hoist_paddings)
|
||||
transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
padding_values=padding_values_attr,
|
||||
padding_dimensions=padding_dimensions_attr,
|
||||
pack_paddings=pack_paddings_attr,
|
||||
hoist_paddings=hoist_paddings_attr,
|
||||
transpose_paddings=transpose_paddings_attr,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
|
||||
class ScalarizeOp:
|
||||
"""Specialization for ScalarizeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)
|
||||
|
||||
|
||||
class TileOp:
|
||||
"""Specialization for TileOp class."""
|
||||
|
||||
def __init__(self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
sizes: OptionalIntList = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
sizes_attr = _get_int_array_attr(sizes)
|
||||
num_loops = sum(
|
||||
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
|
||||
super().__init__(
|
||||
pdl_operation_type, [pdl_operation_type] * num_loops,
|
||||
_get_op_result_or_value(target),
|
||||
sizes=sizes_attr,
|
||||
interchange=_get_int_array_attr(interchange) if interchange else None,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]:
|
||||
if not attr:
|
||||
return []
|
||||
return [IntegerAttr(element).value for element in attr]
|
||||
|
||||
|
||||
class VectorizeOp:
|
||||
"""Specialization for VectorizeOp class."""
|
||||
|
||||
def __init__(self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
vectorize_padding: Union[bool, BoolAttr] = False,
|
||||
loc=None,
|
||||
ip=None):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
if isinstance(vectorize_padding, bool):
|
||||
vectorize_padding = BoolAttr.get(vectorize_padding)
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
vectorize_padding=vectorize_padding,
|
||||
loc=loc,
|
||||
ip=ip)
|
|
@ -0,0 +1,106 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||
from ..dialects import pdl
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, overload, Sequence, Union
|
||||
|
||||
|
||||
def _get_symbol_ref_attr(value: Union[Attribute, str]):
|
||||
if isinstance(value, Attribute):
|
||||
return value
|
||||
return FlatSymbolRefAttr.get(value)
|
||||
|
||||
|
||||
class GetClosestIsolatedParentOp:
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
_get_op_result_or_value(target),
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
|
||||
class PDLMatchOp:
|
||||
|
||||
def __init__(self,
|
||||
target: Union[Operation, Value],
|
||||
pattern_name: Union[Attribute, str],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
_get_op_result_or_value(target),
|
||||
_get_symbol_ref_attr(pattern_name),
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
|
||||
class SequenceOp:
|
||||
|
||||
@overload
|
||||
def __init__(self, resultsOrRoot: Sequence[Type],
|
||||
optionalRoot: Optional[Union[Operation, Value]]):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]],
|
||||
optionalRoot: NoneType):
|
||||
...
|
||||
|
||||
def __init__(self, resultsOrRoot=None, optionalRoot=None):
|
||||
results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else []
|
||||
root = (
|
||||
resultsOrRoot
|
||||
if not isinstance(resultsOrRoot, Sequence) else optionalRoot)
|
||||
root = _get_op_result_or_value(root) if root else None
|
||||
super().__init__(results_=results, root=root)
|
||||
self.regions[0].blocks.append(pdl.OperationType.get())
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
|
||||
|
||||
class WithPDLPatternsOp:
|
||||
|
||||
def __init__(self,
|
||||
target: Optional[Union[Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
super().__init__(
|
||||
root=_get_op_result_or_value(target) if target else None,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
self.regions[0].blocks.append(pdl.OperationType.get())
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
|
||||
|
||||
class YieldOp:
|
||||
|
||||
def __init__(self,
|
||||
operands: Union[Operation, Sequence[Value]] = [],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
|
|
@ -0,0 +1,5 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._transform_ops_gen import *
|
|
@ -0,0 +1,5 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._structured_transform_ops_gen import *
|
|
@ -0,0 +1,84 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import transform
|
||||
from mlir.dialects import pdl
|
||||
|
||||
|
||||
def run(f):
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
@run
|
||||
def testSequenceOp():
|
||||
sequence = transform.SequenceOp([pdl.OperationType.get()])
|
||||
with InsertionPoint(sequence.body):
|
||||
transform.YieldOp([sequence.bodyTarget])
|
||||
# CHECK-LABEL: TEST: testSequenceOp
|
||||
# CHECK: = transform.sequence {
|
||||
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
|
||||
# CHECK: yield %[[ARG0]] : !pdl.operation
|
||||
# CHECK: } : !pdl.operation
|
||||
|
||||
|
||||
@run
|
||||
def testNestedSequenceOp():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
nested = transform.SequenceOp(sequence.bodyTarget)
|
||||
with InsertionPoint(nested.body):
|
||||
doubly_nested = transform.SequenceOp([pdl.OperationType.get()],
|
||||
nested.bodyTarget)
|
||||
with InsertionPoint(doubly_nested.body):
|
||||
transform.YieldOp([doubly_nested.bodyTarget])
|
||||
transform.YieldOp()
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testNestedSequenceOp
|
||||
# CHECK: transform.sequence {
|
||||
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
|
||||
# CHECK: sequence %[[ARG0]] {
|
||||
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
|
||||
# CHECK: = sequence %[[ARG1]] {
|
||||
# CHECK: ^{{.*}}(%[[ARG2:.+]]: !pdl.operation):
|
||||
# CHECK: yield %[[ARG2]] : !pdl.operation
|
||||
# CHECK: } : !pdl.operation
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
|
||||
|
||||
@run
|
||||
def testTransformPDLOps():
|
||||
withPdl = transform.WithPDLPatternsOp()
|
||||
with InsertionPoint(withPdl.body):
|
||||
sequence = transform.SequenceOp([pdl.OperationType.get()],
|
||||
withPdl.bodyTarget)
|
||||
with InsertionPoint(sequence.body):
|
||||
match = transform.PDLMatchOp(sequence.bodyTarget, "pdl_matcher")
|
||||
transform.YieldOp(match)
|
||||
# CHECK-LABEL: TEST: testTransformPDLOps
|
||||
# CHECK: transform.with_pdl_patterns {
|
||||
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
|
||||
# CHECK: = sequence %[[ARG0]] {
|
||||
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
|
||||
# CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
|
||||
# CHECK: yield %[[RES]] : !pdl.operation
|
||||
# CHECK: } : !pdl.operation
|
||||
# CHECK: }
|
||||
|
||||
|
||||
@run
|
||||
def testGetClosestIsolatedParentOp():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
transform.GetClosestIsolatedParentOp(sequence.bodyTarget)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
|
||||
# CHECK: = get_closest_isolated_parent %[[ARG1]]
|
|
@ -0,0 +1,118 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import transform
|
||||
from mlir.dialects import pdl
|
||||
from mlir.dialects.transform import structured
|
||||
|
||||
|
||||
def run(f):
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
@run
|
||||
def testInterchange():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.InterchangeOp(
|
||||
sequence.bodyTarget,
|
||||
iterator_interchange=[
|
||||
IntegerAttr.get(IntegerType.get_signless(64), 1), 0
|
||||
])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testInterchange
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: transform.structured.interchange
|
||||
# CHECK: iterator_interchange = [1, 0]
|
||||
|
||||
|
||||
@run
|
||||
def testPad():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.PadOp(
|
||||
sequence.bodyTarget,
|
||||
padding_values=[FloatAttr.get_f32(42.0)],
|
||||
padding_dimensions=[1],
|
||||
transpose_paddings=[[1, 0]])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testPad
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: transform.structured.pad
|
||||
# CHECK-DAG: padding_values = [4.200000e+01 : f32]
|
||||
# CHECK-DAG: padding_dimensions = [1]
|
||||
# CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
|
||||
# CHECK-DAG: hoist_paddings = []
|
||||
# CHECK-DAG: pack_paddings = []
|
||||
|
||||
|
||||
@run
|
||||
def testScalarize():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.ScalarizeOp(sequence.bodyTarget)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testScalarize
|
||||
# CHECK: transform.structured.scalarize
|
||||
|
||||
|
||||
@run
|
||||
def testTileCompact():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testTileCompact
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
|
||||
# CHECK-DAG: interchange = [0, 1]
|
||||
# CHECK-DAG: sizes = [4, 8]
|
||||
|
||||
|
||||
@run
|
||||
def testTileAttributes():
|
||||
sequence = transform.SequenceOp()
|
||||
attr = ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
|
||||
ichange = ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testTileAttributes
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: structured.tile
|
||||
# CHECK-DAG: interchange = [0, 1]
|
||||
# CHECK-DAG: sizes = [4, 8]
|
||||
|
||||
|
||||
@run
|
||||
def testTileZero():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileOp(
|
||||
sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testTileZero
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
|
||||
# CHECK-DAG: interchange = [0, 1, 2, 3]
|
||||
# CHECK-DAG: sizes = [4, 0, 2, 0]
|
||||
|
||||
|
||||
@run
|
||||
def testVectorize():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testVectorize
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: = transform.structured.vectorize
|
||||
# CHECK: vectorize_padding = true
|
|
@ -50,6 +50,10 @@ class _Dialect(_ods_ir.Dialect):
|
|||
|
||||
)Py";
|
||||
|
||||
constexpr const char *dialectExtensionTemplate = R"Py(
|
||||
from ._{0}_ops_gen import _Dialect
|
||||
)Py";
|
||||
|
||||
/// Template for operation class:
|
||||
/// {0} is the Python class name;
|
||||
/// {1} is the operation name.
|
||||
|
@ -270,6 +274,10 @@ static llvm::cl::opt<std::string>
|
|||
llvm::cl::desc("The dialect to run the generator for"),
|
||||
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
|
||||
|
||||
static llvm::cl::opt<std::string> clDialectExtensionName(
|
||||
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
|
||||
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
|
||||
|
||||
using AttributeClasses = DenseMap<StringRef, StringRef>;
|
||||
|
||||
/// Checks whether `str` is a Python keyword.
|
||||
|
@ -1014,8 +1022,14 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
|
|||
AttributeClasses attributeClasses;
|
||||
constructAttributeMapping(records, attributeClasses);
|
||||
|
||||
os << llvm::formatv(fileHeader, clDialectName.getValue());
|
||||
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
||||
bool isExtension = !clDialectExtensionName.empty();
|
||||
os << llvm::formatv(fileHeader, isExtension
|
||||
? clDialectExtensionName.getValue()
|
||||
: clDialectName.getValue());
|
||||
if (isExtension)
|
||||
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
|
||||
else
|
||||
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
||||
|
||||
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
|
||||
Operator op(rec);
|
||||
|
|
|
@ -825,6 +825,74 @@ filegroup(
|
|||
],
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
# Transform dialect and extensions.
|
||||
##---------------------------------------------------------------------------##
|
||||
|
||||
td_library(
|
||||
name = "TransformOpsPyTdFiles",
|
||||
srcs = [
|
||||
"//mlir:include/mlir/Bindings/Python/Attributes.td",
|
||||
],
|
||||
deps = [
|
||||
"//mlir:OpBaseTdFiles",
|
||||
"//mlir:TransformDialectTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl_filegroup(
|
||||
name = "TransformOpsPyGen",
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-python-op-bindings",
|
||||
"-bind-dialect=transform",
|
||||
],
|
||||
"mlir/dialects/_transform_ops_gen.py",
|
||||
),
|
||||
],
|
||||
tblgen = "//mlir:mlir-tblgen",
|
||||
td_file = "mlir/dialects/TransformOps.td",
|
||||
deps = [
|
||||
":TransformOpsPyTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl_filegroup(
|
||||
name = "StructuredTransformOpsPyGen",
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-python-op-bindings",
|
||||
"-bind-dialect=transform",
|
||||
"-dialect-extension=structured_transform",
|
||||
],
|
||||
"mlir/dialects/_structured_transform_ops_gen.py",
|
||||
),
|
||||
],
|
||||
tblgen = "//mlir:mlir-tblgen",
|
||||
td_file = "mlir/dialects/LinalgStructuredTransformOps.td",
|
||||
deps = [
|
||||
":TransformOpsPyTdFiles",
|
||||
"//mlir:LinalgTransformOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "TransformOpsPyFiles",
|
||||
srcs = [
|
||||
"mlir/dialects/_structured_transform_ops_ext.py",
|
||||
"mlir/dialects/_transform_ops_ext.py",
|
||||
":StructuredTransformOpsPyGen",
|
||||
":TransformOpsPyGen",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "TransformOpsPackagePyFiles",
|
||||
srcs = glob(["mlir/dialects/transform/*.py"]),
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
# Vector dialect.
|
||||
##---------------------------------------------------------------------------##
|
||||
|
|
Loading…
Reference in New Issue