forked from OSchip/llvm-project
[mlir][python] Factor out standalone OpView._ods_build_default class method.
* This allows us to hoist trait level information for regions and sized-variadic to class level attributes (_ODS_REGIONS, _ODS_OPERAND_SEGMENTS, _ODS_RESULT_SEGMENTS). * Eliminates some splicey python generated code in favor of a native helper for it. * Makes it possible to implement custom, variadic and region based builders with one line of python, without needing to manually code access to the segment attributes. * Needs follow-on work for region based callbacks and support for SingleBlockImplicitTerminator. * A follow-up will actually add ODS support for generating custom Python builders that delegate to this new method. * Also includes the start of an e2e sample for constructing linalg ops where this limitation was discovered (working progressively through this example and cleaning up as I go). Differential Revision: https://reviews.llvm.org/D94738
This commit is contained in:
parent
cbdde495ba
commit
71b6b010e6
|
@ -365,7 +365,7 @@ for the canonical way to use this facility.
|
|||
|
||||
Each dialect with a mapping to python requires that an appropriate
|
||||
`{DIALECT_NAMESPACE}.py` wrapper module is created. This is done by invoking
|
||||
`mlir-tablegen` on a python-bindings specific tablegen wrapper that includes
|
||||
`mlir-tblgen` on a python-bindings specific tablegen wrapper that includes
|
||||
the boilerplate and actual dialect specific `td` file. An example, for the
|
||||
`StandardOps` (which is assigned the namespace `std` as a special case):
|
||||
|
||||
|
@ -383,7 +383,7 @@ In the main repository, building the wrapper is done via the CMake function
|
|||
`add_mlir_dialect_python_bindings`, which invokes:
|
||||
|
||||
```
|
||||
mlir-tablegen -gen-python-op-bindings -bind-dialect={DIALECT_NAMESPACE} \
|
||||
mlir-tblgen -gen-python-op-bindings -bind-dialect={DIALECT_NAMESPACE} \
|
||||
{PYTHON_BINDING_TD_FILE}
|
||||
```
|
||||
|
||||
|
@ -411,7 +411,8 @@ The wrapper module tablegen emitter outputs:
|
|||
Note: In order to avoid naming conflicts, all internal names used by the wrapper
|
||||
module are prefixed by `_ods_`.
|
||||
|
||||
Each concrete `OpView` subclass further defines several attributes:
|
||||
Each concrete `OpView` subclass further defines several public-intended
|
||||
attributes:
|
||||
|
||||
* `OPERATION_NAME` attribute with the `str` fully qualified operation name
|
||||
(i.e. `std.absf`).
|
||||
|
@ -421,6 +422,20 @@ Each concrete `OpView` subclass further defines several attributes:
|
|||
for unnamed of each).
|
||||
* `@property` getter, setter and deleter for each declared attribute.
|
||||
|
||||
It further emits additional private-intended attributes meant for subclassing
|
||||
and customization (default cases omit these attributes in favor of the
|
||||
defaults on `OpView`):
|
||||
|
||||
* `_ODS_REGIONS`: A specification on the number and types of regions.
|
||||
Currently a tuple of (min_region_count, has_no_variadic_regions). Note that
|
||||
the API does some light validation on this but the primary purpose is to
|
||||
capture sufficient information to perform other default building and region
|
||||
accessor generation.
|
||||
* `_ODS_OPERAND_SEGMENTS` and `_ODS_RESULT_SEGMENTS`: Black-box value which
|
||||
indicates the structure of either the operand or results with respect to
|
||||
variadics. Used by `OpView._ods_build_default` to decode operand and result
|
||||
lists that contain lists.
|
||||
|
||||
#### Builders
|
||||
|
||||
Presently, only a single, default builder is mapped to the `__init__` method.
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# This is a work in progress example to do end2end build and code generation
|
||||
# of a small linalg program with configuration options. It is currently non
|
||||
# functional and is being used to elaborate the APIs.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import linalg
|
||||
from mlir.dialects import std
|
||||
|
||||
|
||||
# TODO: This should be in the core API.
|
||||
def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
|
||||
"""Creates a |func| op.
|
||||
TODO: This should really be in the MLIR API.
|
||||
Returns:
|
||||
(operation, entry_block)
|
||||
"""
|
||||
attrs = {
|
||||
"type": TypeAttr.get(func_type),
|
||||
"sym_name": StringAttr.get(name),
|
||||
}
|
||||
op = Operation.create("func", regions=1, attributes=attrs)
|
||||
body_region = op.regions[0]
|
||||
entry_block = body_region.blocks.append(*func_type.inputs)
|
||||
return op, entry_block
|
||||
|
||||
|
||||
# TODO: Generate customs builder vs patching one in.
|
||||
def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None):
|
||||
super(linalg.MatmulOp, self).__init__(
|
||||
self._ods_build_default(operands=[[lhs, rhs], [result]],
|
||||
results=[],
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
# TODO: Implement support for SingleBlockImplicitTerminator
|
||||
block = self.regions[0].blocks.append()
|
||||
with InsertionPoint(block):
|
||||
linalg.YieldOp(values=[])
|
||||
|
||||
linalg.MatmulOp.__init__ = PatchMatmulOpInit
|
||||
|
||||
|
||||
def build_matmul_func(func_name, m, k, n, dtype):
|
||||
lhs_type = MemRefType.get(dtype, [m, k])
|
||||
rhs_type = MemRefType.get(dtype, [k, n])
|
||||
result_type = MemRefType.get(dtype, [m, n])
|
||||
# TODO: There should be a one-liner for this.
|
||||
func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
|
||||
_, entry = FuncOp(func_name, func_type)
|
||||
lhs, rhs, result = entry.arguments
|
||||
with InsertionPoint(entry):
|
||||
linalg.MatmulOp(lhs, rhs, result)
|
||||
std.ReturnOp([])
|
||||
|
||||
|
||||
def run():
|
||||
with Context() as c, Location.unknown():
|
||||
module = Module.create()
|
||||
# TODO: This at_block_terminator vs default construct distinction feels
|
||||
# wrong and is error-prone.
|
||||
with InsertionPoint.at_block_terminator(module.body):
|
||||
build_matmul_func('main', 18, 32, 96, F32Type.get())
|
||||
|
||||
print(module)
|
||||
print(module.operation.get_asm(print_generic_op_form=True))
|
||||
|
||||
|
||||
if __name__ == '__main__': run()
|
|
@ -130,6 +130,13 @@ equivalent to printing the operation that produced it.
|
|||
// Utilities.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Helper for creating an @classmethod.
|
||||
template <class Func, typename... Args>
|
||||
py::object classmethod(Func f, Args... args) {
|
||||
py::object cf = py::cpp_function(f, args...);
|
||||
return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
|
||||
}
|
||||
|
||||
/// Checks whether the given type is an integer or float type.
|
||||
static int mlirTypeIsAIntegerOrFloat(MlirType type) {
|
||||
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
|
||||
|
@ -1027,6 +1034,267 @@ py::object PyOperation::createOpView() {
|
|||
return py::cast(PyOpView(getRef().getObject()));
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyOpView
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
py::object
|
||||
PyOpView::odsBuildDefault(py::object cls, py::list operandList,
|
||||
py::list resultTypeList,
|
||||
llvm::Optional<py::dict> attributes,
|
||||
llvm::Optional<std::vector<PyBlock *>> successors,
|
||||
llvm::Optional<int> regions,
|
||||
DefaultingPyLocation location, py::object maybeIp) {
|
||||
PyMlirContextRef context = location->getContext();
|
||||
// Class level operation construction metadata.
|
||||
std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
|
||||
// Operand and result segment specs are either none, which does no
|
||||
// variadic unpacking, or a list of ints with segment sizes, where each
|
||||
// element is either a positive number (typically 1 for a scalar) or -1 to
|
||||
// indicate that it is derived from the length of the same-indexed operand
|
||||
// or result (implying that it is a list at that position).
|
||||
py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
|
||||
py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
|
||||
|
||||
std::vector<uint64_t> operandSegmentLengths;
|
||||
std::vector<uint64_t> resultSegmentLengths;
|
||||
|
||||
// Validate/determine region count.
|
||||
auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
|
||||
int opMinRegionCount = std::get<0>(opRegionSpec);
|
||||
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
|
||||
if (!regions) {
|
||||
regions = opMinRegionCount;
|
||||
}
|
||||
if (*regions < opMinRegionCount) {
|
||||
throw py::value_error(
|
||||
(llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
|
||||
llvm::Twine(opMinRegionCount) +
|
||||
" regions but was built with regions=" + llvm::Twine(*regions))
|
||||
.str());
|
||||
}
|
||||
if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
|
||||
throw py::value_error(
|
||||
(llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
|
||||
llvm::Twine(opMinRegionCount) +
|
||||
" regions but was built with regions=" + llvm::Twine(*regions))
|
||||
.str());
|
||||
}
|
||||
|
||||
// Unpack results.
|
||||
std::vector<PyType *> resultTypes;
|
||||
resultTypes.reserve(resultTypeList.size());
|
||||
if (resultSegmentSpecObj.is_none()) {
|
||||
// Non-variadic result unpacking.
|
||||
for (auto it : llvm::enumerate(resultTypeList)) {
|
||||
try {
|
||||
resultTypes.push_back(py::cast<PyType *>(it.value()));
|
||||
if (!resultTypes.back())
|
||||
throw py::cast_error();
|
||||
} catch (py::cast_error &err) {
|
||||
throw py::value_error((llvm::Twine("Result ") +
|
||||
llvm::Twine(it.index()) + " of operation \"" +
|
||||
name + "\" must be a Type (" + err.what() + ")")
|
||||
.str());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Sized result unpacking.
|
||||
auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
|
||||
if (resultSegmentSpec.size() != resultTypeList.size()) {
|
||||
throw py::value_error((llvm::Twine("Operation \"") + name +
|
||||
"\" requires " +
|
||||
llvm::Twine(resultSegmentSpec.size()) +
|
||||
"result segments but was provided " +
|
||||
llvm::Twine(resultTypeList.size()))
|
||||
.str());
|
||||
}
|
||||
resultSegmentLengths.reserve(resultTypeList.size());
|
||||
for (auto it :
|
||||
llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
|
||||
int segmentSpec = std::get<1>(it.value());
|
||||
if (segmentSpec == 1 || segmentSpec == 0) {
|
||||
// Unpack unary element.
|
||||
try {
|
||||
auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
|
||||
if (resultType) {
|
||||
resultTypes.push_back(resultType);
|
||||
resultSegmentLengths.push_back(1);
|
||||
} else if (segmentSpec == 0) {
|
||||
// Allowed to be optional.
|
||||
resultSegmentLengths.push_back(0);
|
||||
} else {
|
||||
throw py::cast_error("was None and result is not optional");
|
||||
}
|
||||
} catch (py::cast_error &err) {
|
||||
throw py::value_error((llvm::Twine("Result ") +
|
||||
llvm::Twine(it.index()) + " of operation \"" +
|
||||
name + "\" must be a Type (" + err.what() +
|
||||
")")
|
||||
.str());
|
||||
}
|
||||
} else if (segmentSpec == -1) {
|
||||
// Unpack sequence by appending.
|
||||
try {
|
||||
if (std::get<0>(it.value()).is_none()) {
|
||||
// Treat it as an empty list.
|
||||
resultSegmentLengths.push_back(0);
|
||||
} else {
|
||||
// Unpack the list.
|
||||
auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
|
||||
for (py::object segmentItem : segment) {
|
||||
resultTypes.push_back(py::cast<PyType *>(segmentItem));
|
||||
if (!resultTypes.back()) {
|
||||
throw py::cast_error("contained a None item");
|
||||
}
|
||||
}
|
||||
resultSegmentLengths.push_back(segment.size());
|
||||
}
|
||||
} catch (std::exception &err) {
|
||||
// NOTE: Sloppy to be using a catch-all here, but there are at least
|
||||
// three different unrelated exceptions that can be thrown in the
|
||||
// above "casts". Just keep the scope above small and catch them all.
|
||||
throw py::value_error((llvm::Twine("Result ") +
|
||||
llvm::Twine(it.index()) + " of operation \"" +
|
||||
name + "\" must be a Sequence of Types (" +
|
||||
err.what() + ")")
|
||||
.str());
|
||||
}
|
||||
} else {
|
||||
throw py::value_error("Unexpected segment spec");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unpack operands.
|
||||
std::vector<PyValue *> operands;
|
||||
operands.reserve(operands.size());
|
||||
if (operandSegmentSpecObj.is_none()) {
|
||||
// Non-sized operand unpacking.
|
||||
for (auto it : llvm::enumerate(operandList)) {
|
||||
try {
|
||||
operands.push_back(py::cast<PyValue *>(it.value()));
|
||||
if (!operands.back())
|
||||
throw py::cast_error();
|
||||
} catch (py::cast_error &err) {
|
||||
throw py::value_error((llvm::Twine("Operand ") +
|
||||
llvm::Twine(it.index()) + " of operation \"" +
|
||||
name + "\" must be a Value (" + err.what() + ")")
|
||||
.str());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Sized operand unpacking.
|
||||
auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
|
||||
if (operandSegmentSpec.size() != operandList.size()) {
|
||||
throw py::value_error((llvm::Twine("Operation \"") + name +
|
||||
"\" requires " +
|
||||
llvm::Twine(operandSegmentSpec.size()) +
|
||||
"operand segments but was provided " +
|
||||
llvm::Twine(operandList.size()))
|
||||
.str());
|
||||
}
|
||||
operandSegmentLengths.reserve(operandList.size());
|
||||
for (auto it :
|
||||
llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
|
||||
int segmentSpec = std::get<1>(it.value());
|
||||
if (segmentSpec == 1 || segmentSpec == 0) {
|
||||
// Unpack unary element.
|
||||
try {
|
||||
auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
|
||||
if (operandValue) {
|
||||
operands.push_back(operandValue);
|
||||
operandSegmentLengths.push_back(1);
|
||||
} else if (segmentSpec == 0) {
|
||||
// Allowed to be optional.
|
||||
operandSegmentLengths.push_back(0);
|
||||
} else {
|
||||
throw py::cast_error("was None and operand is not optional");
|
||||
}
|
||||
} catch (py::cast_error &err) {
|
||||
throw py::value_error((llvm::Twine("Operand ") +
|
||||
llvm::Twine(it.index()) + " of operation \"" +
|
||||
name + "\" must be a Value (" + err.what() +
|
||||
")")
|
||||
.str());
|
||||
}
|
||||
} else if (segmentSpec == -1) {
|
||||
// Unpack sequence by appending.
|
||||
try {
|
||||
if (std::get<0>(it.value()).is_none()) {
|
||||
// Treat it as an empty list.
|
||||
operandSegmentLengths.push_back(0);
|
||||
} else {
|
||||
// Unpack the list.
|
||||
auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
|
||||
for (py::object segmentItem : segment) {
|
||||
operands.push_back(py::cast<PyValue *>(segmentItem));
|
||||
if (!operands.back()) {
|
||||
throw py::cast_error("contained a None item");
|
||||
}
|
||||
}
|
||||
operandSegmentLengths.push_back(segment.size());
|
||||
}
|
||||
} catch (std::exception &err) {
|
||||
// NOTE: Sloppy to be using a catch-all here, but there are at least
|
||||
// three different unrelated exceptions that can be thrown in the
|
||||
// above "casts". Just keep the scope above small and catch them all.
|
||||
throw py::value_error((llvm::Twine("Operand ") +
|
||||
llvm::Twine(it.index()) + " of operation \"" +
|
||||
name + "\" must be a Sequence of Values (" +
|
||||
err.what() + ")")
|
||||
.str());
|
||||
}
|
||||
} else {
|
||||
throw py::value_error("Unexpected segment spec");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge operand/result segment lengths into attributes if needed.
|
||||
if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
|
||||
// Dup.
|
||||
if (attributes) {
|
||||
attributes = py::dict(*attributes);
|
||||
} else {
|
||||
attributes = py::dict();
|
||||
}
|
||||
if (attributes->contains("result_segment_sizes") ||
|
||||
attributes->contains("operand_segment_sizes")) {
|
||||
throw py::value_error("Manually setting a 'result_segment_sizes' or "
|
||||
"'operand_segment_sizes' attribute is unsupported. "
|
||||
"Use Operation.create for such low-level access.");
|
||||
}
|
||||
|
||||
// Add result_segment_sizes attribute.
|
||||
if (!resultSegmentLengths.empty()) {
|
||||
int64_t size = resultSegmentLengths.size();
|
||||
MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
|
||||
mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
|
||||
resultSegmentLengths.size(), resultSegmentLengths.data());
|
||||
(*attributes)["result_segment_sizes"] =
|
||||
PyAttribute(context, segmentLengthAttr);
|
||||
}
|
||||
|
||||
// Add operand_segment_sizes attribute.
|
||||
if (!operandSegmentLengths.empty()) {
|
||||
int64_t size = operandSegmentLengths.size();
|
||||
MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
|
||||
mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
|
||||
operandSegmentLengths.size(), operandSegmentLengths.data());
|
||||
(*attributes)["operand_segment_sizes"] =
|
||||
PyAttribute(context, segmentLengthAttr);
|
||||
}
|
||||
}
|
||||
|
||||
// Delegate to create.
|
||||
return PyOperation::create(std::move(name), /*operands=*/std::move(operands),
|
||||
/*results=*/std::move(resultTypes),
|
||||
/*attributes=*/std::move(attributes),
|
||||
/*successors=*/std::move(successors),
|
||||
/*regions=*/*regions, location, maybeIp);
|
||||
}
|
||||
|
||||
PyOpView::PyOpView(py::object operationObject)
|
||||
// Casting through the PyOperationBase base-class and then back to the
|
||||
// Operation lets us accept any PyOperationBase subclass.
|
||||
|
@ -3397,17 +3665,29 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
"Context that owns the Operation")
|
||||
.def_property_readonly("opview", &PyOperation::createOpView);
|
||||
|
||||
py::class_<PyOpView, PyOperationBase>(m, "OpView")
|
||||
.def(py::init<py::object>())
|
||||
.def_property_readonly("operation", &PyOpView::getOperationObject)
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
[](PyOpView &self) {
|
||||
return self.getOperation().getContext().getObject();
|
||||
},
|
||||
"Context that owns the Operation")
|
||||
.def("__str__",
|
||||
[](PyOpView &self) { return py::str(self.getOperationObject()); });
|
||||
auto opViewClass =
|
||||
py::class_<PyOpView, PyOperationBase>(m, "OpView")
|
||||
.def(py::init<py::object>())
|
||||
.def_property_readonly("operation", &PyOpView::getOperationObject)
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
[](PyOpView &self) {
|
||||
return self.getOperation().getContext().getObject();
|
||||
},
|
||||
"Context that owns the Operation")
|
||||
.def("__str__", [](PyOpView &self) {
|
||||
return py::str(self.getOperationObject());
|
||||
});
|
||||
opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
|
||||
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
|
||||
opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
|
||||
opViewClass.attr("_ods_build_default") = classmethod(
|
||||
&PyOpView::odsBuildDefault, py::arg("cls"),
|
||||
py::arg("operands") = py::none(), py::arg("results") = py::none(),
|
||||
py::arg("attributes") = py::none(), py::arg("successors") = py::none(),
|
||||
py::arg("regions") = py::none(), py::arg("loc") = py::none(),
|
||||
py::arg("ip") = py::none(),
|
||||
"Builds a specific, generated OpView based on class level attributes.");
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyRegion.
|
||||
|
|
|
@ -497,6 +497,14 @@ public:
|
|||
|
||||
pybind11::object getOperationObject() { return operationObject; }
|
||||
|
||||
static pybind11::object
|
||||
odsBuildDefault(pybind11::object cls, pybind11::list operandList,
|
||||
pybind11::list resultTypeList,
|
||||
llvm::Optional<pybind11::dict> attributes,
|
||||
llvm::Optional<std::vector<PyBlock *>> successors,
|
||||
llvm::Optional<int> regions, DefaultingPyLocation location,
|
||||
pybind11::object maybeIp);
|
||||
|
||||
private:
|
||||
PyOperation &operation; // For efficient, cast-free access from C++
|
||||
pybind11::object operationObject; // Holds the reference.
|
||||
|
|
|
@ -0,0 +1,210 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import gc
|
||||
from mlir.ir import *
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
gc.collect()
|
||||
assert Context._get_live_count() == 0
|
||||
|
||||
|
||||
def add_dummy_value():
|
||||
return Operation.create(
|
||||
"custom.value",
|
||||
results=[IntegerType.get_signless(32)]).result
|
||||
|
||||
|
||||
def testOdsBuildDefaultImplicitRegions():
|
||||
|
||||
class TestFixedRegionsOp(OpView):
|
||||
OPERATION_NAME = "custom.test_op"
|
||||
_ODS_REGIONS = (2, True)
|
||||
|
||||
class TestVariadicRegionsOp(OpView):
|
||||
OPERATION_NAME = "custom.test_any_regions_op"
|
||||
_ODS_REGIONS = (2, False)
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
op = TestFixedRegionsOp._ods_build_default(operands=[], results=[])
|
||||
# CHECK: NUM_REGIONS: 2
|
||||
print(f"NUM_REGIONS: {len(op.regions)}")
|
||||
# Including a regions= that matches should be fine.
|
||||
op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=2)
|
||||
print(f"NUM_REGIONS: {len(op.regions)}")
|
||||
# Reject greater than.
|
||||
try:
|
||||
op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=3)
|
||||
except ValueError as e:
|
||||
# CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
|
||||
print(f"ERROR:{e}")
|
||||
# Reject less than.
|
||||
try:
|
||||
op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=1)
|
||||
except ValueError as e:
|
||||
# CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
|
||||
print(f"ERROR:{e}")
|
||||
|
||||
# If no regions specified for a variadic region op, build the minimum.
|
||||
op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[])
|
||||
# CHECK: DEFAULT_NUM_REGIONS: 2
|
||||
print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
|
||||
# Should also accept an explicit regions= that matches the minimum.
|
||||
op = TestVariadicRegionsOp._ods_build_default(
|
||||
operands=[], results=[], regions=2)
|
||||
# CHECK: EQ_NUM_REGIONS: 2
|
||||
print(f"EQ_NUM_REGIONS: {len(op.regions)}")
|
||||
# And accept greater than minimum.
|
||||
# Should also accept an explicit regions= that matches the minimum.
|
||||
op = TestVariadicRegionsOp._ods_build_default(
|
||||
operands=[], results=[], regions=3)
|
||||
# CHECK: GT_NUM_REGIONS: 3
|
||||
print(f"GT_NUM_REGIONS: {len(op.regions)}")
|
||||
# Should reject less than minimum.
|
||||
try:
|
||||
op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[], regions=1)
|
||||
except ValueError as e:
|
||||
# CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
|
||||
print(f"ERROR:{e}")
|
||||
|
||||
|
||||
|
||||
run(testOdsBuildDefaultImplicitRegions)
|
||||
|
||||
|
||||
def testOdsBuildDefaultNonVariadic():
|
||||
|
||||
class TestOp(OpView):
|
||||
OPERATION_NAME = "custom.test_op"
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
v0 = add_dummy_value()
|
||||
v1 = add_dummy_value()
|
||||
t0 = IntegerType.get_signless(8)
|
||||
t1 = IntegerType.get_signless(16)
|
||||
op = TestOp._ods_build_default(operands=[v0, v1], results=[t0, t1])
|
||||
# CHECK: %[[V0:.+]] = "custom.value"
|
||||
# CHECK: %[[V1:.+]] = "custom.value"
|
||||
# CHECK: "custom.test_op"(%[[V0]], %[[V1]])
|
||||
# CHECK-NOT: operand_segment_sizes
|
||||
# CHECK-NOT: result_segment_sizes
|
||||
# CHECK-SAME: : (i32, i32) -> (i8, i16)
|
||||
print(m)
|
||||
|
||||
run(testOdsBuildDefaultNonVariadic)
|
||||
|
||||
|
||||
def testOdsBuildDefaultSizedVariadic():
|
||||
|
||||
class TestOp(OpView):
|
||||
OPERATION_NAME = "custom.test_op"
|
||||
_ODS_OPERAND_SEGMENTS = [1, -1, 0]
|
||||
_ODS_RESULT_SEGMENTS = [-1, 0, 1]
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
v0 = add_dummy_value()
|
||||
v1 = add_dummy_value()
|
||||
v2 = add_dummy_value()
|
||||
v3 = add_dummy_value()
|
||||
t0 = IntegerType.get_signless(8)
|
||||
t1 = IntegerType.get_signless(16)
|
||||
t2 = IntegerType.get_signless(32)
|
||||
t3 = IntegerType.get_signless(64)
|
||||
# CHECK: %[[V0:.+]] = "custom.value"
|
||||
# CHECK: %[[V1:.+]] = "custom.value"
|
||||
# CHECK: %[[V2:.+]] = "custom.value"
|
||||
# CHECK: %[[V3:.+]] = "custom.value"
|
||||
# CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
|
||||
# CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi64>
|
||||
# CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi64>
|
||||
# CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
|
||||
op = TestOp._ods_build_default(
|
||||
operands=[v0, [v1, v2], v3],
|
||||
results=[[t0, t1], t2, t3])
|
||||
|
||||
# Now test with optional omitted.
|
||||
# CHECK: "custom.test_op"(%[[V0]])
|
||||
# CHECK-SAME: operand_segment_sizes = dense<[1, 0, 0]>
|
||||
# CHECK-SAME: result_segment_sizes = dense<[0, 0, 1]>
|
||||
# CHECK-SAME: (i32) -> i64
|
||||
op = TestOp._ods_build_default(
|
||||
operands=[v0, None, None],
|
||||
results=[None, None, t3])
|
||||
print(m)
|
||||
|
||||
# And verify that errors are raised for None in a required operand.
|
||||
try:
|
||||
op = TestOp._ods_build_default(
|
||||
operands=[None, None, None],
|
||||
results=[None, None, t3])
|
||||
except ValueError as e:
|
||||
# CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
|
||||
print(f"OPERAND_CAST_ERROR:{e}")
|
||||
|
||||
# And verify that errors are raised for None in a required result.
|
||||
try:
|
||||
op = TestOp._ods_build_default(
|
||||
operands=[v0, None, None],
|
||||
results=[None, None, None])
|
||||
except ValueError as e:
|
||||
# CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
|
||||
print(f"RESULT_CAST_ERROR:{e}")
|
||||
|
||||
# Variadic lists with None elements should reject.
|
||||
try:
|
||||
op = TestOp._ods_build_default(
|
||||
operands=[v0, [None], None],
|
||||
results=[None, None, t3])
|
||||
except ValueError as e:
|
||||
# CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
|
||||
print(f"OPERAND_LIST_CAST_ERROR:{e}")
|
||||
try:
|
||||
op = TestOp._ods_build_default(
|
||||
operands=[v0, None, None],
|
||||
results=[[None], None, t3])
|
||||
except ValueError as e:
|
||||
# CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
|
||||
print(f"RESULT_LIST_CAST_ERROR:{e}")
|
||||
|
||||
run(testOdsBuildDefaultSizedVariadic)
|
||||
|
||||
|
||||
def testOdsBuildDefaultCastError():
|
||||
|
||||
class TestOp(OpView):
|
||||
OPERATION_NAME = "custom.test_op"
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
v0 = add_dummy_value()
|
||||
v1 = add_dummy_value()
|
||||
t0 = IntegerType.get_signless(8)
|
||||
t1 = IntegerType.get_signless(16)
|
||||
try:
|
||||
op = TestOp._ods_build_default(
|
||||
operands=[None, v1],
|
||||
results=[t0, t1])
|
||||
except ValueError as e:
|
||||
# CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
|
||||
print(f"ERROR: {e}")
|
||||
try:
|
||||
op = TestOp._ods_build_default(
|
||||
operands=[v0, v1],
|
||||
results=[t0, None])
|
||||
except ValueError as e:
|
||||
# CHECK: Result 1 of operation "custom.test_op" must be a Type
|
||||
print(f"ERROR: {e}")
|
||||
|
||||
run(testOdsBuildDefaultCastError)
|
|
@ -17,23 +17,18 @@ class TestOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttrSizedOperandsOp(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
|
||||
// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,]
|
||||
def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
|
||||
[AttrSizedOperandSegments]> {
|
||||
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: operand_segment_sizes_ods = _ods_array.array('L')
|
||||
// CHECK: operands += [*variadic1]
|
||||
// CHECK: operand_segment_sizes_ods.append(len(variadic1))
|
||||
// CHECK: operands.append(variadic1)
|
||||
// CHECK: operands.append(non_variadic)
|
||||
// CHECK: operand_segment_sizes_ods.append(1)
|
||||
// CHECK: if variadic2 is not None: operands.append(variadic2)
|
||||
// CHECK: operand_segment_sizes_ods.append(0 if variadic2 is None else 1)
|
||||
// CHECK: attributes["operand_segment_sizes"] = _ods_ir.DenseElementsAttr.get(operand_segment_sizes_ods,
|
||||
// CHECK: context=_ods_get_default_loc_context(loc))
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.attr_sized_operands", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
|
@ -63,23 +58,18 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
|
|||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
|
||||
// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,]
|
||||
def AttrSizedResultsOp : TestOp<"attr_sized_results",
|
||||
[AttrSizedResultSegments]> {
|
||||
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: result_segment_sizes_ods = _ods_array.array('L')
|
||||
// CHECK: if variadic1 is not None: results.append(variadic1)
|
||||
// CHECK: result_segment_sizes_ods.append(0 if variadic1 is None else 1)
|
||||
// CHECK: results.append(non_variadic)
|
||||
// CHECK: result_segment_sizes_ods.append(1) # non_variadic
|
||||
// CHECK: if variadic2 is not None: results.append(variadic2)
|
||||
// CHECK: result_segment_sizes_ods.append(0 if variadic2 is None else 1)
|
||||
// CHECK: attributes["result_segment_sizes"] = _ods_ir.DenseElementsAttr.get(result_segment_sizes_ods,
|
||||
// CHECK: context=_ods_get_default_loc_context(loc))
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.attr_sized_results", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
|
@ -110,6 +100,8 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
|
|||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttributedOp(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op"
|
||||
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
|
||||
// CHECK-NOT: _ODS_RESULT_SEGMENTS
|
||||
def AttributedOp : TestOp<"attributed_op"> {
|
||||
// CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
|
@ -120,8 +112,8 @@ def AttributedOp : TestOp<"attributed_op"> {
|
|||
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
|
||||
// CHECK: _ods_get_default_loc_context(loc))
|
||||
// CHECK: attributes["in"] = in_
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
|
@ -148,6 +140,8 @@ def AttributedOp : TestOp<"attributed_op"> {
|
|||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
|
||||
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
|
||||
// CHECK-NOT: _ODS_RESULT_SEGMENTS
|
||||
def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
|
||||
// CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
|
@ -158,8 +152,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
|
|||
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
|
||||
// CHECK: _ods_get_default_loc_context(loc))
|
||||
// CHECK: if is_ is not None: attributes["is"] = is_
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
|
@ -183,8 +177,8 @@ def EmptyOp : TestOp<"empty">;
|
|||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.empty", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
|
@ -201,8 +195,8 @@ def MissingNamesOp : TestOp<"missing_names"> {
|
|||
// CHECK: operands.append(_gen_arg_0)
|
||||
// CHECK: operands.append(f32)
|
||||
// CHECK: operands.append(_gen_arg_2)
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.missing_names", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
|
@ -223,15 +217,17 @@ def MissingNamesOp : TestOp<"missing_names"> {
|
|||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
|
||||
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
|
||||
// CHECK-NOT: _ODS_RESULT_SEGMENTS
|
||||
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
|
||||
// CHECK: def __init__(self, non_variadic, variadic, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: operands.append(non_variadic)
|
||||
// CHECK: operands += [*variadic]
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.one_variadic_operand", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: operands.extend(variadic)
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
|
@ -248,15 +244,17 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
|
|||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
|
||||
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
|
||||
// CHECK-NOT: _ODS_RESULT_SEGMENTS
|
||||
def OneVariadicResultOp : TestOp<"one_variadic_result"> {
|
||||
// CHECK: def __init__(self, variadic, non_variadic, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: results += [*variadic]
|
||||
// CHECK: results.extend(variadic)
|
||||
// CHECK: results.append(non_variadic)
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.one_variadic_result", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
|
@ -280,8 +278,8 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
|
|||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: operands.append(in_)
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.python_keyword", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
|
@ -348,8 +346,8 @@ def SimpleOp : TestOp<"simple"> {
|
|||
// CHECK: results.append(f64)
|
||||
// CHECK: operands.append(i32)
|
||||
// CHECK: operands.append(f32)
|
||||
// CHECK: super().__init__(_ods_ir.Operation.create(
|
||||
// CHECK: "test.simple", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: super().__init__(self._ods_build_default(
|
||||
// CHECK: attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
|
|
|
@ -26,7 +26,6 @@ using namespace mlir::tblgen;
|
|||
constexpr const char *fileHeader = R"Py(
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
import array as _ods_array
|
||||
from . import _cext as _ods_cext
|
||||
from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context
|
||||
_ods_ir = _ods_cext.ir
|
||||
|
@ -51,6 +50,25 @@ class {0}(_ods_ir.OpView):
|
|||
OPERATION_NAME = "{1}"
|
||||
)Py";
|
||||
|
||||
/// Template for class level declarations of operand and result
|
||||
/// segment specs.
|
||||
/// {0} is either "OPERAND" or "RESULT"
|
||||
/// {1} is the segment spec
|
||||
/// Each segment spec is either None (default) or an array of integers
|
||||
/// where:
|
||||
/// 1 = single element (expect non sequence operand/result)
|
||||
/// -1 = operand/result is a sequence corresponding to a variadic
|
||||
constexpr const char *opClassSizedSegmentsTemplate = R"Py(
|
||||
_ODS_{0}_SEGMENTS = {1}
|
||||
)Py";
|
||||
|
||||
/// Template for class level declarations of the _ODS_REGIONS spec:
|
||||
/// {0} is the minimum number of regions
|
||||
/// {1} is the Python bool literal for hasNoVariadicRegions
|
||||
constexpr const char *opClassRegionSpecTemplate = R"Py(
|
||||
_ODS_REGIONS = ({0}, {1})
|
||||
)Py";
|
||||
|
||||
/// Template for single-element accessor:
|
||||
/// {0} is the name of the accessor;
|
||||
/// {1} is either 'operand' or 'result';
|
||||
|
@ -446,18 +464,17 @@ static void emitAttributeAccessors(const Operator &op,
|
|||
}
|
||||
|
||||
/// Template for the default auto-generated builder.
|
||||
/// {0} is the operation name;
|
||||
/// {1} is a comma-separated list of builder arguments, including the trailing
|
||||
/// {0} is a comma-separated list of builder arguments, including the trailing
|
||||
/// `loc` and `ip`;
|
||||
/// {2} is the code populating `operands`, `results` and `attributes` fields.
|
||||
/// {1} is the code populating `operands`, `results` and `attributes` fields.
|
||||
constexpr const char *initTemplate = R"Py(
|
||||
def __init__(self, {1}):
|
||||
def __init__(self, {0}):
|
||||
operands = []
|
||||
results = []
|
||||
attributes = {{}
|
||||
{2}
|
||||
super().__init__(_ods_ir.Operation.create(
|
||||
"{0}", attributes=attributes, operands=operands, results=results,
|
||||
{1}
|
||||
super().__init__(self._ods_build_default(
|
||||
attributes=attributes, operands=operands, results=results,
|
||||
loc=loc, ip=ip))
|
||||
)Py";
|
||||
|
||||
|
@ -472,37 +489,10 @@ constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
|
|||
constexpr const char *optionalAppendTemplate =
|
||||
"if {1} is not None: {0}s.append({1})";
|
||||
|
||||
/// Template for appending a variadic element to the operand/result list.
|
||||
/// Template for appending a a list of elements to the operand/result list.
|
||||
/// {0} is either 'operand' or 'result';
|
||||
/// {1} is the field name.
|
||||
constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]";
|
||||
|
||||
/// Template for setting up the segment sizes buffer.
|
||||
constexpr const char *segmentDeclarationTemplate =
|
||||
"{0}_segment_sizes_ods = _ods_array.array('L')";
|
||||
|
||||
/// Template for attaching segment sizes to the attribute list.
|
||||
constexpr const char *segmentAttributeTemplate =
|
||||
R"Py(attributes["{0}_segment_sizes"] = _ods_ir.DenseElementsAttr.get({0}_segment_sizes_ods,
|
||||
context=_ods_get_default_loc_context(loc)))Py";
|
||||
|
||||
/// Template for appending the unit size to the segment sizes.
|
||||
/// {0} is either 'operand' or 'result';
|
||||
/// {1} is the field name.
|
||||
constexpr const char *singleElementSegmentTemplate =
|
||||
"{0}_segment_sizes_ods.append(1) # {1}";
|
||||
|
||||
/// Template for appending 0/1 for an optional element to the segment sizes.
|
||||
/// {0} is either 'operand' or 'result';
|
||||
/// {1} is the field name.
|
||||
constexpr const char *optionalSegmentTemplate =
|
||||
"{0}_segment_sizes_ods.append(0 if {1} is None else 1)";
|
||||
|
||||
/// Template for appending the length of a variadic group to the segment sizes.
|
||||
/// {0} is either 'operand' or 'result';
|
||||
/// {1} is the field name.
|
||||
constexpr const char *variadicSegmentTemplate =
|
||||
"{0}_segment_sizes_ods.append(len({1}))";
|
||||
constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})";
|
||||
|
||||
/// Template for setting an attribute in the operation builder.
|
||||
/// {0} is the attribute name;
|
||||
|
@ -584,11 +574,7 @@ static void populateBuilderLines(
|
|||
llvm::function_ref<int(const Operator &)> getNumElements,
|
||||
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
||||
getElement) {
|
||||
// The segment sizes buffer only has to be populated if there attr-sized
|
||||
// segments trait is present.
|
||||
bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
|
||||
if (includeSegments)
|
||||
builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind));
|
||||
bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
|
||||
|
||||
// For each element, find or generate a name.
|
||||
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
||||
|
@ -596,28 +582,28 @@ static void populateBuilderLines(
|
|||
std::string name = names[i];
|
||||
|
||||
// Choose the formatting string based on the element kind.
|
||||
llvm::StringRef formatString, segmentFormatString;
|
||||
llvm::StringRef formatString;
|
||||
if (!element.isVariableLength()) {
|
||||
formatString = singleElementAppendTemplate;
|
||||
segmentFormatString = singleElementSegmentTemplate;
|
||||
} else if (element.isOptional()) {
|
||||
formatString = optionalAppendTemplate;
|
||||
segmentFormatString = optionalSegmentTemplate;
|
||||
} else {
|
||||
assert(element.isVariadic() && "unhandled element group type");
|
||||
formatString = variadicAppendTemplate;
|
||||
segmentFormatString = variadicSegmentTemplate;
|
||||
// If emitting with sizedSegments, then we add the actual list typed
|
||||
// element using the singleElementAppendTemplate. Otherwise, we extend
|
||||
// the actual operands.
|
||||
if (sizedSegments) {
|
||||
// Append the list as is.
|
||||
formatString = singleElementAppendTemplate;
|
||||
} else {
|
||||
// Append the list elements.
|
||||
formatString = multiElementAppendTemplate;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the lines.
|
||||
builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
|
||||
if (includeSegments)
|
||||
builderLines.push_back(
|
||||
llvm::formatv(segmentFormatString.data(), kind, name));
|
||||
}
|
||||
|
||||
if (includeSegments)
|
||||
builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind));
|
||||
}
|
||||
|
||||
/// Emits a default builder constructing an operation from the list of its
|
||||
|
@ -645,8 +631,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
|
|||
|
||||
builderArgs.push_back("loc=None");
|
||||
builderArgs.push_back("ip=None");
|
||||
os << llvm::formatv(initTemplate, op.getOperationName(),
|
||||
llvm::join(builderArgs, ", "),
|
||||
os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),
|
||||
llvm::join(builderLines, "\n "));
|
||||
}
|
||||
|
||||
|
@ -659,12 +644,52 @@ static void constructAttributeMapping(const llvm::RecordKeeper &records,
|
|||
}
|
||||
}
|
||||
|
||||
static void emitSegmentSpec(
|
||||
const Operator &op, const char *kind,
|
||||
llvm::function_ref<int(const Operator &)> getNumElements,
|
||||
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
||||
getElement,
|
||||
raw_ostream &os) {
|
||||
std::string segmentSpec("[");
|
||||
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
if (element.isVariableLength()) {
|
||||
segmentSpec.append("-1,");
|
||||
} else if (element.isOptional()) {
|
||||
segmentSpec.append("0,");
|
||||
} else {
|
||||
segmentSpec.append("1,");
|
||||
}
|
||||
}
|
||||
segmentSpec.append("]");
|
||||
|
||||
os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
|
||||
}
|
||||
|
||||
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
|
||||
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
|
||||
// Note that the base OpView class defines this as (0, True).
|
||||
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
|
||||
os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
|
||||
op.hasNoVariadicRegions() ? "True" : "False");
|
||||
}
|
||||
|
||||
/// Emits bindings for a specific Op to the given output stream.
|
||||
static void emitOpBindings(const Operator &op,
|
||||
const AttributeClasses &attributeClasses,
|
||||
raw_ostream &os) {
|
||||
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
|
||||
op.getOperationName());
|
||||
|
||||
// Sized segments.
|
||||
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
|
||||
emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
|
||||
}
|
||||
if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
|
||||
emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
|
||||
}
|
||||
|
||||
emitRegionAttributes(op, os);
|
||||
emitDefaultOpBuilder(op, os);
|
||||
emitOperandAccessors(op, os);
|
||||
emitAttributeAccessors(op, attributeClasses, os);
|
||||
|
|
Loading…
Reference in New Issue