[mlir] provide Python bindings for the Transform dialect

Python bindings for extensions of the Transform dialect are defined in separate
Python source files that can be imported on-demand, i.e., that are not imported
with the "main" transform dialect. This requires a minor addition to the
ODS-based bindings generator. This approach is consistent with the current
model for downstream projects that are expected to bundle MLIR Python bindings:
such projects can include their custom extensions into the bundle similarly to
how they include their dialects.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D126208
This commit is contained in:
Alex Zinenko 2022-05-30 15:14:02 +02:00
parent cc6c159203
commit 3f71765a71
12 changed files with 690 additions and 2 deletions

View File

@ -355,6 +355,61 @@ function(declare_mlir_dialect_python_bindings)
endif()
endfunction()
# Function: declare_mlir_dialect_extension_python_bindings
# Helper to generate source groups for dialect extensions, including both
# static source files and a TD_FILE to generate wrappers.
#
# This will generate a source group named ${ADD_TO_PARENT}.${EXTENSION_NAME}.
#
# Arguments:
# ROOT_DIR: Same as for declare_mlir_python_sources().
# ADD_TO_PARENT: Same as for declare_mlir_python_sources(). Unique names
# for the subordinate source groups are derived from this.
# TD_FILE: Tablegen file to generate source for (relative to ROOT_DIR).
# DIALECT_NAME: Python name of the dialect.
# EXTENSION_NAME: Python name of the dialect extension.
# SOURCES: Same as declare_mlir_python_sources().
# SOURCES_GLOB: Same as declare_mlir_python_sources().
# DEPENDS: Additional dependency targets.
function(declare_mlir_dialect_extension_python_bindings)
cmake_parse_arguments(ARG
""
"ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME"
"SOURCES;SOURCES_GLOB;DEPENDS"
${ARGN})
# Source files.
set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}")
declare_mlir_python_sources(${_extension_target}
ROOT_DIR "${ARG_ROOT_DIR}"
ADD_TO_PARENT "${ARG_ADD_TO_PARENT}"
SOURCES "${ARG_SOURCES}"
SOURCES_GLOB "${ARG_SOURCES_GLOB}"
)
# Tablegen
if(ARG_TD_FILE)
set(tblgen_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}.tablegen")
set(td_file "${ARG_ROOT_DIR}/${ARG_TD_FILE}")
get_filename_component(relative_td_directory "${ARG_TD_FILE}" DIRECTORY)
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}")
set(output_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_ops_gen.py")
set(LLVM_TARGET_DEFINITIONS ${td_file})
mlir_tablegen("${output_filename}" -gen-python-op-bindings
-bind-dialect=${ARG_DIALECT_NAME}
-dialect-extension=${ARG_EXTENSION_NAME})
add_public_tablegen_target(${tblgen_target})
if(ARG_DEPENDS)
add_dependencies(${tblgen_target} ${ARG_DEPENDS})
endif()
declare_mlir_python_sources("${_extension_target}.ops_gen"
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
ADD_TO_PARENT "${_extension_target}"
SOURCES "${output_filename}"
)
endif()
endfunction()
# Function: mlir_python_setup_extension_rpath
# Sets RPATH properties on a target, assuming that it is being output to
# an _mlir_libs directory with all other libraries. For static linkage,

View File

@ -116,6 +116,25 @@ declare_mlir_dialect_python_bindings(
DIALECT_NAME linalg
DEPENDS LinalgOdsGen)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformOps.td
SOURCES
dialects/_transform_ops_ext.py
dialects/transform/__init__.py
DIALECT_NAME transform)
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgStructuredTransformOps.td
SOURCES
dialects/_structured_transform_ops_ext.py
dialects/transform/structured.py
DIALECT_NAME transform
EXTENSION_NAME structured_transform)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

View File

@ -0,0 +1,21 @@
//===-- LinalgStructuredTransformOps.td --------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Entry point of the Python bindings generator for the structured transform ops
// provided by Linalg (and other dialects).
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td"
#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS

View File

@ -0,0 +1,15 @@
//===-- TransformOps.td - Transform ops bind entry point ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_TRANSFORM_OPS
#define PYTHON_BINDINGS_TRANSFORM_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/Transform/IR/TransformOps.td"
#endif // PYTHON_BINDINGS_TRANSFORM_OPS

View File

@ -0,0 +1,178 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import List, Optional, Sequence, Union
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
def _get_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
"""Creates an array attribute from its operand."""
if values is None:
return ArrayAttr.get([])
if isinstance(values, ArrayAttr):
return values
return ArrayAttr.get(values)
def _get_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]]
) -> ArrayAttr:
"""Creates an integer array attribute from its operand.
If the operand is already an array attribute, forwards it. Otherwise treats
the operand as a list of attributes or integers, possibly intersperced, to
create a new array attribute containing integer attributes. Expects the
thread-local MLIR context to have been set by the context manager.
"""
if values is None:
return ArrayAttr.get([])
if isinstance(values, ArrayAttr):
return values
attributes = []
for value in values:
if isinstance(value, IntegerAttr):
attributes.append(value)
else:
attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value))
return ArrayAttr.get(attributes)
def _get_int_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an array attribute containing array attributes of integers.
If the operand is already an array attribute, forwards it. Otherwise treats
the operand as a list of attributes or integers, potentially interpserced, to
create a new array-of-array attribute. Expects the thread-local MLIR context
to have been set by the context manager.
"""
if values is None:
return ArrayAttr.get([])
if isinstance(values, ArrayAttr):
return values
return ArrayAttr.get([_get_int_array_attr(value) for value in values])
class InterchangeOp:
"""Specialization for InterchangeOp class."""
def __init__(self,
target: Union[Operation, Value],
*,
iterator_interchange: OptionalIntList = None,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
interchange_attr = _get_int_array_attr(iterator_interchange)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
iterator_interchange=interchange_attr,
loc=loc,
ip=ip)
class PadOp:
"""Specialization for PadOp class."""
def __init__(self,
target: Union[Operation, Value],
*,
padding_values: Optional[Union[ArrayAttr,
Sequence[Attribute]]] = None,
padding_dimensions: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
hoist_paddings: OptionalIntList = None,
transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
ArrayAttr, IntOrAttrList]]]] = None,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
padding_values_attr = _get_array_attr(padding_values)
padding_dimensions_attr = _get_int_array_attr(padding_dimensions)
pack_paddings_attr = _get_int_array_attr(pack_paddings)
hoist_paddings_attr = _get_int_array_attr(hoist_paddings)
transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
padding_values=padding_values_attr,
padding_dimensions=padding_dimensions_attr,
pack_paddings=pack_paddings_attr,
hoist_paddings=hoist_paddings_attr,
transpose_paddings=transpose_paddings_attr,
loc=loc,
ip=ip)
class ScalarizeOp:
"""Specialization for ScalarizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)
class TileOp:
"""Specialization for TileOp class."""
def __init__(self,
target: Union[Operation, Value],
*,
sizes: OptionalIntList = None,
interchange: OptionalIntList = None,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
sizes_attr = _get_int_array_attr(sizes)
num_loops = sum(
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
super().__init__(
pdl_operation_type, [pdl_operation_type] * num_loops,
_get_op_result_or_value(target),
sizes=sizes_attr,
interchange=_get_int_array_attr(interchange) if interchange else None,
loc=loc,
ip=ip)
def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]:
if not attr:
return []
return [IntegerAttr(element).value for element in attr]
class VectorizeOp:
"""Specialization for VectorizeOp class."""
def __init__(self,
target: Union[Operation, Value],
*,
vectorize_padding: Union[bool, BoolAttr] = False,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
if isinstance(vectorize_padding, bool):
vectorize_padding = BoolAttr.get(vectorize_padding)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
vectorize_padding=vectorize_padding,
loc=loc,
ip=ip)

View File

@ -0,0 +1,106 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, overload, Sequence, Union
def _get_symbol_ref_attr(value: Union[Attribute, str]):
if isinstance(value, Attribute):
return value
return FlatSymbolRefAttr.get(value)
class GetClosestIsolatedParentOp:
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
pdl.OperationType.get(),
_get_op_result_or_value(target),
loc=loc,
ip=ip)
class PDLMatchOp:
def __init__(self,
target: Union[Operation, Value],
pattern_name: Union[Attribute, str],
*,
loc=None,
ip=None):
super().__init__(
pdl.OperationType.get(),
_get_op_result_or_value(target),
_get_symbol_ref_attr(pattern_name),
loc=loc,
ip=ip)
class SequenceOp:
@overload
def __init__(self, resultsOrRoot: Sequence[Type],
optionalRoot: Optional[Union[Operation, Value]]):
...
@overload
def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]],
optionalRoot: NoneType):
...
def __init__(self, resultsOrRoot=None, optionalRoot=None):
results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else []
root = (
resultsOrRoot
if not isinstance(resultsOrRoot, Sequence) else optionalRoot)
root = _get_op_result_or_value(root) if root else None
super().__init__(results_=results, root=root)
self.regions[0].blocks.append(pdl.OperationType.get())
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]
class WithPDLPatternsOp:
def __init__(self,
target: Optional[Union[Operation, Value]] = None,
*,
loc=None,
ip=None):
super().__init__(
root=_get_op_result_or_value(target) if target else None,
loc=loc,
ip=ip)
self.regions[0].blocks.append(pdl.OperationType.get())
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]
class YieldOp:
def __init__(self,
operands: Union[Operation, Sequence[Value]] = [],
*,
loc=None,
ip=None):
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)

View File

@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._transform_ops_gen import *

View File

@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._structured_transform_ops_gen import *

View File

@ -0,0 +1,84 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects import pdl
def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
print(module)
return f
@run
def testSequenceOp():
sequence = transform.SequenceOp([pdl.OperationType.get()])
with InsertionPoint(sequence.body):
transform.YieldOp([sequence.bodyTarget])
# CHECK-LABEL: TEST: testSequenceOp
# CHECK: = transform.sequence {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
# CHECK: yield %[[ARG0]] : !pdl.operation
# CHECK: } : !pdl.operation
@run
def testNestedSequenceOp():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
nested = transform.SequenceOp(sequence.bodyTarget)
with InsertionPoint(nested.body):
doubly_nested = transform.SequenceOp([pdl.OperationType.get()],
nested.bodyTarget)
with InsertionPoint(doubly_nested.body):
transform.YieldOp([doubly_nested.bodyTarget])
transform.YieldOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testNestedSequenceOp
# CHECK: transform.sequence {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
# CHECK: sequence %[[ARG0]] {
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
# CHECK: = sequence %[[ARG1]] {
# CHECK: ^{{.*}}(%[[ARG2:.+]]: !pdl.operation):
# CHECK: yield %[[ARG2]] : !pdl.operation
# CHECK: } : !pdl.operation
# CHECK: }
# CHECK: }
@run
def testTransformPDLOps():
withPdl = transform.WithPDLPatternsOp()
with InsertionPoint(withPdl.body):
sequence = transform.SequenceOp([pdl.OperationType.get()],
withPdl.bodyTarget)
with InsertionPoint(sequence.body):
match = transform.PDLMatchOp(sequence.bodyTarget, "pdl_matcher")
transform.YieldOp(match)
# CHECK-LABEL: TEST: testTransformPDLOps
# CHECK: transform.with_pdl_patterns {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
# CHECK: = sequence %[[ARG0]] {
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
# CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
# CHECK: yield %[[RES]] : !pdl.operation
# CHECK: } : !pdl.operation
# CHECK: }
@run
def testGetClosestIsolatedParentOp():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
transform.GetClosestIsolatedParentOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
# CHECK: transform.sequence
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
# CHECK: = get_closest_isolated_parent %[[ARG1]]

View File

@ -0,0 +1,118 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects import pdl
from mlir.dialects.transform import structured
def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
print(module)
return f
@run
def testInterchange():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.InterchangeOp(
sequence.bodyTarget,
iterator_interchange=[
IntegerAttr.get(IntegerType.get_signless(64), 1), 0
])
transform.YieldOp()
# CHECK-LABEL: TEST: testInterchange
# CHECK: transform.sequence
# CHECK: transform.structured.interchange
# CHECK: iterator_interchange = [1, 0]
@run
def testPad():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.PadOp(
sequence.bodyTarget,
padding_values=[FloatAttr.get_f32(42.0)],
padding_dimensions=[1],
transpose_paddings=[[1, 0]])
transform.YieldOp()
# CHECK-LABEL: TEST: testPad
# CHECK: transform.sequence
# CHECK: transform.structured.pad
# CHECK-DAG: padding_values = [4.200000e+01 : f32]
# CHECK-DAG: padding_dimensions = [1]
# CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
# CHECK-DAG: hoist_paddings = []
# CHECK-DAG: pack_paddings = []
@run
def testScalarize():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.ScalarizeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testScalarize
# CHECK: transform.structured.scalarize
@run
def testTileCompact():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileCompact
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
# CHECK-DAG: interchange = [0, 1]
# CHECK-DAG: sizes = [4, 8]
@run
def testTileAttributes():
sequence = transform.SequenceOp()
attr = ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
ichange = ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileAttributes
# CHECK: transform.sequence
# CHECK: structured.tile
# CHECK-DAG: interchange = [0, 1]
# CHECK-DAG: sizes = [4, 8]
@run
def testTileZero():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.TileOp(
sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileZero
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
# CHECK-DAG: interchange = [0, 1, 2, 3]
# CHECK-DAG: sizes = [4, 0, 2, 0]
@run
def testVectorize():
sequence = transform.SequenceOp()
with InsertionPoint(sequence.body):
structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
transform.YieldOp()
# CHECK-LABEL: TEST: testVectorize
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
# CHECK: vectorize_padding = true

View File

@ -50,6 +50,10 @@ class _Dialect(_ods_ir.Dialect):
)Py";
constexpr const char *dialectExtensionTemplate = R"Py(
from ._{0}_ops_gen import _Dialect
)Py";
/// Template for operation class:
/// {0} is the Python class name;
/// {1} is the operation name.
@ -270,6 +274,10 @@ static llvm::cl::opt<std::string>
llvm::cl::desc("The dialect to run the generator for"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
static llvm::cl::opt<std::string> clDialectExtensionName(
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
using AttributeClasses = DenseMap<StringRef, StringRef>;
/// Checks whether `str` is a Python keyword.
@ -1014,8 +1022,14 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
AttributeClasses attributeClasses;
constructAttributeMapping(records, attributeClasses);
os << llvm::formatv(fileHeader, clDialectName.getValue());
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
bool isExtension = !clDialectExtensionName.empty();
os << llvm::formatv(fileHeader, isExtension
? clDialectExtensionName.getValue()
: clDialectName.getValue());
if (isExtension)
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
else
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);

View File

@ -825,6 +825,74 @@ filegroup(
],
)
##---------------------------------------------------------------------------##
# Transform dialect and extensions.
##---------------------------------------------------------------------------##
td_library(
name = "TransformOpsPyTdFiles",
srcs = [
"//mlir:include/mlir/Bindings/Python/Attributes.td",
],
deps = [
"//mlir:OpBaseTdFiles",
"//mlir:TransformDialectTdFiles",
],
)
gentbl_filegroup(
name = "TransformOpsPyGen",
tbl_outs = [
(
[
"-gen-python-op-bindings",
"-bind-dialect=transform",
],
"mlir/dialects/_transform_ops_gen.py",
),
],
tblgen = "//mlir:mlir-tblgen",
td_file = "mlir/dialects/TransformOps.td",
deps = [
":TransformOpsPyTdFiles",
],
)
gentbl_filegroup(
name = "StructuredTransformOpsPyGen",
tbl_outs = [
(
[
"-gen-python-op-bindings",
"-bind-dialect=transform",
"-dialect-extension=structured_transform",
],
"mlir/dialects/_structured_transform_ops_gen.py",
),
],
tblgen = "//mlir:mlir-tblgen",
td_file = "mlir/dialects/LinalgStructuredTransformOps.td",
deps = [
":TransformOpsPyTdFiles",
"//mlir:LinalgTransformOpsTdFiles",
],
)
filegroup(
name = "TransformOpsPyFiles",
srcs = [
"mlir/dialects/_structured_transform_ops_ext.py",
"mlir/dialects/_transform_ops_ext.py",
":StructuredTransformOpsPyGen",
":TransformOpsPyGen",
],
)
filegroup(
name = "TransformOpsPackagePyFiles",
srcs = glob(["mlir/dialects/transform/*.py"]),
)
##---------------------------------------------------------------------------##
# Vector dialect.
##---------------------------------------------------------------------------##