[mlir][python] Provide more convenient wrappers for std.ConstantOp

Constructing a ConstantOp using the default-generated API is verbose and
requires to specify the constant type twice: for the result type of the
operation and for the type of the attribute. It also requires to explicitly
construct the attribute. Provide custom constructors that take the type once
and accept a raw value instead of the attribute. This requires dynamic dispatch
based on type in the constructor. Also provide the corresponding accessors to
raw values.

In addition, provide a "refinement" class ConstantIndexOp similar to what
exists in C++. Unlike other "op view" Python classes, operations cannot be
automatically downcasted to this class since it does not correspond to a
specific operation name. It only exists to simplify construction of the
operation.

Depends On D110946

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110947
This commit is contained in:
Alex Zinenko 2021-10-04 11:38:53 +02:00
parent ed9e52f3af
commit 3a3a09f654
3 changed files with 138 additions and 1 deletions

View File

@ -136,7 +136,9 @@ declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/StandardOps.td
SOURCES dialects/std.py
SOURCES
dialects/std.py
dialects/_std_ops_ext.py
DIALECT_NAME std)
declare_mlir_dialect_python_bindings(

View File

@ -0,0 +1,71 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from .builtin import FuncOp
from ._ods_common import get_default_loc_context as _get_default_loc_context
from typing import Any, List, Optional, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
def _isa(obj: Any, cls: type):
try:
cls(obj)
except ValueError:
return False
return True
def _is_any_of(obj: Any, classes: List[type]):
return any(_isa(obj, cls) for cls in classes)
def _is_integer_like_type(type: Type):
return _is_any_of(type, [IntegerType, IndexType])
def _is_float_type(type: Type):
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
class ConstantOp:
"""Specialization for the constant op class."""
def __init__(self,
result: Type,
value: Union[int, float, Attribute],
*,
loc=None,
ip=None):
if isinstance(value, int):
super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
else:
super().__init__(result, value, loc=loc, ip=ip)
@classmethod
def create_index(cls, value: int, *, loc=None, ip=None):
"""Create an index-typed constant."""
return cls(
IndexType.get(context=_get_default_loc_context(loc)),
value,
loc=loc,
ip=ip)
@property
def type(self):
return self.results[0].type
@property
def literal_value(self) -> Union[int, float]:
if _is_integer_like_type(self.type):
return IntegerAttr(self.value).value
elif _is_float_type(self.type):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")

View File

@ -0,0 +1,64 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import std
def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f
# CHECK-LABEL: TEST: testConstantOp
@constructAndPrintInModule
def testConstantOp():
c1 = std.ConstantOp(IntegerType.get_signless(32), 42)
c2 = std.ConstantOp(IntegerType.get_signless(64), 100)
c3 = std.ConstantOp(F32Type.get(), 3.14)
c4 = std.ConstantOp(F64Type.get(), 1.23)
# CHECK: 42
print(c1.literal_value)
# CHECK: 100
print(c2.literal_value)
# CHECK: 3.140000104904175
print(c3.literal_value)
# CHECK: 1.23
print(c4.literal_value)
# CHECK: = constant 42 : i32
# CHECK: = constant 100 : i64
# CHECK: = constant 3.140000e+00 : f32
# CHECK: = constant 1.230000e+00 : f64
# CHECK-LABEL: TEST: testVectorConstantOp
@constructAndPrintInModule
def testVectorConstantOp():
int_type = IntegerType.get_signless(32)
vec_type = VectorType.get([2, 2], int_type)
c1 = std.ConstantOp(vec_type,
DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
try:
print(c1.literal_value)
except ValueError as e:
assert "only integer and float constants have literal values" in str(e)
else:
assert False
# CHECK: = constant dense<42> : vector<2x2xi32>
# CHECK-LABEL: TEST: testConstantIndexOp
@constructAndPrintInModule
def testConstantIndexOp():
c1 = std.ConstantOp.create_index(10)
# CHECK: 10
print(c1.literal_value)
# CHECK: = constant 10 : index