[mlir][python] Factor out standalone OpView._ods_build_default class method.

* This allows us to hoist trait level information for regions and sized-variadic to class level attributes (_ODS_REGIONS, _ODS_OPERAND_SEGMENTS, _ODS_RESULT_SEGMENTS).
* Eliminates some splicey python generated code in favor of a native helper for it.
* Makes it possible to implement custom, variadic and region based builders with one line of python, without needing to manually code access to the segment attributes.
* Needs follow-on work for region based callbacks and support for SingleBlockImplicitTerminator.
* A follow-up will actually add ODS support for generating custom Python builders that delegate to this new method.
* Also includes the start of an e2e sample for constructing linalg ops where this limitation was discovered (working progressively through this example and cleaning up as I go).

Differential Revision: https://reviews.llvm.org/D94738
This commit is contained in:
Stella Laurenzo 2021-01-18 11:27:19 -08:00
parent cbdde495ba
commit 71b6b010e6
7 changed files with 713 additions and 104 deletions

View File

@ -365,7 +365,7 @@ for the canonical way to use this facility.
Each dialect with a mapping to python requires that an appropriate
`{DIALECT_NAMESPACE}.py` wrapper module is created. This is done by invoking
`mlir-tablegen` on a python-bindings specific tablegen wrapper that includes
`mlir-tblgen` on a python-bindings specific tablegen wrapper that includes
the boilerplate and actual dialect specific `td` file. An example, for the
`StandardOps` (which is assigned the namespace `std` as a special case):
@ -383,7 +383,7 @@ In the main repository, building the wrapper is done via the CMake function
`add_mlir_dialect_python_bindings`, which invokes:
```
mlir-tablegen -gen-python-op-bindings -bind-dialect={DIALECT_NAMESPACE} \
mlir-tblgen -gen-python-op-bindings -bind-dialect={DIALECT_NAMESPACE} \
{PYTHON_BINDING_TD_FILE}
```
@ -411,7 +411,8 @@ The wrapper module tablegen emitter outputs:
Note: In order to avoid naming conflicts, all internal names used by the wrapper
module are prefixed by `_ods_`.
Each concrete `OpView` subclass further defines several attributes:
Each concrete `OpView` subclass further defines several public-intended
attributes:
* `OPERATION_NAME` attribute with the `str` fully qualified operation name
(i.e. `std.absf`).
@ -421,6 +422,20 @@ Each concrete `OpView` subclass further defines several attributes:
for unnamed of each).
* `@property` getter, setter and deleter for each declared attribute.
It further emits additional private-intended attributes meant for subclassing
and customization (default cases omit these attributes in favor of the
defaults on `OpView`):
* `_ODS_REGIONS`: A specification on the number and types of regions.
Currently a tuple of (min_region_count, has_no_variadic_regions). Note that
the API does some light validation on this but the primary purpose is to
capture sufficient information to perform other default building and region
accessor generation.
* `_ODS_OPERAND_SEGMENTS` and `_ODS_RESULT_SEGMENTS`: Black-box value which
indicates the structure of either the operand or results with respect to
variadics. Used by `OpView._ods_build_default` to decode operand and result
lists that contain lists.
#### Builders
Presently, only a single, default builder is mapped to the `__init__` method.

View File

@ -0,0 +1,73 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This is a work in progress example to do end2end build and code generation
# of a small linalg program with configuration options. It is currently non
# functional and is being used to elaborate the APIs.
from typing import Tuple
from mlir.ir import *
from mlir.dialects import linalg
from mlir.dialects import std
# TODO: This should be in the core API.
def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
"""Creates a |func| op.
TODO: This should really be in the MLIR API.
Returns:
(operation, entry_block)
"""
attrs = {
"type": TypeAttr.get(func_type),
"sym_name": StringAttr.get(name),
}
op = Operation.create("func", regions=1, attributes=attrs)
body_region = op.regions[0]
entry_block = body_region.blocks.append(*func_type.inputs)
return op, entry_block
# TODO: Generate customs builder vs patching one in.
def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None):
super(linalg.MatmulOp, self).__init__(
self._ods_build_default(operands=[[lhs, rhs], [result]],
results=[],
loc=loc,
ip=ip))
# TODO: Implement support for SingleBlockImplicitTerminator
block = self.regions[0].blocks.append()
with InsertionPoint(block):
linalg.YieldOp(values=[])
linalg.MatmulOp.__init__ = PatchMatmulOpInit
def build_matmul_func(func_name, m, k, n, dtype):
lhs_type = MemRefType.get(dtype, [m, k])
rhs_type = MemRefType.get(dtype, [k, n])
result_type = MemRefType.get(dtype, [m, n])
# TODO: There should be a one-liner for this.
func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
_, entry = FuncOp(func_name, func_type)
lhs, rhs, result = entry.arguments
with InsertionPoint(entry):
linalg.MatmulOp(lhs, rhs, result)
std.ReturnOp([])
def run():
with Context() as c, Location.unknown():
module = Module.create()
# TODO: This at_block_terminator vs default construct distinction feels
# wrong and is error-prone.
with InsertionPoint.at_block_terminator(module.body):
build_matmul_func('main', 18, 32, 96, F32Type.get())
print(module)
print(module.operation.get_asm(print_generic_op_form=True))
if __name__ == '__main__': run()

View File

@ -130,6 +130,13 @@ equivalent to printing the operation that produced it.
// Utilities.
//------------------------------------------------------------------------------
// Helper for creating an @classmethod.
template <class Func, typename... Args>
py::object classmethod(Func f, Args... args) {
py::object cf = py::cpp_function(f, args...);
return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
}
/// Checks whether the given type is an integer or float type.
static int mlirTypeIsAIntegerOrFloat(MlirType type) {
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
@ -1027,6 +1034,267 @@ py::object PyOperation::createOpView() {
return py::cast(PyOpView(getRef().getObject()));
}
//------------------------------------------------------------------------------
// PyOpView
//------------------------------------------------------------------------------
py::object
PyOpView::odsBuildDefault(py::object cls, py::list operandList,
py::list resultTypeList,
llvm::Optional<py::dict> attributes,
llvm::Optional<std::vector<PyBlock *>> successors,
llvm::Optional<int> regions,
DefaultingPyLocation location, py::object maybeIp) {
PyMlirContextRef context = location->getContext();
// Class level operation construction metadata.
std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
// Operand and result segment specs are either none, which does no
// variadic unpacking, or a list of ints with segment sizes, where each
// element is either a positive number (typically 1 for a scalar) or -1 to
// indicate that it is derived from the length of the same-indexed operand
// or result (implying that it is a list at that position).
py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
std::vector<uint64_t> operandSegmentLengths;
std::vector<uint64_t> resultSegmentLengths;
// Validate/determine region count.
auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
int opMinRegionCount = std::get<0>(opRegionSpec);
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
if (!regions) {
regions = opMinRegionCount;
}
if (*regions < opMinRegionCount) {
throw py::value_error(
(llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
llvm::Twine(opMinRegionCount) +
" regions but was built with regions=" + llvm::Twine(*regions))
.str());
}
if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
throw py::value_error(
(llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
llvm::Twine(opMinRegionCount) +
" regions but was built with regions=" + llvm::Twine(*regions))
.str());
}
// Unpack results.
std::vector<PyType *> resultTypes;
resultTypes.reserve(resultTypeList.size());
if (resultSegmentSpecObj.is_none()) {
// Non-variadic result unpacking.
for (auto it : llvm::enumerate(resultTypeList)) {
try {
resultTypes.push_back(py::cast<PyType *>(it.value()));
if (!resultTypes.back())
throw py::cast_error();
} catch (py::cast_error &err) {
throw py::value_error((llvm::Twine("Result ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Type (" + err.what() + ")")
.str());
}
}
} else {
// Sized result unpacking.
auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
if (resultSegmentSpec.size() != resultTypeList.size()) {
throw py::value_error((llvm::Twine("Operation \"") + name +
"\" requires " +
llvm::Twine(resultSegmentSpec.size()) +
"result segments but was provided " +
llvm::Twine(resultTypeList.size()))
.str());
}
resultSegmentLengths.reserve(resultTypeList.size());
for (auto it :
llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
int segmentSpec = std::get<1>(it.value());
if (segmentSpec == 1 || segmentSpec == 0) {
// Unpack unary element.
try {
auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
if (resultType) {
resultTypes.push_back(resultType);
resultSegmentLengths.push_back(1);
} else if (segmentSpec == 0) {
// Allowed to be optional.
resultSegmentLengths.push_back(0);
} else {
throw py::cast_error("was None and result is not optional");
}
} catch (py::cast_error &err) {
throw py::value_error((llvm::Twine("Result ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Type (" + err.what() +
")")
.str());
}
} else if (segmentSpec == -1) {
// Unpack sequence by appending.
try {
if (std::get<0>(it.value()).is_none()) {
// Treat it as an empty list.
resultSegmentLengths.push_back(0);
} else {
// Unpack the list.
auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
for (py::object segmentItem : segment) {
resultTypes.push_back(py::cast<PyType *>(segmentItem));
if (!resultTypes.back()) {
throw py::cast_error("contained a None item");
}
}
resultSegmentLengths.push_back(segment.size());
}
} catch (std::exception &err) {
// NOTE: Sloppy to be using a catch-all here, but there are at least
// three different unrelated exceptions that can be thrown in the
// above "casts". Just keep the scope above small and catch them all.
throw py::value_error((llvm::Twine("Result ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Sequence of Types (" +
err.what() + ")")
.str());
}
} else {
throw py::value_error("Unexpected segment spec");
}
}
}
// Unpack operands.
std::vector<PyValue *> operands;
operands.reserve(operands.size());
if (operandSegmentSpecObj.is_none()) {
// Non-sized operand unpacking.
for (auto it : llvm::enumerate(operandList)) {
try {
operands.push_back(py::cast<PyValue *>(it.value()));
if (!operands.back())
throw py::cast_error();
} catch (py::cast_error &err) {
throw py::value_error((llvm::Twine("Operand ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Value (" + err.what() + ")")
.str());
}
}
} else {
// Sized operand unpacking.
auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
if (operandSegmentSpec.size() != operandList.size()) {
throw py::value_error((llvm::Twine("Operation \"") + name +
"\" requires " +
llvm::Twine(operandSegmentSpec.size()) +
"operand segments but was provided " +
llvm::Twine(operandList.size()))
.str());
}
operandSegmentLengths.reserve(operandList.size());
for (auto it :
llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
int segmentSpec = std::get<1>(it.value());
if (segmentSpec == 1 || segmentSpec == 0) {
// Unpack unary element.
try {
auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
if (operandValue) {
operands.push_back(operandValue);
operandSegmentLengths.push_back(1);
} else if (segmentSpec == 0) {
// Allowed to be optional.
operandSegmentLengths.push_back(0);
} else {
throw py::cast_error("was None and operand is not optional");
}
} catch (py::cast_error &err) {
throw py::value_error((llvm::Twine("Operand ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Value (" + err.what() +
")")
.str());
}
} else if (segmentSpec == -1) {
// Unpack sequence by appending.
try {
if (std::get<0>(it.value()).is_none()) {
// Treat it as an empty list.
operandSegmentLengths.push_back(0);
} else {
// Unpack the list.
auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
for (py::object segmentItem : segment) {
operands.push_back(py::cast<PyValue *>(segmentItem));
if (!operands.back()) {
throw py::cast_error("contained a None item");
}
}
operandSegmentLengths.push_back(segment.size());
}
} catch (std::exception &err) {
// NOTE: Sloppy to be using a catch-all here, but there are at least
// three different unrelated exceptions that can be thrown in the
// above "casts". Just keep the scope above small and catch them all.
throw py::value_error((llvm::Twine("Operand ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Sequence of Values (" +
err.what() + ")")
.str());
}
} else {
throw py::value_error("Unexpected segment spec");
}
}
}
// Merge operand/result segment lengths into attributes if needed.
if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
// Dup.
if (attributes) {
attributes = py::dict(*attributes);
} else {
attributes = py::dict();
}
if (attributes->contains("result_segment_sizes") ||
attributes->contains("operand_segment_sizes")) {
throw py::value_error("Manually setting a 'result_segment_sizes' or "
"'operand_segment_sizes' attribute is unsupported. "
"Use Operation.create for such low-level access.");
}
// Add result_segment_sizes attribute.
if (!resultSegmentLengths.empty()) {
int64_t size = resultSegmentLengths.size();
MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
resultSegmentLengths.size(), resultSegmentLengths.data());
(*attributes)["result_segment_sizes"] =
PyAttribute(context, segmentLengthAttr);
}
// Add operand_segment_sizes attribute.
if (!operandSegmentLengths.empty()) {
int64_t size = operandSegmentLengths.size();
MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
operandSegmentLengths.size(), operandSegmentLengths.data());
(*attributes)["operand_segment_sizes"] =
PyAttribute(context, segmentLengthAttr);
}
}
// Delegate to create.
return PyOperation::create(std::move(name), /*operands=*/std::move(operands),
/*results=*/std::move(resultTypes),
/*attributes=*/std::move(attributes),
/*successors=*/std::move(successors),
/*regions=*/*regions, location, maybeIp);
}
PyOpView::PyOpView(py::object operationObject)
// Casting through the PyOperationBase base-class and then back to the
// Operation lets us accept any PyOperationBase subclass.
@ -3397,17 +3665,29 @@ void mlir::python::populateIRSubmodule(py::module &m) {
"Context that owns the Operation")
.def_property_readonly("opview", &PyOperation::createOpView);
py::class_<PyOpView, PyOperationBase>(m, "OpView")
.def(py::init<py::object>())
.def_property_readonly("operation", &PyOpView::getOperationObject)
.def_property_readonly(
"context",
[](PyOpView &self) {
return self.getOperation().getContext().getObject();
},
"Context that owns the Operation")
.def("__str__",
[](PyOpView &self) { return py::str(self.getOperationObject()); });
auto opViewClass =
py::class_<PyOpView, PyOperationBase>(m, "OpView")
.def(py::init<py::object>())
.def_property_readonly("operation", &PyOpView::getOperationObject)
.def_property_readonly(
"context",
[](PyOpView &self) {
return self.getOperation().getContext().getObject();
},
"Context that owns the Operation")
.def("__str__", [](PyOpView &self) {
return py::str(self.getOperationObject());
});
opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
opViewClass.attr("_ods_build_default") = classmethod(
&PyOpView::odsBuildDefault, py::arg("cls"),
py::arg("operands") = py::none(), py::arg("results") = py::none(),
py::arg("attributes") = py::none(), py::arg("successors") = py::none(),
py::arg("regions") = py::none(), py::arg("loc") = py::none(),
py::arg("ip") = py::none(),
"Builds a specific, generated OpView based on class level attributes.");
//----------------------------------------------------------------------------
// Mapping of PyRegion.

View File

@ -497,6 +497,14 @@ public:
pybind11::object getOperationObject() { return operationObject; }
static pybind11::object
odsBuildDefault(pybind11::object cls, pybind11::list operandList,
pybind11::list resultTypeList,
llvm::Optional<pybind11::dict> attributes,
llvm::Optional<std::vector<PyBlock *>> successors,
llvm::Optional<int> regions, DefaultingPyLocation location,
pybind11::object maybeIp);
private:
PyOperation &operation; // For efficient, cast-free access from C++
pybind11::object operationObject; // Holds the reference.

View File

@ -0,0 +1,210 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
def add_dummy_value():
return Operation.create(
"custom.value",
results=[IntegerType.get_signless(32)]).result
def testOdsBuildDefaultImplicitRegions():
class TestFixedRegionsOp(OpView):
OPERATION_NAME = "custom.test_op"
_ODS_REGIONS = (2, True)
class TestVariadicRegionsOp(OpView):
OPERATION_NAME = "custom.test_any_regions_op"
_ODS_REGIONS = (2, False)
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
m = Module.create()
with InsertionPoint.at_block_terminator(m.body):
op = TestFixedRegionsOp._ods_build_default(operands=[], results=[])
# CHECK: NUM_REGIONS: 2
print(f"NUM_REGIONS: {len(op.regions)}")
# Including a regions= that matches should be fine.
op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=2)
print(f"NUM_REGIONS: {len(op.regions)}")
# Reject greater than.
try:
op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=3)
except ValueError as e:
# CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
print(f"ERROR:{e}")
# Reject less than.
try:
op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=1)
except ValueError as e:
# CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
print(f"ERROR:{e}")
# If no regions specified for a variadic region op, build the minimum.
op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[])
# CHECK: DEFAULT_NUM_REGIONS: 2
print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
# Should also accept an explicit regions= that matches the minimum.
op = TestVariadicRegionsOp._ods_build_default(
operands=[], results=[], regions=2)
# CHECK: EQ_NUM_REGIONS: 2
print(f"EQ_NUM_REGIONS: {len(op.regions)}")
# And accept greater than minimum.
# Should also accept an explicit regions= that matches the minimum.
op = TestVariadicRegionsOp._ods_build_default(
operands=[], results=[], regions=3)
# CHECK: GT_NUM_REGIONS: 3
print(f"GT_NUM_REGIONS: {len(op.regions)}")
# Should reject less than minimum.
try:
op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[], regions=1)
except ValueError as e:
# CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
print(f"ERROR:{e}")
run(testOdsBuildDefaultImplicitRegions)
def testOdsBuildDefaultNonVariadic():
class TestOp(OpView):
OPERATION_NAME = "custom.test_op"
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
m = Module.create()
with InsertionPoint.at_block_terminator(m.body):
v0 = add_dummy_value()
v1 = add_dummy_value()
t0 = IntegerType.get_signless(8)
t1 = IntegerType.get_signless(16)
op = TestOp._ods_build_default(operands=[v0, v1], results=[t0, t1])
# CHECK: %[[V0:.+]] = "custom.value"
# CHECK: %[[V1:.+]] = "custom.value"
# CHECK: "custom.test_op"(%[[V0]], %[[V1]])
# CHECK-NOT: operand_segment_sizes
# CHECK-NOT: result_segment_sizes
# CHECK-SAME: : (i32, i32) -> (i8, i16)
print(m)
run(testOdsBuildDefaultNonVariadic)
def testOdsBuildDefaultSizedVariadic():
class TestOp(OpView):
OPERATION_NAME = "custom.test_op"
_ODS_OPERAND_SEGMENTS = [1, -1, 0]
_ODS_RESULT_SEGMENTS = [-1, 0, 1]
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
m = Module.create()
with InsertionPoint.at_block_terminator(m.body):
v0 = add_dummy_value()
v1 = add_dummy_value()
v2 = add_dummy_value()
v3 = add_dummy_value()
t0 = IntegerType.get_signless(8)
t1 = IntegerType.get_signless(16)
t2 = IntegerType.get_signless(32)
t3 = IntegerType.get_signless(64)
# CHECK: %[[V0:.+]] = "custom.value"
# CHECK: %[[V1:.+]] = "custom.value"
# CHECK: %[[V2:.+]] = "custom.value"
# CHECK: %[[V3:.+]] = "custom.value"
# CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
# CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi64>
# CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi64>
# CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
op = TestOp._ods_build_default(
operands=[v0, [v1, v2], v3],
results=[[t0, t1], t2, t3])
# Now test with optional omitted.
# CHECK: "custom.test_op"(%[[V0]])
# CHECK-SAME: operand_segment_sizes = dense<[1, 0, 0]>
# CHECK-SAME: result_segment_sizes = dense<[0, 0, 1]>
# CHECK-SAME: (i32) -> i64
op = TestOp._ods_build_default(
operands=[v0, None, None],
results=[None, None, t3])
print(m)
# And verify that errors are raised for None in a required operand.
try:
op = TestOp._ods_build_default(
operands=[None, None, None],
results=[None, None, t3])
except ValueError as e:
# CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
print(f"OPERAND_CAST_ERROR:{e}")
# And verify that errors are raised for None in a required result.
try:
op = TestOp._ods_build_default(
operands=[v0, None, None],
results=[None, None, None])
except ValueError as e:
# CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
print(f"RESULT_CAST_ERROR:{e}")
# Variadic lists with None elements should reject.
try:
op = TestOp._ods_build_default(
operands=[v0, [None], None],
results=[None, None, t3])
except ValueError as e:
# CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
print(f"OPERAND_LIST_CAST_ERROR:{e}")
try:
op = TestOp._ods_build_default(
operands=[v0, None, None],
results=[[None], None, t3])
except ValueError as e:
# CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
print(f"RESULT_LIST_CAST_ERROR:{e}")
run(testOdsBuildDefaultSizedVariadic)
def testOdsBuildDefaultCastError():
class TestOp(OpView):
OPERATION_NAME = "custom.test_op"
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
m = Module.create()
with InsertionPoint.at_block_terminator(m.body):
v0 = add_dummy_value()
v1 = add_dummy_value()
t0 = IntegerType.get_signless(8)
t1 = IntegerType.get_signless(16)
try:
op = TestOp._ods_build_default(
operands=[None, v1],
results=[t0, t1])
except ValueError as e:
# CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
print(f"ERROR: {e}")
try:
op = TestOp._ods_build_default(
operands=[v0, v1],
results=[t0, None])
except ValueError as e:
# CHECK: Result 1 of operation "custom.test_op" must be a Type
print(f"ERROR: {e}")
run(testOdsBuildDefaultCastError)

View File

@ -17,23 +17,18 @@ class TestOp<string mnemonic, list<OpTrait> traits = []> :
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedOperandsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,]
def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
[AttrSizedOperandSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operand_segment_sizes_ods = _ods_array.array('L')
// CHECK: operands += [*variadic1]
// CHECK: operand_segment_sizes_ods.append(len(variadic1))
// CHECK: operands.append(variadic1)
// CHECK: operands.append(non_variadic)
// CHECK: operand_segment_sizes_ods.append(1)
// CHECK: if variadic2 is not None: operands.append(variadic2)
// CHECK: operand_segment_sizes_ods.append(0 if variadic2 is None else 1)
// CHECK: attributes["operand_segment_sizes"] = _ods_ir.DenseElementsAttr.get(operand_segment_sizes_ods,
// CHECK: context=_ods_get_default_loc_context(loc))
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.attr_sized_operands", attributes=attributes, operands=operands, results=results,
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@ -63,23 +58,18 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,]
def AttrSizedResultsOp : TestOp<"attr_sized_results",
[AttrSizedResultSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: result_segment_sizes_ods = _ods_array.array('L')
// CHECK: if variadic1 is not None: results.append(variadic1)
// CHECK: result_segment_sizes_ods.append(0 if variadic1 is None else 1)
// CHECK: results.append(non_variadic)
// CHECK: result_segment_sizes_ods.append(1) # non_variadic
// CHECK: if variadic2 is not None: results.append(variadic2)
// CHECK: result_segment_sizes_ods.append(0 if variadic2 is None else 1)
// CHECK: attributes["result_segment_sizes"] = _ods_ir.DenseElementsAttr.get(result_segment_sizes_ods,
// CHECK: context=_ods_get_default_loc_context(loc))
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.attr_sized_results", attributes=attributes, operands=operands, results=results,
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@ -110,6 +100,8 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOp : TestOp<"attributed_op"> {
// CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None):
// CHECK: operands = []
@ -120,8 +112,8 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: attributes["in"] = in_
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results,
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@ -148,6 +140,8 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None):
// CHECK: operands = []
@ -158,8 +152,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = is_
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results,
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@ -183,8 +177,8 @@ def EmptyOp : TestOp<"empty">;
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.empty", attributes=attributes, operands=operands, results=results,
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @_ods_cext.register_operation(_Dialect)
@ -201,8 +195,8 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: operands.append(_gen_arg_0)
// CHECK: operands.append(f32)
// CHECK: operands.append(_gen_arg_2)
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.missing_names", attributes=attributes, operands=operands, results=results,
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@ -223,15 +217,17 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: def __init__(self, non_variadic, variadic, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(non_variadic)
// CHECK: operands += [*variadic]
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.one_variadic_operand", attributes=attributes, operands=operands, results=results,
// CHECK: operands.extend(variadic)
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@ -248,15 +244,17 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK: def __init__(self, variadic, non_variadic, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: results += [*variadic]
// CHECK: results.extend(variadic)
// CHECK: results.append(non_variadic)
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.one_variadic_result", attributes=attributes, operands=operands, results=results,
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@ -280,8 +278,8 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(in_)
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.python_keyword", attributes=attributes, operands=operands, results=results,
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property
@ -348,8 +346,8 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: results.append(f64)
// CHECK: operands.append(i32)
// CHECK: operands.append(f32)
// CHECK: super().__init__(_ods_ir.Operation.create(
// CHECK: "test.simple", attributes=attributes, operands=operands, results=results,
// CHECK: super().__init__(self._ods_build_default(
// CHECK: attributes=attributes, operands=operands, results=results,
// CHECK: loc=loc, ip=ip))
// CHECK: @property

View File

@ -26,7 +26,6 @@ using namespace mlir::tblgen;
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
import array as _ods_array
from . import _cext as _ods_cext
from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context
_ods_ir = _ods_cext.ir
@ -51,6 +50,25 @@ class {0}(_ods_ir.OpView):
OPERATION_NAME = "{1}"
)Py";
/// Template for class level declarations of operand and result
/// segment specs.
/// {0} is either "OPERAND" or "RESULT"
/// {1} is the segment spec
/// Each segment spec is either None (default) or an array of integers
/// where:
/// 1 = single element (expect non sequence operand/result)
/// -1 = operand/result is a sequence corresponding to a variadic
constexpr const char *opClassSizedSegmentsTemplate = R"Py(
_ODS_{0}_SEGMENTS = {1}
)Py";
/// Template for class level declarations of the _ODS_REGIONS spec:
/// {0} is the minimum number of regions
/// {1} is the Python bool literal for hasNoVariadicRegions
constexpr const char *opClassRegionSpecTemplate = R"Py(
_ODS_REGIONS = ({0}, {1})
)Py";
/// Template for single-element accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
@ -446,18 +464,17 @@ static void emitAttributeAccessors(const Operator &op,
}
/// Template for the default auto-generated builder.
/// {0} is the operation name;
/// {1} is a comma-separated list of builder arguments, including the trailing
/// {0} is a comma-separated list of builder arguments, including the trailing
/// `loc` and `ip`;
/// {2} is the code populating `operands`, `results` and `attributes` fields.
/// {1} is the code populating `operands`, `results` and `attributes` fields.
constexpr const char *initTemplate = R"Py(
def __init__(self, {1}):
def __init__(self, {0}):
operands = []
results = []
attributes = {{}
{2}
super().__init__(_ods_ir.Operation.create(
"{0}", attributes=attributes, operands=operands, results=results,
{1}
super().__init__(self._ods_build_default(
attributes=attributes, operands=operands, results=results,
loc=loc, ip=ip))
)Py";
@ -472,37 +489,10 @@ constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
constexpr const char *optionalAppendTemplate =
"if {1} is not None: {0}s.append({1})";
/// Template for appending a variadic element to the operand/result list.
/// Template for appending a a list of elements to the operand/result list.
/// {0} is either 'operand' or 'result';
/// {1} is the field name.
constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]";
/// Template for setting up the segment sizes buffer.
constexpr const char *segmentDeclarationTemplate =
"{0}_segment_sizes_ods = _ods_array.array('L')";
/// Template for attaching segment sizes to the attribute list.
constexpr const char *segmentAttributeTemplate =
R"Py(attributes["{0}_segment_sizes"] = _ods_ir.DenseElementsAttr.get({0}_segment_sizes_ods,
context=_ods_get_default_loc_context(loc)))Py";
/// Template for appending the unit size to the segment sizes.
/// {0} is either 'operand' or 'result';
/// {1} is the field name.
constexpr const char *singleElementSegmentTemplate =
"{0}_segment_sizes_ods.append(1) # {1}";
/// Template for appending 0/1 for an optional element to the segment sizes.
/// {0} is either 'operand' or 'result';
/// {1} is the field name.
constexpr const char *optionalSegmentTemplate =
"{0}_segment_sizes_ods.append(0 if {1} is None else 1)";
/// Template for appending the length of a variadic group to the segment sizes.
/// {0} is either 'operand' or 'result';
/// {1} is the field name.
constexpr const char *variadicSegmentTemplate =
"{0}_segment_sizes_ods.append(len({1}))";
constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})";
/// Template for setting an attribute in the operation builder.
/// {0} is the attribute name;
@ -584,11 +574,7 @@ static void populateBuilderLines(
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
// The segment sizes buffer only has to be populated if there attr-sized
// segments trait is present.
bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
if (includeSegments)
builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind));
bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
// For each element, find or generate a name.
for (int i = 0, e = getNumElements(op); i < e; ++i) {
@ -596,28 +582,28 @@ static void populateBuilderLines(
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString, segmentFormatString;
llvm::StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleElementAppendTemplate;
segmentFormatString = singleElementSegmentTemplate;
} else if (element.isOptional()) {
formatString = optionalAppendTemplate;
segmentFormatString = optionalSegmentTemplate;
} else {
assert(element.isVariadic() && "unhandled element group type");
formatString = variadicAppendTemplate;
segmentFormatString = variadicSegmentTemplate;
// If emitting with sizedSegments, then we add the actual list typed
// element using the singleElementAppendTemplate. Otherwise, we extend
// the actual operands.
if (sizedSegments) {
// Append the list as is.
formatString = singleElementAppendTemplate;
} else {
// Append the list elements.
formatString = multiElementAppendTemplate;
}
}
// Add the lines.
builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
if (includeSegments)
builderLines.push_back(
llvm::formatv(segmentFormatString.data(), kind, name));
}
if (includeSegments)
builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind));
}
/// Emits a default builder constructing an operation from the list of its
@ -645,8 +631,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
builderArgs.push_back("loc=None");
builderArgs.push_back("ip=None");
os << llvm::formatv(initTemplate, op.getOperationName(),
llvm::join(builderArgs, ", "),
os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),
llvm::join(builderLines, "\n "));
}
@ -659,12 +644,52 @@ static void constructAttributeMapping(const llvm::RecordKeeper &records,
}
}
static void emitSegmentSpec(
const Operator &op, const char *kind,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement,
raw_ostream &os) {
std::string segmentSpec("[");
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.isVariableLength()) {
segmentSpec.append("-1,");
} else if (element.isOptional()) {
segmentSpec.append("0,");
} else {
segmentSpec.append("1,");
}
}
segmentSpec.append("]");
os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
}
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
// Note that the base OpView class defines this as (0, True).
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
op.hasNoVariadicRegions() ? "True" : "False");
}
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op,
const AttributeClasses &attributeClasses,
raw_ostream &os) {
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
op.getOperationName());
// Sized segments.
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
}
if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
}
emitRegionAttributes(op, os);
emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os);
emitAttributeAccessors(op, attributeClasses, os);