forked from OSchip/llvm-project
[mlir][Python] Add casting constructor to Type and Attribute.
* This makes them consistent with custom types/attributes, whose constructors will do a type checked conversion. Of course, the base classes can represent everything so never error. * More importantly, this makes it possible to subclass Type and Attribute out of tree in sensible ways. Differential Revision: https://reviews.llvm.org/D101734
This commit is contained in:
parent
5fa9d41634
commit
b57d6fe42e
|
@ -2255,6 +2255,10 @@ void mlir::python::populateIRCore(py::module &m) {
|
|||
// Mapping of PyAttribute.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyAttribute>(m, "Attribute")
|
||||
// Delegate to the PyAttribute copy constructor, which will also lifetime
|
||||
// extend the backing context which owns the MlirAttribute.
|
||||
.def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
|
||||
"Casts the passed attribute to the generic Attribute")
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyAttribute::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
|
||||
|
@ -2358,6 +2362,10 @@ void mlir::python::populateIRCore(py::module &m) {
|
|||
// Mapping of PyType.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyType>(m, "Type")
|
||||
// Delegate to the PyType copy constructor, which will also lifetime
|
||||
// extend the backing context which owns the MlirType.
|
||||
.def(py::init<PyType &>(), py::arg("cast_from_type"),
|
||||
"Casts the passed type to the generic Type")
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
|
||||
.def_static(
|
||||
|
|
|
@ -8,9 +8,11 @@ def run(f):
|
|||
f()
|
||||
gc.collect()
|
||||
assert Context._get_live_count() == 0
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testParsePrint
|
||||
@run
|
||||
def testParsePrint():
|
||||
with Context() as ctx:
|
||||
t = Attribute.parse('"hello"')
|
||||
|
@ -22,12 +24,11 @@ def testParsePrint():
|
|||
# CHECK: Attribute("hello")
|
||||
print(repr(t))
|
||||
|
||||
run(testParsePrint)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testParseError
|
||||
# TODO: Hook the diagnostic manager to capture a more meaningful error
|
||||
# message.
|
||||
@run
|
||||
def testParseError():
|
||||
with Context():
|
||||
try:
|
||||
|
@ -38,10 +39,9 @@ def testParseError():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testParseError)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAttrEq
|
||||
@run
|
||||
def testAttrEq():
|
||||
with Context():
|
||||
a1 = Attribute.parse('"attr1"')
|
||||
|
@ -56,10 +56,19 @@ def testAttrEq():
|
|||
# CHECK: a1 == None: False
|
||||
print("a1 == None:", a1 == None)
|
||||
|
||||
run(testAttrEq)
|
||||
|
||||
# CHECK-LABEL: TEST: testAttrCast
|
||||
@run
|
||||
def testAttrCast():
|
||||
with Context():
|
||||
a1 = Attribute.parse('"attr1"')
|
||||
a2 = Attribute(a1)
|
||||
# CHECK: a1 == a2: True
|
||||
print("a1 == a2:", a1 == a2)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
|
||||
@run
|
||||
def testAttrEqDoesNotRaise():
|
||||
with Context():
|
||||
a1 = Attribute.parse('"attr1"')
|
||||
|
@ -71,10 +80,9 @@ def testAttrEqDoesNotRaise():
|
|||
# CHECK: True
|
||||
print(a1 != None)
|
||||
|
||||
run(testAttrEqDoesNotRaise)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAttrCapsule
|
||||
@run
|
||||
def testAttrCapsule():
|
||||
with Context() as ctx:
|
||||
a1 = Attribute.parse('"attr1"')
|
||||
|
@ -85,10 +93,9 @@ def testAttrCapsule():
|
|||
assert a2 == a1
|
||||
assert a2.context is ctx
|
||||
|
||||
run(testAttrCapsule)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testStandardAttrCasts
|
||||
@run
|
||||
def testStandardAttrCasts():
|
||||
with Context():
|
||||
a1 = Attribute.parse('"attr1"')
|
||||
|
@ -104,10 +111,9 @@ def testStandardAttrCasts():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testStandardAttrCasts)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineMapAttr
|
||||
@run
|
||||
def testAffineMapAttr():
|
||||
with Context() as ctx:
|
||||
d0 = AffineDimExpr.get(0)
|
||||
|
@ -122,10 +128,9 @@ def testAffineMapAttr():
|
|||
attr_parsed = Attribute.parse(str(attr_built))
|
||||
assert attr_built == attr_parsed
|
||||
|
||||
run(testAffineMapAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFloatAttr
|
||||
@run
|
||||
def testFloatAttr():
|
||||
with Context(), Location.unknown():
|
||||
fattr = FloatAttr(Attribute.parse("42.0 : f32"))
|
||||
|
@ -149,10 +154,9 @@ def testFloatAttr():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testFloatAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testIntegerAttr
|
||||
@run
|
||||
def testIntegerAttr():
|
||||
with Context() as ctx:
|
||||
iattr = IntegerAttr(Attribute.parse("42"))
|
||||
|
@ -166,10 +170,9 @@ def testIntegerAttr():
|
|||
print("default_get:", IntegerAttr.get(
|
||||
IntegerType.get_signless(32), 42))
|
||||
|
||||
run(testIntegerAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBoolAttr
|
||||
@run
|
||||
def testBoolAttr():
|
||||
with Context() as ctx:
|
||||
battr = BoolAttr(Attribute.parse("true"))
|
||||
|
@ -180,10 +183,9 @@ def testBoolAttr():
|
|||
# CHECK: default_get: true
|
||||
print("default_get:", BoolAttr.get(True))
|
||||
|
||||
run(testBoolAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFlatSymbolRefAttr
|
||||
@run
|
||||
def testFlatSymbolRefAttr():
|
||||
with Context() as ctx:
|
||||
sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
|
||||
|
@ -194,10 +196,9 @@ def testFlatSymbolRefAttr():
|
|||
# CHECK: default_get: @foobar
|
||||
print("default_get:", FlatSymbolRefAttr.get("foobar"))
|
||||
|
||||
run(testFlatSymbolRefAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testStringAttr
|
||||
@run
|
||||
def testStringAttr():
|
||||
with Context() as ctx:
|
||||
sattr = StringAttr(Attribute.parse('"stringattr"'))
|
||||
|
@ -211,10 +212,9 @@ def testStringAttr():
|
|||
print("typed_get:", StringAttr.get_typed(
|
||||
IntegerType.get_signless(32), "12345"))
|
||||
|
||||
run(testStringAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testNamedAttr
|
||||
@run
|
||||
def testNamedAttr():
|
||||
with Context():
|
||||
a = Attribute.parse('"stringattr"')
|
||||
|
@ -226,10 +226,9 @@ def testNamedAttr():
|
|||
# CHECK: named: NamedAttribute(foobar="stringattr")
|
||||
print("named:", named)
|
||||
|
||||
run(testNamedAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDenseIntAttr
|
||||
@run
|
||||
def testDenseIntAttr():
|
||||
with Context():
|
||||
raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
|
||||
|
@ -263,10 +262,8 @@ def testDenseIntAttr():
|
|||
print(ShapedType(a.type).element_type)
|
||||
|
||||
|
||||
run(testDenseIntAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDenseFPAttr
|
||||
@run
|
||||
def testDenseFPAttr():
|
||||
with Context():
|
||||
raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
|
||||
|
@ -286,10 +283,8 @@ def testDenseFPAttr():
|
|||
print(ShapedType(a.type).element_type)
|
||||
|
||||
|
||||
run(testDenseFPAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDictAttr
|
||||
@run
|
||||
def testDictAttr():
|
||||
with Context():
|
||||
dict_attr = {
|
||||
|
@ -327,10 +322,8 @@ def testDictAttr():
|
|||
assert False, "expected IndexError on accessing an out-of-bounds attribute"
|
||||
|
||||
|
||||
|
||||
run(testDictAttr)
|
||||
|
||||
# CHECK-LABEL: TEST: testTypeAttr
|
||||
@run
|
||||
def testTypeAttr():
|
||||
with Context():
|
||||
raw = Attribute.parse("vector<4xf32>")
|
||||
|
@ -341,10 +334,8 @@ def testTypeAttr():
|
|||
print(ShapedType(type_attr.value).element_type)
|
||||
|
||||
|
||||
run(testTypeAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testArrayAttr
|
||||
@run
|
||||
def testArrayAttr():
|
||||
with Context():
|
||||
raw = Attribute.parse("[42, true, vector<4xf32>]")
|
||||
|
@ -391,5 +382,4 @@ def testArrayAttr():
|
|||
except RuntimeError as e:
|
||||
# CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
|
||||
print("Error: ", e)
|
||||
run(testArrayAttr)
|
||||
|
||||
|
|
|
@ -8,9 +8,11 @@ def run(f):
|
|||
f()
|
||||
gc.collect()
|
||||
assert Context._get_live_count() == 0
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testParsePrint
|
||||
@run
|
||||
def testParsePrint():
|
||||
ctx = Context()
|
||||
t = Type.parse("i32", ctx)
|
||||
|
@ -22,12 +24,11 @@ def testParsePrint():
|
|||
# CHECK: Type(i32)
|
||||
print(repr(t))
|
||||
|
||||
run(testParsePrint)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testParseError
|
||||
# TODO: Hook the diagnostic manager to capture a more meaningful error
|
||||
# message.
|
||||
@run
|
||||
def testParseError():
|
||||
ctx = Context()
|
||||
try:
|
||||
|
@ -38,10 +39,9 @@ def testParseError():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testParseError)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testTypeEq
|
||||
@run
|
||||
def testTypeEq():
|
||||
ctx = Context()
|
||||
t1 = Type.parse("i32", ctx)
|
||||
|
@ -56,10 +56,19 @@ def testTypeEq():
|
|||
# CHECK: t1 == None: False
|
||||
print("t1 == None:", t1 == None)
|
||||
|
||||
run(testTypeEq)
|
||||
|
||||
# CHECK-LABEL: TEST: testTypeCast
|
||||
@run
|
||||
def testTypeCast():
|
||||
ctx = Context()
|
||||
t1 = Type.parse("i32", ctx)
|
||||
t2 = Type(t1)
|
||||
# CHECK: t1 == t2: True
|
||||
print("t1 == t2:", t1 == t2)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testTypeIsInstance
|
||||
@run
|
||||
def testTypeIsInstance():
|
||||
ctx = Context()
|
||||
t1 = Type.parse("i32", ctx)
|
||||
|
@ -71,10 +80,9 @@ def testTypeIsInstance():
|
|||
# CHECK: True
|
||||
print(F32Type.isinstance(t2))
|
||||
|
||||
run(testTypeIsInstance)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
|
||||
@run
|
||||
def testTypeEqDoesNotRaise():
|
||||
ctx = Context()
|
||||
t1 = Type.parse("i32", ctx)
|
||||
|
@ -86,10 +94,9 @@ def testTypeEqDoesNotRaise():
|
|||
# CHECK: True
|
||||
print(t1 != None)
|
||||
|
||||
run(testTypeEqDoesNotRaise)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testTypeCapsule
|
||||
@run
|
||||
def testTypeCapsule():
|
||||
with Context() as ctx:
|
||||
t1 = Type.parse("i32", ctx)
|
||||
|
@ -100,10 +107,9 @@ def testTypeCapsule():
|
|||
assert t2 == t1
|
||||
assert t2.context is ctx
|
||||
|
||||
run(testTypeCapsule)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testStandardTypeCasts
|
||||
@run
|
||||
def testStandardTypeCasts():
|
||||
ctx = Context()
|
||||
t1 = Type.parse("i32", ctx)
|
||||
|
@ -119,10 +125,9 @@ def testStandardTypeCasts():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testStandardTypeCasts)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testIntegerType
|
||||
@run
|
||||
def testIntegerType():
|
||||
with Context() as ctx:
|
||||
i32 = IntegerType(Type.parse("i32"))
|
||||
|
@ -158,17 +163,16 @@ def testIntegerType():
|
|||
# CHECK: unsigned: ui64
|
||||
print("unsigned:", IntegerType.get_unsigned(64))
|
||||
|
||||
run(testIntegerType)
|
||||
|
||||
# CHECK-LABEL: TEST: testIndexType
|
||||
@run
|
||||
def testIndexType():
|
||||
with Context() as ctx:
|
||||
# CHECK: index type: index
|
||||
print("index type:", IndexType.get())
|
||||
|
||||
run(testIndexType)
|
||||
|
||||
# CHECK-LABEL: TEST: testFloatType
|
||||
@run
|
||||
def testFloatType():
|
||||
with Context():
|
||||
# CHECK: float: bf16
|
||||
|
@ -180,17 +184,17 @@ def testFloatType():
|
|||
# CHECK: float: f64
|
||||
print("float:", F64Type.get())
|
||||
|
||||
run(testFloatType)
|
||||
|
||||
# CHECK-LABEL: TEST: testNoneType
|
||||
@run
|
||||
def testNoneType():
|
||||
with Context():
|
||||
# CHECK: none type: none
|
||||
print("none type:", NoneType.get())
|
||||
|
||||
run(testNoneType)
|
||||
|
||||
# CHECK-LABEL: TEST: testComplexType
|
||||
@run
|
||||
def testComplexType():
|
||||
with Context() as ctx:
|
||||
complex_i32 = ComplexType(Type.parse("complex<i32>"))
|
||||
|
@ -210,13 +214,12 @@ def testComplexType():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testComplexType)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testConcreteShapedType
|
||||
# Shaped type is not a kind of builtin types, it is the base class for vectors,
|
||||
# memrefs and tensors, so this test case uses an instance of vector to test the
|
||||
# shaped type. The class hierarchy is preserved on the python side.
|
||||
@run
|
||||
def testConcreteShapedType():
|
||||
with Context() as ctx:
|
||||
vector = VectorType(Type.parse("vector<2x3xf32>"))
|
||||
|
@ -239,20 +242,20 @@ def testConcreteShapedType():
|
|||
# CHECK: isinstance(ShapedType): True
|
||||
print("isinstance(ShapedType):", isinstance(vector, ShapedType))
|
||||
|
||||
run(testConcreteShapedType)
|
||||
|
||||
# CHECK-LABEL: TEST: testAbstractShapedType
|
||||
# Tests that ShapedType operates as an abstract base class of a concrete
|
||||
# shaped type (using vector as an example).
|
||||
@run
|
||||
def testAbstractShapedType():
|
||||
ctx = Context()
|
||||
vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
|
||||
# CHECK: element type: f32
|
||||
print("element type:", vector.element_type)
|
||||
|
||||
run(testAbstractShapedType)
|
||||
|
||||
# CHECK-LABEL: TEST: testVectorType
|
||||
@run
|
||||
def testVectorType():
|
||||
with Context(), Location.unknown():
|
||||
f32 = F32Type.get()
|
||||
|
@ -269,9 +272,9 @@ def testVectorType():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testVectorType)
|
||||
|
||||
# CHECK-LABEL: TEST: testRankedTensorType
|
||||
@run
|
||||
def testRankedTensorType():
|
||||
with Context(), Location.unknown():
|
||||
f32 = F32Type.get()
|
||||
|
@ -291,9 +294,9 @@ def testRankedTensorType():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testRankedTensorType)
|
||||
|
||||
# CHECK-LABEL: TEST: testUnrankedTensorType
|
||||
@run
|
||||
def testUnrankedTensorType():
|
||||
with Context(), Location.unknown():
|
||||
f32 = F32Type.get()
|
||||
|
@ -333,9 +336,9 @@ def testUnrankedTensorType():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testUnrankedTensorType)
|
||||
|
||||
# CHECK-LABEL: TEST: testMemRefType
|
||||
@run
|
||||
def testMemRefType():
|
||||
with Context(), Location.unknown():
|
||||
f32 = F32Type.get()
|
||||
|
@ -369,9 +372,9 @@ def testMemRefType():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testMemRefType)
|
||||
|
||||
# CHECK-LABEL: TEST: testUnrankedMemRefType
|
||||
@run
|
||||
def testUnrankedMemRefType():
|
||||
with Context(), Location.unknown():
|
||||
f32 = F32Type.get()
|
||||
|
@ -411,9 +414,9 @@ def testUnrankedMemRefType():
|
|||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testUnrankedMemRefType)
|
||||
|
||||
# CHECK-LABEL: TEST: testTupleType
|
||||
@run
|
||||
def testTupleType():
|
||||
with Context() as ctx:
|
||||
i32 = IntegerType(Type.parse("i32"))
|
||||
|
@ -428,10 +431,9 @@ def testTupleType():
|
|||
# CHECK: pos-th type in the tuple type: f32
|
||||
print("pos-th type in the tuple type:", tuple_type.get_type(1))
|
||||
|
||||
run(testTupleType)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFunctionType
|
||||
@run
|
||||
def testFunctionType():
|
||||
with Context() as ctx:
|
||||
input_types = [IntegerType.get_signless(32),
|
||||
|
@ -442,6 +444,3 @@ def testFunctionType():
|
|||
print("INPUTS:", func.inputs)
|
||||
# CHECK: RESULTS: [Type(index)]
|
||||
print("RESULTS:", func.results)
|
||||
|
||||
|
||||
run(testFunctionType)
|
||||
|
|
Loading…
Reference in New Issue