[mlir][Python] Add casting constructor to Type and Attribute.

* This makes them consistent with custom types/attributes, whose constructors will do a type checked conversion. Of course, the base classes can represent everything so never error.
* More importantly, this makes it possible to subclass Type and Attribute out of tree in sensible ways.

Differential Revision: https://reviews.llvm.org/D101734
This commit is contained in:
Stella Laurenzo 2021-05-02 15:15:21 -07:00
parent 5fa9d41634
commit b57d6fe42e
3 changed files with 67 additions and 70 deletions

View File

@ -2255,6 +2255,10 @@ void mlir::python::populateIRCore(py::module &m) {
// Mapping of PyAttribute.
//----------------------------------------------------------------------------
py::class_<PyAttribute>(m, "Attribute")
// Delegate to the PyAttribute copy constructor, which will also lifetime
// extend the backing context which owns the MlirAttribute.
.def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
"Casts the passed attribute to the generic Attribute")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyAttribute::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
@ -2358,6 +2362,10 @@ void mlir::python::populateIRCore(py::module &m) {
// Mapping of PyType.
//----------------------------------------------------------------------------
py::class_<PyType>(m, "Type")
// Delegate to the PyType copy constructor, which will also lifetime
// extend the backing context which owns the MlirType.
.def(py::init<PyType &>(), py::arg("cast_from_type"),
"Casts the passed type to the generic Type")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
.def_static(

View File

@ -8,9 +8,11 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# CHECK-LABEL: TEST: testParsePrint
@run
def testParsePrint():
with Context() as ctx:
t = Attribute.parse('"hello"')
@ -22,12 +24,11 @@ def testParsePrint():
# CHECK: Attribute("hello")
print(repr(t))
run(testParsePrint)
# CHECK-LABEL: TEST: testParseError
# TODO: Hook the diagnostic manager to capture a more meaningful error
# message.
@run
def testParseError():
with Context():
try:
@ -38,10 +39,9 @@ def testParseError():
else:
print("Exception not produced")
run(testParseError)
# CHECK-LABEL: TEST: testAttrEq
@run
def testAttrEq():
with Context():
a1 = Attribute.parse('"attr1"')
@ -56,10 +56,19 @@ def testAttrEq():
# CHECK: a1 == None: False
print("a1 == None:", a1 == None)
run(testAttrEq)
# CHECK-LABEL: TEST: testAttrCast
@run
def testAttrCast():
with Context():
a1 = Attribute.parse('"attr1"')
a2 = Attribute(a1)
# CHECK: a1 == a2: True
print("a1 == a2:", a1 == a2)
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
@run
def testAttrEqDoesNotRaise():
with Context():
a1 = Attribute.parse('"attr1"')
@ -71,10 +80,9 @@ def testAttrEqDoesNotRaise():
# CHECK: True
print(a1 != None)
run(testAttrEqDoesNotRaise)
# CHECK-LABEL: TEST: testAttrCapsule
@run
def testAttrCapsule():
with Context() as ctx:
a1 = Attribute.parse('"attr1"')
@ -85,10 +93,9 @@ def testAttrCapsule():
assert a2 == a1
assert a2.context is ctx
run(testAttrCapsule)
# CHECK-LABEL: TEST: testStandardAttrCasts
@run
def testStandardAttrCasts():
with Context():
a1 = Attribute.parse('"attr1"')
@ -104,10 +111,9 @@ def testStandardAttrCasts():
else:
print("Exception not produced")
run(testStandardAttrCasts)
# CHECK-LABEL: TEST: testAffineMapAttr
@run
def testAffineMapAttr():
with Context() as ctx:
d0 = AffineDimExpr.get(0)
@ -122,10 +128,9 @@ def testAffineMapAttr():
attr_parsed = Attribute.parse(str(attr_built))
assert attr_built == attr_parsed
run(testAffineMapAttr)
# CHECK-LABEL: TEST: testFloatAttr
@run
def testFloatAttr():
with Context(), Location.unknown():
fattr = FloatAttr(Attribute.parse("42.0 : f32"))
@ -149,10 +154,9 @@ def testFloatAttr():
else:
print("Exception not produced")
run(testFloatAttr)
# CHECK-LABEL: TEST: testIntegerAttr
@run
def testIntegerAttr():
with Context() as ctx:
iattr = IntegerAttr(Attribute.parse("42"))
@ -166,10 +170,9 @@ def testIntegerAttr():
print("default_get:", IntegerAttr.get(
IntegerType.get_signless(32), 42))
run(testIntegerAttr)
# CHECK-LABEL: TEST: testBoolAttr
@run
def testBoolAttr():
with Context() as ctx:
battr = BoolAttr(Attribute.parse("true"))
@ -180,10 +183,9 @@ def testBoolAttr():
# CHECK: default_get: true
print("default_get:", BoolAttr.get(True))
run(testBoolAttr)
# CHECK-LABEL: TEST: testFlatSymbolRefAttr
@run
def testFlatSymbolRefAttr():
with Context() as ctx:
sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
@ -194,10 +196,9 @@ def testFlatSymbolRefAttr():
# CHECK: default_get: @foobar
print("default_get:", FlatSymbolRefAttr.get("foobar"))
run(testFlatSymbolRefAttr)
# CHECK-LABEL: TEST: testStringAttr
@run
def testStringAttr():
with Context() as ctx:
sattr = StringAttr(Attribute.parse('"stringattr"'))
@ -211,10 +212,9 @@ def testStringAttr():
print("typed_get:", StringAttr.get_typed(
IntegerType.get_signless(32), "12345"))
run(testStringAttr)
# CHECK-LABEL: TEST: testNamedAttr
@run
def testNamedAttr():
with Context():
a = Attribute.parse('"stringattr"')
@ -226,10 +226,9 @@ def testNamedAttr():
# CHECK: named: NamedAttribute(foobar="stringattr")
print("named:", named)
run(testNamedAttr)
# CHECK-LABEL: TEST: testDenseIntAttr
@run
def testDenseIntAttr():
with Context():
raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
@ -263,10 +262,8 @@ def testDenseIntAttr():
print(ShapedType(a.type).element_type)
run(testDenseIntAttr)
# CHECK-LABEL: TEST: testDenseFPAttr
@run
def testDenseFPAttr():
with Context():
raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
@ -286,10 +283,8 @@ def testDenseFPAttr():
print(ShapedType(a.type).element_type)
run(testDenseFPAttr)
# CHECK-LABEL: TEST: testDictAttr
@run
def testDictAttr():
with Context():
dict_attr = {
@ -327,10 +322,8 @@ def testDictAttr():
assert False, "expected IndexError on accessing an out-of-bounds attribute"
run(testDictAttr)
# CHECK-LABEL: TEST: testTypeAttr
@run
def testTypeAttr():
with Context():
raw = Attribute.parse("vector<4xf32>")
@ -341,10 +334,8 @@ def testTypeAttr():
print(ShapedType(type_attr.value).element_type)
run(testTypeAttr)
# CHECK-LABEL: TEST: testArrayAttr
@run
def testArrayAttr():
with Context():
raw = Attribute.parse("[42, true, vector<4xf32>]")
@ -391,5 +382,4 @@ def testArrayAttr():
except RuntimeError as e:
# CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
print("Error: ", e)
run(testArrayAttr)

View File

@ -8,9 +8,11 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# CHECK-LABEL: TEST: testParsePrint
@run
def testParsePrint():
ctx = Context()
t = Type.parse("i32", ctx)
@ -22,12 +24,11 @@ def testParsePrint():
# CHECK: Type(i32)
print(repr(t))
run(testParsePrint)
# CHECK-LABEL: TEST: testParseError
# TODO: Hook the diagnostic manager to capture a more meaningful error
# message.
@run
def testParseError():
ctx = Context()
try:
@ -38,10 +39,9 @@ def testParseError():
else:
print("Exception not produced")
run(testParseError)
# CHECK-LABEL: TEST: testTypeEq
@run
def testTypeEq():
ctx = Context()
t1 = Type.parse("i32", ctx)
@ -56,10 +56,19 @@ def testTypeEq():
# CHECK: t1 == None: False
print("t1 == None:", t1 == None)
run(testTypeEq)
# CHECK-LABEL: TEST: testTypeCast
@run
def testTypeCast():
ctx = Context()
t1 = Type.parse("i32", ctx)
t2 = Type(t1)
# CHECK: t1 == t2: True
print("t1 == t2:", t1 == t2)
# CHECK-LABEL: TEST: testTypeIsInstance
@run
def testTypeIsInstance():
ctx = Context()
t1 = Type.parse("i32", ctx)
@ -71,10 +80,9 @@ def testTypeIsInstance():
# CHECK: True
print(F32Type.isinstance(t2))
run(testTypeIsInstance)
# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
@run
def testTypeEqDoesNotRaise():
ctx = Context()
t1 = Type.parse("i32", ctx)
@ -86,10 +94,9 @@ def testTypeEqDoesNotRaise():
# CHECK: True
print(t1 != None)
run(testTypeEqDoesNotRaise)
# CHECK-LABEL: TEST: testTypeCapsule
@run
def testTypeCapsule():
with Context() as ctx:
t1 = Type.parse("i32", ctx)
@ -100,10 +107,9 @@ def testTypeCapsule():
assert t2 == t1
assert t2.context is ctx
run(testTypeCapsule)
# CHECK-LABEL: TEST: testStandardTypeCasts
@run
def testStandardTypeCasts():
ctx = Context()
t1 = Type.parse("i32", ctx)
@ -119,10 +125,9 @@ def testStandardTypeCasts():
else:
print("Exception not produced")
run(testStandardTypeCasts)
# CHECK-LABEL: TEST: testIntegerType
@run
def testIntegerType():
with Context() as ctx:
i32 = IntegerType(Type.parse("i32"))
@ -158,17 +163,16 @@ def testIntegerType():
# CHECK: unsigned: ui64
print("unsigned:", IntegerType.get_unsigned(64))
run(testIntegerType)
# CHECK-LABEL: TEST: testIndexType
@run
def testIndexType():
with Context() as ctx:
# CHECK: index type: index
print("index type:", IndexType.get())
run(testIndexType)
# CHECK-LABEL: TEST: testFloatType
@run
def testFloatType():
with Context():
# CHECK: float: bf16
@ -180,17 +184,17 @@ def testFloatType():
# CHECK: float: f64
print("float:", F64Type.get())
run(testFloatType)
# CHECK-LABEL: TEST: testNoneType
@run
def testNoneType():
with Context():
# CHECK: none type: none
print("none type:", NoneType.get())
run(testNoneType)
# CHECK-LABEL: TEST: testComplexType
@run
def testComplexType():
with Context() as ctx:
complex_i32 = ComplexType(Type.parse("complex<i32>"))
@ -210,13 +214,12 @@ def testComplexType():
else:
print("Exception not produced")
run(testComplexType)
# CHECK-LABEL: TEST: testConcreteShapedType
# Shaped type is not a kind of builtin types, it is the base class for vectors,
# memrefs and tensors, so this test case uses an instance of vector to test the
# shaped type. The class hierarchy is preserved on the python side.
@run
def testConcreteShapedType():
with Context() as ctx:
vector = VectorType(Type.parse("vector<2x3xf32>"))
@ -239,20 +242,20 @@ def testConcreteShapedType():
# CHECK: isinstance(ShapedType): True
print("isinstance(ShapedType):", isinstance(vector, ShapedType))
run(testConcreteShapedType)
# CHECK-LABEL: TEST: testAbstractShapedType
# Tests that ShapedType operates as an abstract base class of a concrete
# shaped type (using vector as an example).
@run
def testAbstractShapedType():
ctx = Context()
vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
# CHECK: element type: f32
print("element type:", vector.element_type)
run(testAbstractShapedType)
# CHECK-LABEL: TEST: testVectorType
@run
def testVectorType():
with Context(), Location.unknown():
f32 = F32Type.get()
@ -269,9 +272,9 @@ def testVectorType():
else:
print("Exception not produced")
run(testVectorType)
# CHECK-LABEL: TEST: testRankedTensorType
@run
def testRankedTensorType():
with Context(), Location.unknown():
f32 = F32Type.get()
@ -291,9 +294,9 @@ def testRankedTensorType():
else:
print("Exception not produced")
run(testRankedTensorType)
# CHECK-LABEL: TEST: testUnrankedTensorType
@run
def testUnrankedTensorType():
with Context(), Location.unknown():
f32 = F32Type.get()
@ -333,9 +336,9 @@ def testUnrankedTensorType():
else:
print("Exception not produced")
run(testUnrankedTensorType)
# CHECK-LABEL: TEST: testMemRefType
@run
def testMemRefType():
with Context(), Location.unknown():
f32 = F32Type.get()
@ -369,9 +372,9 @@ def testMemRefType():
else:
print("Exception not produced")
run(testMemRefType)
# CHECK-LABEL: TEST: testUnrankedMemRefType
@run
def testUnrankedMemRefType():
with Context(), Location.unknown():
f32 = F32Type.get()
@ -411,9 +414,9 @@ def testUnrankedMemRefType():
else:
print("Exception not produced")
run(testUnrankedMemRefType)
# CHECK-LABEL: TEST: testTupleType
@run
def testTupleType():
with Context() as ctx:
i32 = IntegerType(Type.parse("i32"))
@ -428,10 +431,9 @@ def testTupleType():
# CHECK: pos-th type in the tuple type: f32
print("pos-th type in the tuple type:", tuple_type.get_type(1))
run(testTupleType)
# CHECK-LABEL: TEST: testFunctionType
@run
def testFunctionType():
with Context() as ctx:
input_types = [IntegerType.get_signless(32),
@ -442,6 +444,3 @@ def testFunctionType():
print("INPUTS:", func.inputs)
# CHECK: RESULTS: [Type(index)]
print("RESULTS:", func.results)
run(testFunctionType)