forked from OSchip/llvm-project
[mlir][python] Provide some methods and properties for API completeness
When writing the user-facing documentation, I noticed several inconsistencies and asymmetries in the Python API we provide. Fix them by adding: - the `owner` property to regions, similarly to blocks; - the `isinstance` method to any class derived from `PyConcreteAttr`, `PyConcreteValue` and `PyConreteAffineExpr`, similar to `PyConcreteType` to enable `isa`-like calls without having to handle exceptions; - a mechanism to create the first block in the region as we could only create blocks relative to other blocks, with is impossible in an empty region. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111556
This commit is contained in:
parent
e845ca2ff1
commit
78f2dae00d
|
@ -612,8 +612,22 @@ operations (unlike in C++ that supports detached regions).
|
|||
|
||||
Blocks can be created within a given region and inserted before or after another
|
||||
block of the same region using `create_before()`, `create_after()` methods of
|
||||
the `Block` class. They are not expected to exist outside of regions (unlike in
|
||||
C++ that supports detached blocks).
|
||||
the `Block` class, or the `create_at_start()` static method of the same class.
|
||||
They are not expected to exist outside of regions (unlike in C++ that supports
|
||||
detached blocks).
|
||||
|
||||
```python
|
||||
from mlir.ir import Block, Context, Operation
|
||||
|
||||
with Context():
|
||||
op = Operation.create("generic.op", regions=1)
|
||||
|
||||
# Create the first block in the region.
|
||||
entry_block = Block.create_at_start(op.regions[0])
|
||||
|
||||
# Create further blocks.
|
||||
other_block = entry_block.create_after()
|
||||
```
|
||||
|
||||
Blocks can be used to create `InsertionPoint`s, which can point to the beginning
|
||||
or the end of the block, or just before its terminator. It is common for
|
||||
|
|
|
@ -99,6 +99,9 @@ public:
|
|||
static void bind(py::module &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
|
||||
cls.def(py::init<PyAffineExpr &>());
|
||||
cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool {
|
||||
return DerivedTy::isaFunction(otherAffineExpr);
|
||||
});
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
|
|
|
@ -1548,6 +1548,9 @@ public:
|
|||
static void bind(py::module &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
|
||||
cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
|
||||
cls.def_static("isinstance", [](PyValue &otherValue) -> bool {
|
||||
return DerivedTy::isaFunction(otherValue);
|
||||
});
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
|
@ -2248,6 +2251,12 @@ void mlir::python::populateIRCore(py::module &m) {
|
|||
return PyBlockList(self.getParentOperation(), self.get());
|
||||
},
|
||||
"Returns a forward-optimized sequence of blocks.")
|
||||
.def_property_readonly(
|
||||
"owner",
|
||||
[](PyRegion &self) {
|
||||
return self.getParentOperation()->createOpView();
|
||||
},
|
||||
"Returns the operation owning this region.")
|
||||
.def(
|
||||
"__iter__",
|
||||
[](PyRegion &self) {
|
||||
|
@ -2291,6 +2300,23 @@ void mlir::python::populateIRCore(py::module &m) {
|
|||
return PyOperationList(self.getParentOperation(), self.get());
|
||||
},
|
||||
"Returns a forward-optimized sequence of operations.")
|
||||
.def_static(
|
||||
"create_at_start",
|
||||
[](PyRegion &parent, py::list pyArgTypes) {
|
||||
parent.checkValid();
|
||||
llvm::SmallVector<MlirType, 4> argTypes;
|
||||
argTypes.reserve(pyArgTypes.size());
|
||||
for (auto &pyArg : pyArgTypes) {
|
||||
argTypes.push_back(pyArg.cast<PyType &>());
|
||||
}
|
||||
|
||||
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
|
||||
mlirRegionInsertOwnedBlock(parent, 0, block);
|
||||
return PyBlock(parent.getParentOperation(), block);
|
||||
},
|
||||
py::arg("parent"), py::arg("pyArgTypes") = py::list(),
|
||||
"Creates and returns a new Block at the beginning of the given "
|
||||
"region (with given argument types).")
|
||||
.def(
|
||||
"create_before",
|
||||
[](PyBlock &self, py::args pyArgTypes) {
|
||||
|
|
|
@ -533,6 +533,7 @@ public:
|
|||
: parentOperation(std::move(parentOperation)), region(region) {
|
||||
assert(!mlirRegionIsNull(region) && "python region cannot be null");
|
||||
}
|
||||
operator MlirRegion() const { return region; }
|
||||
|
||||
MlirRegion get() { return region; }
|
||||
PyOperationRef &getParentOperation() { return parentOperation; }
|
||||
|
@ -681,6 +682,9 @@ public:
|
|||
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
|
||||
pybind11::module_local());
|
||||
cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>());
|
||||
cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool {
|
||||
return DerivedTy::isaFunction(otherAttr);
|
||||
});
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
|
@ -764,6 +768,7 @@ class PyValue {
|
|||
public:
|
||||
PyValue(PyOperationRef parentOperation, MlirValue value)
|
||||
: parentOperation(parentOperation), value(value) {}
|
||||
operator MlirValue() const { return value; }
|
||||
|
||||
MlirValue get() { return value; }
|
||||
PyOperationRef &getParentOperation() { return parentOperation; }
|
||||
|
|
|
@ -8,9 +8,11 @@ def run(f):
|
|||
f()
|
||||
gc.collect()
|
||||
assert Context._get_live_count() == 0
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineExprCapsule
|
||||
@run
|
||||
def testAffineExprCapsule():
|
||||
with Context() as ctx:
|
||||
affine_expr = AffineExpr.get_constant(42)
|
||||
|
@ -24,10 +26,9 @@ def testAffineExprCapsule():
|
|||
assert affine_expr == affine_expr_2
|
||||
assert affine_expr_2.context == ctx
|
||||
|
||||
run(testAffineExprCapsule)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineExprEq
|
||||
@run
|
||||
def testAffineExprEq():
|
||||
with Context():
|
||||
a1 = AffineExpr.get_constant(42)
|
||||
|
@ -44,10 +45,9 @@ def testAffineExprEq():
|
|||
# CHECK: False
|
||||
print(a1 == "foo")
|
||||
|
||||
run(testAffineExprEq)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineExprContext
|
||||
@run
|
||||
def testAffineExprContext():
|
||||
with Context():
|
||||
a1 = AffineExpr.get_constant(42)
|
||||
|
@ -61,6 +61,7 @@ run(testAffineExprContext)
|
|||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineExprConstant
|
||||
@run
|
||||
def testAffineExprConstant():
|
||||
with Context():
|
||||
a1 = AffineExpr.get_constant(42)
|
||||
|
@ -77,10 +78,9 @@ def testAffineExprConstant():
|
|||
|
||||
assert a1 == a2
|
||||
|
||||
run(testAffineExprConstant)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineExprDim
|
||||
@run
|
||||
def testAffineExprDim():
|
||||
with Context():
|
||||
d1 = AffineExpr.get_dim(1)
|
||||
|
@ -100,10 +100,9 @@ def testAffineExprDim():
|
|||
assert d1 == d11
|
||||
assert d1 != d2
|
||||
|
||||
run(testAffineExprDim)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineExprSymbol
|
||||
@run
|
||||
def testAffineExprSymbol():
|
||||
with Context():
|
||||
s1 = AffineExpr.get_symbol(1)
|
||||
|
@ -123,10 +122,9 @@ def testAffineExprSymbol():
|
|||
assert s1 == s11
|
||||
assert s1 != s2
|
||||
|
||||
run(testAffineExprSymbol)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineAddExpr
|
||||
@run
|
||||
def testAffineAddExpr():
|
||||
with Context():
|
||||
d1 = AffineDimExpr.get(1)
|
||||
|
@ -143,10 +141,9 @@ def testAffineAddExpr():
|
|||
assert d12.lhs == d1
|
||||
assert d12.rhs == d2
|
||||
|
||||
run(testAffineAddExpr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineMulExpr
|
||||
@run
|
||||
def testAffineMulExpr():
|
||||
with Context():
|
||||
d1 = AffineDimExpr.get(1)
|
||||
|
@ -163,10 +160,9 @@ def testAffineMulExpr():
|
|||
assert expr.lhs == d1
|
||||
assert expr.rhs == c2
|
||||
|
||||
run(testAffineMulExpr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineModExpr
|
||||
@run
|
||||
def testAffineModExpr():
|
||||
with Context():
|
||||
d1 = AffineDimExpr.get(1)
|
||||
|
@ -183,10 +179,9 @@ def testAffineModExpr():
|
|||
assert expr.lhs == d1
|
||||
assert expr.rhs == c2
|
||||
|
||||
run(testAffineModExpr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineFloorDivExpr
|
||||
@run
|
||||
def testAffineFloorDivExpr():
|
||||
with Context():
|
||||
d1 = AffineDimExpr.get(1)
|
||||
|
@ -198,10 +193,9 @@ def testAffineFloorDivExpr():
|
|||
assert expr.lhs == d1
|
||||
assert expr.rhs == c2
|
||||
|
||||
run(testAffineFloorDivExpr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineCeilDivExpr
|
||||
@run
|
||||
def testAffineCeilDivExpr():
|
||||
with Context():
|
||||
d1 = AffineDimExpr.get(1)
|
||||
|
@ -213,10 +207,9 @@ def testAffineCeilDivExpr():
|
|||
assert expr.lhs == d1
|
||||
assert expr.rhs == c2
|
||||
|
||||
run(testAffineCeilDivExpr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineExprSub
|
||||
@run
|
||||
def testAffineExprSub():
|
||||
with Context():
|
||||
d1 = AffineDimExpr.get(1)
|
||||
|
@ -232,9 +225,8 @@ def testAffineExprSub():
|
|||
# CHECK: -1
|
||||
print(rhs.rhs)
|
||||
|
||||
run(testAffineExprSub)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testClassHierarchy
|
||||
@run
|
||||
def testClassHierarchy():
|
||||
with Context():
|
||||
d1 = AffineDimExpr.get(1)
|
||||
|
@ -272,4 +264,28 @@ def testClassHierarchy():
|
|||
# CHECK: Cannot cast affine expression to AffineBinaryExpr
|
||||
print(e)
|
||||
|
||||
run(testClassHierarchy)
|
||||
# CHECK-LABEL: TEST: testIsInstance
|
||||
@run
|
||||
def testIsInstance():
|
||||
with Context():
|
||||
d1 = AffineDimExpr.get(1)
|
||||
c2 = AffineConstantExpr.get(2)
|
||||
add = AffineAddExpr.get(d1, c2)
|
||||
mul = AffineMulExpr.get(d1, c2)
|
||||
|
||||
# CHECK: True
|
||||
print(AffineDimExpr.isinstance(d1))
|
||||
# CHECK: False
|
||||
print(AffineConstantExpr.isinstance(d1))
|
||||
# CHECK: True
|
||||
print(AffineConstantExpr.isinstance(c2))
|
||||
# CHECK: False
|
||||
print(AffineMulExpr.isinstance(c2))
|
||||
# CHECK: True
|
||||
print(AffineAddExpr.isinstance(add))
|
||||
# CHECK: False
|
||||
print(AffineMulExpr.isinstance(add))
|
||||
# CHECK: True
|
||||
print(AffineMulExpr.isinstance(mul))
|
||||
# CHECK: False
|
||||
print(AffineAddExpr.isinstance(mul))
|
||||
|
|
|
@ -89,6 +89,18 @@ def testAttrCast():
|
|||
print("a1 == a2:", a1 == a2)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAttrIsInstance
|
||||
@run
|
||||
def testAttrIsInstance():
|
||||
with Context():
|
||||
a1 = Attribute.parse("42")
|
||||
a2 = Attribute.parse("[42]")
|
||||
assert IntegerAttr.isinstance(a1)
|
||||
assert not IntegerAttr.isinstance(a2)
|
||||
assert not ArrayAttr.isinstance(a1)
|
||||
assert ArrayAttr.isinstance(a2)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
|
||||
@run
|
||||
def testAttrEqDoesNotRaise():
|
||||
|
|
|
@ -51,3 +51,22 @@ def testBlockCreation():
|
|||
print(module.operation)
|
||||
# Ensure region back references are coherent.
|
||||
assert entry_block.region == middle_block.region == successor_block.region
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFirstBlockCreation
|
||||
# CHECK: func @test(%{{.*}}: f32)
|
||||
# CHECK: return
|
||||
@run
|
||||
def testFirstBlockCreation():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
func = builtin.FuncOp("test", ([f32], []))
|
||||
entry_block = Block.create_at_start(func.operation.regions[0], [f32])
|
||||
with InsertionPoint(entry_block):
|
||||
std.ReturnOp([])
|
||||
|
||||
print(module)
|
||||
assert module.operation.verify()
|
||||
assert func.body.blocks[0] == entry_block
|
||||
|
|
|
@ -11,10 +11,12 @@ def run(f):
|
|||
f()
|
||||
gc.collect()
|
||||
assert Context._get_live_count() == 0
|
||||
return f
|
||||
|
||||
|
||||
# Verify iterator based traversal of the op/region/block hierarchy.
|
||||
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
|
||||
@run
|
||||
def testTraverseOpRegionBlockIterators():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -69,11 +71,9 @@ def testTraverseOpRegionBlockIterators():
|
|||
walk_operations("", op)
|
||||
|
||||
|
||||
run(testTraverseOpRegionBlockIterators)
|
||||
|
||||
|
||||
# Verify index based traversal of the op/region/block hierarchy.
|
||||
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
|
||||
@run
|
||||
def testTraverseOpRegionBlockIndices():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -111,10 +111,30 @@ def testTraverseOpRegionBlockIndices():
|
|||
walk_operations("", module.operation)
|
||||
|
||||
|
||||
run(testTraverseOpRegionBlockIndices)
|
||||
# CHECK-LABEL: TEST: testBlockAndRegionOwners
|
||||
@run
|
||||
def testBlockAndRegionOwners():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = Module.parse(
|
||||
r"""
|
||||
builtin.module {
|
||||
builtin.func @f() {
|
||||
std.return
|
||||
}
|
||||
}
|
||||
""", ctx)
|
||||
|
||||
assert module.operation.regions[0].owner == module.operation
|
||||
assert module.operation.regions[0].blocks[0].owner == module.operation
|
||||
|
||||
func = module.body.operations[0]
|
||||
assert func.operation.regions[0].owner == func
|
||||
assert func.operation.regions[0].blocks[0].owner == func
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBlockArgumentList
|
||||
@run
|
||||
def testBlockArgumentList():
|
||||
with Context() as ctx:
|
||||
module = Module.parse(
|
||||
|
@ -158,10 +178,8 @@ def testBlockArgumentList():
|
|||
print("Type: ", t)
|
||||
|
||||
|
||||
run(testBlockArgumentList)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationOperands
|
||||
@run
|
||||
def testOperationOperands():
|
||||
with Context() as ctx:
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -181,10 +199,10 @@ def testOperationOperands():
|
|||
print(f"Operand {i}, type {operand.type}")
|
||||
|
||||
|
||||
run(testOperationOperands)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationOperandsSlice
|
||||
@run
|
||||
def testOperationOperandsSlice():
|
||||
with Context() as ctx:
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -239,10 +257,10 @@ def testOperationOperandsSlice():
|
|||
print(operand)
|
||||
|
||||
|
||||
run(testOperationOperandsSlice)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationOperandsSet
|
||||
@run
|
||||
def testOperationOperandsSet():
|
||||
with Context() as ctx, Location.unknown(ctx):
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -271,10 +289,10 @@ def testOperationOperandsSet():
|
|||
print(consumer.operands[0])
|
||||
|
||||
|
||||
run(testOperationOperandsSet)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDetachedOperation
|
||||
@run
|
||||
def testDetachedOperation():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -295,10 +313,8 @@ def testDetachedOperation():
|
|||
# TODO: Check successors once enough infra exists to do it properly.
|
||||
|
||||
|
||||
run(testDetachedOperation)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationInsertionPoint
|
||||
@run
|
||||
def testOperationInsertionPoint():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -335,10 +351,8 @@ def testOperationInsertionPoint():
|
|||
assert False, "expected insert of attached op to raise"
|
||||
|
||||
|
||||
run(testOperationInsertionPoint)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationWithRegion
|
||||
@run
|
||||
def testOperationWithRegion():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -377,10 +391,8 @@ def testOperationWithRegion():
|
|||
print(module)
|
||||
|
||||
|
||||
run(testOperationWithRegion)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationResultList
|
||||
@run
|
||||
def testOperationResultList():
|
||||
ctx = Context()
|
||||
module = Module.parse(
|
||||
|
@ -407,10 +419,10 @@ def testOperationResultList():
|
|||
print(f"Result type {t}")
|
||||
|
||||
|
||||
run(testOperationResultList)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationResultListSlice
|
||||
@run
|
||||
def testOperationResultListSlice():
|
||||
with Context() as ctx:
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -458,10 +470,10 @@ def testOperationResultListSlice():
|
|||
print(f"Result {res.result_number}, type {res.type}")
|
||||
|
||||
|
||||
run(testOperationResultListSlice)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationAttributes
|
||||
@run
|
||||
def testOperationAttributes():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -506,10 +518,10 @@ def testOperationAttributes():
|
|||
assert False, "expected IndexError on accessing an out-of-bounds attribute"
|
||||
|
||||
|
||||
run(testOperationAttributes)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationPrint
|
||||
@run
|
||||
def testOperationPrint():
|
||||
ctx = Context()
|
||||
module = Module.parse(
|
||||
|
@ -553,10 +565,10 @@ def testOperationPrint():
|
|||
use_local_scope=True)
|
||||
|
||||
|
||||
run(testOperationPrint)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testKnownOpView
|
||||
@run
|
||||
def testKnownOpView():
|
||||
with Context(), Location.unknown():
|
||||
Context.current.allow_unregistered_dialects = True
|
||||
|
@ -586,10 +598,8 @@ def testKnownOpView():
|
|||
print(repr(custom))
|
||||
|
||||
|
||||
run(testKnownOpView)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSingleResultProperty
|
||||
@run
|
||||
def testSingleResultProperty():
|
||||
with Context(), Location.unknown():
|
||||
Context.current.allow_unregistered_dialects = True
|
||||
|
@ -620,10 +630,8 @@ def testSingleResultProperty():
|
|||
print(module.body.operations[2])
|
||||
|
||||
|
||||
run(testSingleResultProperty)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testPrintInvalidOperation
|
||||
@run
|
||||
def testPrintInvalidOperation():
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx):
|
||||
|
@ -639,10 +647,8 @@ def testPrintInvalidOperation():
|
|||
print(f".verify = {module.operation.verify()}")
|
||||
|
||||
|
||||
run(testPrintInvalidOperation)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
|
||||
@run
|
||||
def testCreateWithInvalidAttributes():
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx):
|
||||
|
@ -670,10 +676,8 @@ def testCreateWithInvalidAttributes():
|
|||
print(e)
|
||||
|
||||
|
||||
run(testCreateWithInvalidAttributes)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationName
|
||||
@run
|
||||
def testOperationName():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -691,10 +695,8 @@ def testOperationName():
|
|||
print(op.operation.name)
|
||||
|
||||
|
||||
run(testOperationName)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testCapsuleConversions
|
||||
@run
|
||||
def testCapsuleConversions():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -706,10 +708,8 @@ def testCapsuleConversions():
|
|||
assert m2 is m
|
||||
|
||||
|
||||
run(testCapsuleConversions)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationErase
|
||||
@run
|
||||
def testOperationErase():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -728,6 +728,3 @@ def testOperationErase():
|
|||
|
||||
# Ensure we can create another operation
|
||||
Operation.create("custom.op2")
|
||||
|
||||
|
||||
run(testOperationErase)
|
||||
|
|
|
@ -9,9 +9,11 @@ def run(f):
|
|||
f()
|
||||
gc.collect()
|
||||
assert Context._get_live_count() == 0
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testCapsuleConversions
|
||||
@run
|
||||
def testCapsuleConversions():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -24,10 +26,8 @@ def testCapsuleConversions():
|
|||
assert value2 == value
|
||||
|
||||
|
||||
run(testCapsuleConversions)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOpResultOwner
|
||||
@run
|
||||
def testOpResultOwner():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
@ -37,4 +37,21 @@ def testOpResultOwner():
|
|||
assert op.result.owner == op
|
||||
|
||||
|
||||
run(testOpResultOwner)
|
||||
# CHECK-LABEL: TEST: testValueIsInstance
|
||||
@run
|
||||
def testValueIsInstance():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = Module.parse(
|
||||
r"""
|
||||
func @foo(%arg0: f32) {
|
||||
%0 = "some_dialect.some_op"() : () -> f64
|
||||
return
|
||||
}""", ctx)
|
||||
func = module.body.operations[0]
|
||||
assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
|
||||
assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
|
||||
|
||||
op = func.regions[0].blocks[0].operations[0]
|
||||
assert not BlockArgument.isinstance(op.results[0])
|
||||
assert OpResult.isinstance(op.results[0])
|
||||
|
|
Loading…
Reference in New Issue