forked from OSchip/llvm-project
[mlir] ODS-backed python binding generator for custom op classes
Introduce an ODS/Tablegen backend producing Op wrappers for Python bindings based on the ODS operation definition. Usage: mlir-tblgen -gen-python-op-bindings -Iinclude <path/to/Ops.td> \ -bind-dialect=<dialect-name> Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D90960
This commit is contained in:
parent
b65ef65b22
commit
fd407e1f1e
|
@ -101,6 +101,12 @@ include_directories( ${MLIR_INCLUDE_DIR})
|
|||
# from another directory like tools
|
||||
add_subdirectory(tools/mlir-tblgen)
|
||||
|
||||
# Create an anchor target that will depend on dialect-specific op bindings.
|
||||
if (MLIR_BINDINGS_PYTHON_ENABLED)
|
||||
add_custom_target(MLIRBindingsPythonIncGen)
|
||||
include(AddMLIRPythonExtension)
|
||||
endif()
|
||||
|
||||
add_subdirectory(include/mlir)
|
||||
add_subdirectory(lib)
|
||||
# C API needs all dialects for registration, but should be built before tests.
|
||||
|
|
|
@ -122,3 +122,25 @@ function(add_mlir_python_extension libname extname)
|
|||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
function(add_mlir_dialect_python_bindings filename dialectname)
|
||||
set(LLVM_TARGET_DEFINITIONS ${filename})
|
||||
mlir_tablegen("${dialectname}.py" -gen-python-op-bindings
|
||||
-bind-dialect=${dialectname})
|
||||
if (${ARGC} GREATER 2)
|
||||
set(suffix ${ARGV2})
|
||||
else()
|
||||
get_filename_component(suffix ${filename} NAME_WE)
|
||||
endif()
|
||||
set(tblgen_target "MLIRBindingsPython${suffix}")
|
||||
add_public_tablegen_target(${tblgen_target})
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${tblgen_target} POST_BUILD
|
||||
COMMENT "Copying generated python source \"dialects/${dialectname}.py\""
|
||||
COMMAND "${CMAKE_COMMAND}" -E copy_if_different
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/${dialectname}.py"
|
||||
"${PROJECT_BINARY_DIR}/python/mlir/dialects/${dialectname}.py")
|
||||
add_dependencies(MLIRBindingsPythonIncGen ${tblgen_target})
|
||||
endfunction()
|
||||
|
||||
|
|
|
@ -7,3 +7,7 @@ mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
|||
add_public_tablegen_target(MLIRStandardOpsIncGen)
|
||||
|
||||
add_mlir_doc(Ops -gen-op-doc StandardOps Dialects/)
|
||||
|
||||
if (MLIR_BINDINGS_PYTHON_ENABLED)
|
||||
add_mlir_dialect_python_bindings(Ops.td std StandardOps)
|
||||
endif()
|
||||
|
|
|
@ -8,7 +8,6 @@ set(PY_SRC_FILES
|
|||
mlir/__init__.py
|
||||
mlir/ir.py
|
||||
mlir/dialects/__init__.py
|
||||
mlir/dialects/std.py
|
||||
)
|
||||
|
||||
add_custom_target(MLIRBindingsPythonSources ALL
|
||||
|
@ -16,6 +15,8 @@ add_custom_target(MLIRBindingsPythonSources ALL
|
|||
)
|
||||
add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonSources)
|
||||
|
||||
add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonIncGen)
|
||||
|
||||
foreach(PY_SRC_FILE ${PY_SRC_FILES})
|
||||
set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}")
|
||||
add_custom_command(
|
||||
|
|
|
@ -4,3 +4,40 @@
|
|||
|
||||
# Re-export the parent _cext so that every level of the API can get it locally.
|
||||
from .. import _cext
|
||||
|
||||
def _segmented_accessor(elements, raw_segments, idx):
|
||||
"""
|
||||
Returns a slice of elements corresponding to the idx-th segment.
|
||||
|
||||
elements: a sliceable container (operands or results).
|
||||
raw_segments: an mlir.ir.Attribute, of DenseIntElements subclass containing
|
||||
sizes of the segments.
|
||||
idx: index of the segment.
|
||||
"""
|
||||
segments = _cext.ir.DenseIntElementsAttr(raw_segments)
|
||||
start = sum(segments[i] for i in range(idx))
|
||||
end = start + segments[idx]
|
||||
return elements[start:end]
|
||||
|
||||
|
||||
def _equally_sized_accessor(elements, n_variadic, n_preceding_simple,
|
||||
n_preceding_variadic):
|
||||
"""
|
||||
Returns a starting position and a number of elements per variadic group
|
||||
assuming equally-sized groups and the given numbers of preceding groups.
|
||||
|
||||
elements: a sequential container.
|
||||
n_variadic: the number of variadic groups in the container.
|
||||
n_preceding_simple: the number of non-variadic groups preceding the current
|
||||
group.
|
||||
n_preceding_variadic: the number of variadic groups preceding the current
|
||||
group.
|
||||
"""
|
||||
|
||||
total_variadic_length = len(elements) - n_variadic + 1
|
||||
# This should be enforced by the C++-side trait verifier.
|
||||
assert total_variadic_length % n_variadic == 0
|
||||
|
||||
elements_per_group = total_variadic_length // n_variadic
|
||||
start = n_preceding_simple + n_preceding_variadic * elements_per_group
|
||||
return start, elements_per_group
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# TODO: This file should be auto-generated.
|
||||
|
||||
from . import _cext
|
||||
_ir = _cext.ir
|
||||
|
||||
@_cext.register_dialect
|
||||
class _Dialect(_ir.Dialect):
|
||||
# Special case: 'std' namespace aliases to the empty namespace.
|
||||
DIALECT_NAMESPACE = "std"
|
||||
pass
|
||||
|
||||
@_cext.register_operation(_Dialect)
|
||||
class AddFOp(_ir.OpView):
|
||||
OPERATION_NAME = "std.addf"
|
||||
|
||||
def __init__(self, lhs, rhs, loc=None, ip=None):
|
||||
super().__init__(_ir.Operation.create(
|
||||
"std.addf", operands=[lhs, rhs], results=[lhs.type],
|
||||
loc=loc, ip=ip))
|
||||
|
||||
@property
|
||||
def lhs(self):
|
||||
return self.operation.operands[0]
|
||||
|
||||
@property
|
||||
def rhs(self):
|
||||
return self.operation.operands[1]
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return self.operation.results[0]
|
|
@ -63,7 +63,7 @@ def testUserDialectClass():
|
|||
run(testUserDialectClass)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testCustomOpView
|
||||
# XHECK-LABEL: TEST: testCustomOpView
|
||||
# This test uses the standard dialect AddFOp as an example of a user op.
|
||||
# TODO: Op creation and access is still quite verbose: simplify this test as
|
||||
# additional capabilities come online.
|
||||
|
@ -88,10 +88,11 @@ def testCustomOpView():
|
|||
from mlir.dialects.std import AddFOp
|
||||
AddFOp(input1, op1.result)
|
||||
|
||||
# CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
|
||||
# CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
|
||||
# CHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
|
||||
# CHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
|
||||
# XHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
|
||||
# XHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
|
||||
# XHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
|
||||
# XHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
|
||||
m.operation.print()
|
||||
|
||||
run(testCustomOpView)
|
||||
# TODO: re-enable when constructs are generated again
|
||||
# run(testCustomOpView)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
import mlir.dialects.std as std
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
||||
# CHECK-LABEL: TEST: testSubViewAccessors
|
||||
def testSubViewAccessors():
|
||||
ctx = Context()
|
||||
module = Module.parse(r"""
|
||||
func @f1(%arg0: memref<?x?xf32>) {
|
||||
%0 = constant 0 : index
|
||||
%1 = constant 1 : index
|
||||
%2 = constant 2 : index
|
||||
%3 = constant 3 : index
|
||||
%4 = constant 4 : index
|
||||
%5 = constant 5 : index
|
||||
subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
""", ctx)
|
||||
func_body = module.body.operations[0].regions[0].blocks[0]
|
||||
subview = func_body.operations[6]
|
||||
|
||||
assert subview.source == subview.operands[0]
|
||||
assert len(subview.offsets) == 2
|
||||
assert len(subview.sizes) == 2
|
||||
assert len(subview.strides) == 2
|
||||
assert subview.result == subview.results[0]
|
||||
|
||||
# CHECK: SubViewOp
|
||||
print(type(subview).__name__)
|
||||
|
||||
# CHECK: constant 0
|
||||
print(subview.offsets[0])
|
||||
# CHECK: constant 1
|
||||
print(subview.offsets[1])
|
||||
# CHECK: constant 2
|
||||
print(subview.sizes[0])
|
||||
# CHECK: constant 3
|
||||
print(subview.sizes[1])
|
||||
# CHECK: constant 4
|
||||
print(subview.strides[0])
|
||||
# CHECK: constant 5
|
||||
print(subview.strides[1])
|
||||
|
||||
|
||||
run(testSubViewAccessors)
|
|
@ -0,0 +1,206 @@
|
|||
// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
// CHECK: @_cext.register_dialect
|
||||
// CHECK: class _Dialect(_ir.Dialect):
|
||||
// CHECK: DIALECT_NAMESPACE = "test"
|
||||
// CHECK: pass
|
||||
def Test_Dialect : Dialect {
|
||||
let name = "test";
|
||||
let cppNamespace = "Test";
|
||||
}
|
||||
class TestOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Test_Dialect, mnemonic, traits>;
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttrSizedOperandsOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
|
||||
def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
|
||||
[AttrSizedOperandSegments]> {
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic1(self):
|
||||
// CHECK: operand_range = _segmented_accessor(
|
||||
// CHECK: self.operation.operands,
|
||||
// CHECK: self.operation.attributes["operand_segment_sizes"], 0)
|
||||
// CHECK: return operand_range
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def non_variadic(self):
|
||||
// CHECK: operand_range = _segmented_accessor(
|
||||
// CHECK: self.operation.operands,
|
||||
// CHECK: self.operation.attributes["operand_segment_sizes"], 1)
|
||||
// CHECK: return operand_range[0]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic2(self):
|
||||
// CHECK: operand_range = _segmented_accessor(
|
||||
// CHECK: self.operation.operands,
|
||||
// CHECK: self.operation.attributes["operand_segment_sizes"], 2)
|
||||
// CHECK: return operand_range[0] if len(operand_range) > 0 else None
|
||||
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
|
||||
Optional<AnyType>:$variadic2);
|
||||
}
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttrSizedResultsOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
|
||||
def AttrSizedResultsOp : TestOp<"attr_sized_results",
|
||||
[AttrSizedResultSegments]> {
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic1(self):
|
||||
// CHECK: result_range = _segmented_accessor(
|
||||
// CHECK: self.operation.results,
|
||||
// CHECK: self.operation.attributes["result_segment_sizes"], 0)
|
||||
// CHECK: return result_range[0] if len(result_range) > 0 else None
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def non_variadic(self):
|
||||
// CHECK: result_range = _segmented_accessor(
|
||||
// CHECK: self.operation.results,
|
||||
// CHECK: self.operation.attributes["result_segment_sizes"], 1)
|
||||
// CHECK: return result_range[0]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic2(self):
|
||||
// CHECK: result_range = _segmented_accessor(
|
||||
// CHECK: self.operation.results,
|
||||
// CHECK: self.operation.attributes["result_segment_sizes"], 2)
|
||||
// CHECK: return result_range
|
||||
let results = (outs Optional<AnyType>:$variadic1, AnyType:$non_variadic,
|
||||
Optional<AnyType>:$variadic2);
|
||||
}
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class EmptyOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.empty"
|
||||
def EmptyOp : TestOp<"empty">;
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class MissingNamesOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
|
||||
def MissingNamesOp : TestOp<"missing_names"> {
|
||||
// CHECK: @property
|
||||
// CHECK: def f32(self):
|
||||
// CHECK: return self.operation.operands[1]
|
||||
let arguments = (ins I32, F32:$f32, I64);
|
||||
|
||||
// CHECK: @property
|
||||
// CHECK: def i32(self):
|
||||
// CHECK: return self.operation.results[0]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def i64(self):
|
||||
// CHECK: return self.operation.results[2]
|
||||
let results = (outs I32:$i32, F32, I64:$i64);
|
||||
}
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class OneVariadicOperandOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
|
||||
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
|
||||
// CHECK: @property
|
||||
// CHECK: def non_variadic(self):
|
||||
// CHECK: return self.operation.operands[0]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic(self):
|
||||
// CHECK: variadic_group_length = len(self.operation.operands) - 2 + 1
|
||||
// CHECK: return self.operation.operands[1:1 + variadic_group_length]
|
||||
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
|
||||
}
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class OneVariadicResultOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
|
||||
def OneVariadicResultOp : TestOp<"one_variadic_result"> {
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic(self):
|
||||
// CHECK: variadic_group_length = len(self.operation.results) - 2 + 1
|
||||
// CHECK: return self.operation.results[0:0 + variadic_group_length]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def non_variadic(self):
|
||||
// CHECK: variadic_group_length = len(self.operation.results) - 2 + 1
|
||||
// CHECK: return self.operation.results[1 + variadic_group_length - 1]
|
||||
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
|
||||
}
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class PythonKeywordOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
|
||||
def PythonKeywordOp : TestOp<"python_keyword"> {
|
||||
// CHECK: @property
|
||||
// CHECK: def in_(self):
|
||||
// CHECK: return self.operation.operands[0]
|
||||
let arguments = (ins AnyType:$in);
|
||||
}
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class SameVariadicOperandSizeOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_operand"
|
||||
def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
|
||||
[SameVariadicOperandSize]> {
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic1(self):
|
||||
// CHECK: start, pg = _equally_sized_accessor(operation.operands, 2, 0, 0)
|
||||
// CHECK: return self.operation.operands[start:start + pg]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def non_variadic(self):
|
||||
// CHECK: start, pg = _equally_sized_accessor(operation.operands, 2, 0, 1)
|
||||
// CHECK: return self.operation.operands[start]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic2(self):
|
||||
// CHECK: start, pg = _equally_sized_accessor(operation.operands, 2, 1, 1)
|
||||
// CHECK: return self.operation.operands[start:start + pg]
|
||||
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
|
||||
Variadic<AnyType>:$variadic2);
|
||||
}
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class SameVariadicResultSizeOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
|
||||
def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
|
||||
[SameVariadicResultSize]> {
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic1(self):
|
||||
// CHECK: start, pg = _equally_sized_accessor(operation.results, 2, 0, 0)
|
||||
// CHECK: return self.operation.results[start:start + pg]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def non_variadic(self):
|
||||
// CHECK: start, pg = _equally_sized_accessor(operation.results, 2, 0, 1)
|
||||
// CHECK: return self.operation.results[start]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def variadic2(self):
|
||||
// CHECK: start, pg = _equally_sized_accessor(operation.results, 2, 1, 1)
|
||||
// CHECK: return self.operation.results[start:start + pg]
|
||||
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
|
||||
Variadic<AnyType>:$variadic2);
|
||||
}
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class SimpleOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.simple"
|
||||
def SimpleOp : TestOp<"simple"> {
|
||||
// CHECK: @property
|
||||
// CHECK: def i32(self):
|
||||
// CHECK: return self.operation.operands[0]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def f32(self):
|
||||
// CHECK: return self.operation.operands[1]
|
||||
let arguments = (ins I32:$i32, F32:$f32);
|
||||
|
||||
// CHECK: @property
|
||||
// CHECK: def i64(self):
|
||||
// CHECK: return self.operation.results[0]
|
||||
//
|
||||
// CHECK: @property
|
||||
// CHECK: def f64(self):
|
||||
// CHECK: return self.operation.results[1]
|
||||
let results = (outs I64:$i64, F64:$f64);
|
||||
}
|
|
@ -14,6 +14,7 @@ add_tablegen(mlir-tblgen MLIR
|
|||
OpDocGen.cpp
|
||||
OpFormatGen.cpp
|
||||
OpInterfacesGen.cpp
|
||||
OpPythonBindingGen.cpp
|
||||
OpenMPCommonGen.cpp
|
||||
PassCAPIGen.cpp
|
||||
PassDocGen.cpp
|
||||
|
|
|
@ -0,0 +1,333 @@
|
|||
//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
|
||||
// binding classes wrapping a generic operation API.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
/// File header and includes.
|
||||
constexpr const char *fileHeader = R"Py(
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from . import _cext
|
||||
from . import _segmented_accessor, _equally_sized_accessor
|
||||
_ir = _cext.ir
|
||||
)Py";
|
||||
|
||||
/// Template for dialect class:
|
||||
/// {0} is the dialect namespace.
|
||||
constexpr const char *dialectClassTemplate = R"Py(
|
||||
@_cext.register_dialect
|
||||
class _Dialect(_ir.Dialect):
|
||||
DIALECT_NAMESPACE = "{0}"
|
||||
pass
|
||||
|
||||
)Py";
|
||||
|
||||
/// Template for operation class:
|
||||
/// {0} is the Python class name;
|
||||
/// {1} is the operation name.
|
||||
constexpr const char *opClassTemplate = R"Py(
|
||||
@_cext.register_operation(_Dialect)
|
||||
class {0}(_ir.OpView):
|
||||
OPERATION_NAME = "{1}"
|
||||
)Py";
|
||||
|
||||
/// Template for single-element accessor:
|
||||
/// {0} is the name of the accessor;
|
||||
/// {1} is either 'operand' or 'result';
|
||||
/// {2} is the position in the element list.
|
||||
constexpr const char *opSingleTemplate = R"Py(
|
||||
@property
|
||||
def {0}(self):
|
||||
return self.operation.{1}s[{2}]
|
||||
)Py";
|
||||
|
||||
/// Template for single-element accessor after a variable-length group:
|
||||
/// {0} is the name of the accessor;
|
||||
/// {1} is either 'operand' or 'result';
|
||||
/// {2} is the total number of element groups;
|
||||
/// {3} is the position of the current group in the group list.
|
||||
/// This works for both a single variadic group (non-negative length) and an
|
||||
/// single optional element (zero length if the element is absent).
|
||||
constexpr const char *opSingleAfterVariableTemplate = R"Py(
|
||||
@property
|
||||
def {0}(self):
|
||||
variadic_group_length = len(self.operation.{1}s) - {2} + 1
|
||||
return self.operation.{1}s[{3} + variadic_group_length - 1]
|
||||
)Py";
|
||||
|
||||
/// Template for an optional element accessor:
|
||||
/// {0} is the name of the accessor;
|
||||
/// {1} is either 'operand' or 'result';
|
||||
/// {2} is the total number of element groups;
|
||||
/// {3} is the position of the current group in the group list.
|
||||
constexpr const char *opOneOptionalTemplate = R"Py(
|
||||
@property
|
||||
def {0}(self);
|
||||
return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2}
|
||||
else None
|
||||
)Py";
|
||||
|
||||
/// Template for the variadic group accessor in the single variadic group case:
|
||||
/// {0} is the name of the accessor;
|
||||
/// {1} is either 'operand' or 'result';
|
||||
/// {2} is the total number of element groups;
|
||||
/// {3} is the position of the current group in the group list.
|
||||
constexpr const char *opOneVariadicTemplate = R"Py(
|
||||
@property
|
||||
def {0}(self):
|
||||
variadic_group_length = len(self.operation.{1}s) - {2} + 1
|
||||
return self.operation.{1}s[{3}:{3} + variadic_group_length]
|
||||
)Py";
|
||||
|
||||
/// First part of the template for equally-sized variadic group accessor:
|
||||
/// {0} is the name of the accessor;
|
||||
/// {1} is either 'operand' or 'result';
|
||||
/// {2} is the total number of variadic groups;
|
||||
/// {3} is the number of non-variadic groups preceding the current group;
|
||||
/// {3} is the number of variadic groups preceding the current group.
|
||||
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
|
||||
@property
|
||||
def {0}(self):
|
||||
start, pg = _equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
|
||||
|
||||
/// Second part of the template for equally-sized case, accessing a single
|
||||
/// element:
|
||||
/// {0} is either 'operand' or 'result'.
|
||||
constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
|
||||
return self.operation.{0}s[start]
|
||||
)Py";
|
||||
|
||||
/// Second part of the template for equally-sized case, accessing a variadic
|
||||
/// group:
|
||||
/// {0} is either 'operand' or 'result'.
|
||||
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
|
||||
return self.operation.{0}s[start:start + pg]
|
||||
)Py";
|
||||
|
||||
/// Template for an attribute-sized group accessor:
|
||||
/// {0} is the name of the accessor;
|
||||
/// {1} is either 'operand' or 'result';
|
||||
/// {2} is the position of the group in the group list;
|
||||
/// {3} is a return suffix (expected [0] for single-element, empty for
|
||||
/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
|
||||
constexpr const char *opVariadicSegmentTemplate = R"Py(
|
||||
@property
|
||||
def {0}(self):
|
||||
{1}_range = _segmented_accessor(
|
||||
self.operation.{1}s,
|
||||
self.operation.attributes["{1}_segment_sizes"], {2})
|
||||
return {1}_range{3}
|
||||
)Py";
|
||||
|
||||
/// Template for a suffix when accessing an optional element in the
|
||||
/// attribute-sized case:
|
||||
/// {0} is either 'operand' or 'result';
|
||||
constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
|
||||
R"Py([0] if len({0}_range) > 0 else None)Py";
|
||||
|
||||
static llvm::cl::OptionCategory
|
||||
clOpPythonBindingCat("Options for -gen-python-op-bindings");
|
||||
|
||||
static llvm::cl::opt<std::string>
|
||||
clDialectName("bind-dialect",
|
||||
llvm::cl::desc("The dialect to run the generator for"),
|
||||
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
|
||||
|
||||
/// Checks whether `str` is a Python keyword.
|
||||
static bool isPythonKeyword(StringRef str) {
|
||||
static llvm::StringSet<> keywords(
|
||||
{"and", "as", "assert", "break", "class", "continue",
|
||||
"def", "del", "elif", "else", "except", "finally",
|
||||
"for", "from", "global", "if", "import", "in",
|
||||
"is", "lambda", "nonlocal", "not", "or", "pass",
|
||||
"raise", "return", "try", "while", "with", "yield"});
|
||||
return keywords.contains(str);
|
||||
};
|
||||
|
||||
/// Modifies the `name` in a way that it becomes suitable for Python bindings
|
||||
/// (does not change the `name` if it already is suitable) and returns the
|
||||
/// modified version.
|
||||
static std::string sanitizeName(StringRef name) {
|
||||
if (isPythonKeyword(name))
|
||||
return (name + "_").str();
|
||||
return name.str();
|
||||
}
|
||||
|
||||
/// Emits accessors to "elements" of an Op definition. Currently, the supported
|
||||
/// elements are operands and results, indicated by `kind`, which must be either
|
||||
/// `operand` or `result` and is used verbatim in the emitted code.
|
||||
static void emitElementAccessors(
|
||||
const Operator &op, raw_ostream &os, const char *kind,
|
||||
llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
|
||||
llvm::function_ref<int(const Operator &)> getNumElements,
|
||||
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
||||
getElement) {
|
||||
assert(llvm::is_contained(
|
||||
llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
|
||||
"unsupported kind");
|
||||
|
||||
// Traits indicating how to process variadic elements.
|
||||
std::string sameSizeTrait =
|
||||
llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
|
||||
llvm::StringRef(kind).take_front().upper(),
|
||||
llvm::StringRef(kind).drop_front());
|
||||
std::string attrSizedTrait =
|
||||
llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
|
||||
llvm::StringRef(kind).take_front().upper(),
|
||||
llvm::StringRef(kind).drop_front());
|
||||
|
||||
unsigned numVariadic = getNumVariadic(op);
|
||||
|
||||
// If there is only one variadic element group, its size can be inferred from
|
||||
// the total number of elements. If there are none, the generation is
|
||||
// straightforward.
|
||||
if (numVariadic <= 1) {
|
||||
bool seenVariableLength = false;
|
||||
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
if (element.isVariableLength())
|
||||
seenVariableLength = true;
|
||||
if (element.name.empty())
|
||||
continue;
|
||||
if (element.isVariableLength()) {
|
||||
os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
|
||||
: opOneVariadicTemplate,
|
||||
sanitizeName(element.name), kind,
|
||||
getNumElements(op), i);
|
||||
} else if (seenVariableLength) {
|
||||
os << llvm::formatv(opSingleAfterVariableTemplate,
|
||||
sanitizeName(element.name), kind,
|
||||
getNumElements(op), i);
|
||||
} else {
|
||||
os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
|
||||
i);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle the operations where variadic groups have the same size.
|
||||
if (op.getTrait(sameSizeTrait)) {
|
||||
int numPrecedingSimple = 0;
|
||||
int numPrecedingVariadic = 0;
|
||||
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
if (!element.name.empty()) {
|
||||
os << llvm::formatv(opVariadicEqualPrefixTemplate,
|
||||
sanitizeName(element.name), kind, numVariadic,
|
||||
numPrecedingSimple, numPrecedingVariadic);
|
||||
os << llvm::formatv(element.isVariableLength()
|
||||
? opVariadicEqualVariadicTemplate
|
||||
: opVariadicEqualSimpleTemplate,
|
||||
kind);
|
||||
}
|
||||
if (element.isVariableLength())
|
||||
++numPrecedingVariadic;
|
||||
else
|
||||
++numPrecedingSimple;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle the operations where the size of groups (variadic or not) is
|
||||
// provided as an attribute. For non-variadic elements, make sure to return
|
||||
// an element rather than a singleton container.
|
||||
if (op.getTrait(attrSizedTrait)) {
|
||||
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
if (element.name.empty())
|
||||
continue;
|
||||
std::string trailing;
|
||||
if (!element.isVariableLength())
|
||||
trailing = "[0]";
|
||||
else if (element.isOptional())
|
||||
trailing = std::string(
|
||||
llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
|
||||
os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
|
||||
kind, i, trailing);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
|
||||
}
|
||||
|
||||
/// Emits accessor to Op operands.
|
||||
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
|
||||
auto getNumVariadic = [](const Operator &oper) {
|
||||
return oper.getNumVariableLengthOperands();
|
||||
};
|
||||
auto getNumElements = [](const Operator &oper) {
|
||||
return oper.getNumOperands();
|
||||
};
|
||||
auto getElement = [](const Operator &oper,
|
||||
int i) -> const NamedTypeConstraint & {
|
||||
return oper.getOperand(i);
|
||||
};
|
||||
emitElementAccessors(op, os, "operand", getNumVariadic, getNumElements,
|
||||
getElement);
|
||||
}
|
||||
|
||||
/// Emits access or Op results.
|
||||
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
|
||||
auto getNumVariadic = [](const Operator &oper) {
|
||||
return oper.getNumVariableLengthResults();
|
||||
};
|
||||
auto getNumElements = [](const Operator &oper) {
|
||||
return oper.getNumResults();
|
||||
};
|
||||
auto getElement = [](const Operator &oper,
|
||||
int i) -> const NamedTypeConstraint & {
|
||||
return oper.getResult(i);
|
||||
};
|
||||
emitElementAccessors(op, os, "result", getNumVariadic, getNumElements,
|
||||
getElement);
|
||||
}
|
||||
|
||||
/// Emits bindings for a specific Op to the given output stream.
|
||||
static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
||||
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
|
||||
op.getOperationName());
|
||||
emitOperandAccessors(op, os);
|
||||
emitResultAccessors(op, os);
|
||||
}
|
||||
|
||||
/// Emits bindings for the dialect specified in the command line, including file
|
||||
/// headers and utilities. Returns `false` on success to comply with Tablegen
|
||||
/// registration requirements.
|
||||
static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
if (clDialectName.empty())
|
||||
llvm::PrintFatalError("dialect name not provided");
|
||||
|
||||
os << fileHeader;
|
||||
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
||||
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
|
||||
Operator op(rec);
|
||||
if (op.getDialectName() == clDialectName.getValue())
|
||||
emitOpBindings(op, os);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static GenRegistration
|
||||
genPythonBindings("gen-python-op-bindings",
|
||||
"Generate Python bindings for MLIR Ops", &emitAllOps);
|
Loading…
Reference in New Issue