forked from OSchip/llvm-project
[mlir][python] Provide more convenient wrappers for std.ConstantOp
Constructing a ConstantOp using the default-generated API is verbose and requires to specify the constant type twice: for the result type of the operation and for the type of the attribute. It also requires to explicitly construct the attribute. Provide custom constructors that take the type once and accept a raw value instead of the attribute. This requires dynamic dispatch based on type in the constructor. Also provide the corresponding accessors to raw values. In addition, provide a "refinement" class ConstantIndexOp similar to what exists in C++. Unlike other "op view" Python classes, operations cannot be automatically downcasted to this class since it does not correspond to a specific operation name. It only exists to simplify construction of the operation. Depends On D110946 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110947
This commit is contained in:
parent
ed9e52f3af
commit
3a3a09f654
|
@ -136,7 +136,9 @@ declare_mlir_dialect_python_bindings(
|
||||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||||
TD_FILE dialects/StandardOps.td
|
TD_FILE dialects/StandardOps.td
|
||||||
SOURCES dialects/std.py
|
SOURCES
|
||||||
|
dialects/std.py
|
||||||
|
dialects/_std_ops_ext.py
|
||||||
DIALECT_NAME std)
|
DIALECT_NAME std)
|
||||||
|
|
||||||
declare_mlir_dialect_python_bindings(
|
declare_mlir_dialect_python_bindings(
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ..ir import *
|
||||||
|
from .builtin import FuncOp
|
||||||
|
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||||
|
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError("Error loading imports from extension module") from e
|
||||||
|
|
||||||
|
|
||||||
|
def _isa(obj: Any, cls: type):
|
||||||
|
try:
|
||||||
|
cls(obj)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_any_of(obj: Any, classes: List[type]):
|
||||||
|
return any(_isa(obj, cls) for cls in classes)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_integer_like_type(type: Type):
|
||||||
|
return _is_any_of(type, [IntegerType, IndexType])
|
||||||
|
|
||||||
|
|
||||||
|
def _is_float_type(type: Type):
|
||||||
|
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantOp:
|
||||||
|
"""Specialization for the constant op class."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
result: Type,
|
||||||
|
value: Union[int, float, Attribute],
|
||||||
|
*,
|
||||||
|
loc=None,
|
||||||
|
ip=None):
|
||||||
|
if isinstance(value, int):
|
||||||
|
super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
|
||||||
|
elif isinstance(value, float):
|
||||||
|
super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
|
||||||
|
else:
|
||||||
|
super().__init__(result, value, loc=loc, ip=ip)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_index(cls, value: int, *, loc=None, ip=None):
|
||||||
|
"""Create an index-typed constant."""
|
||||||
|
return cls(
|
||||||
|
IndexType.get(context=_get_default_loc_context(loc)),
|
||||||
|
value,
|
||||||
|
loc=loc,
|
||||||
|
ip=ip)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self):
|
||||||
|
return self.results[0].type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def literal_value(self) -> Union[int, float]:
|
||||||
|
if _is_integer_like_type(self.type):
|
||||||
|
return IntegerAttr(self.value).value
|
||||||
|
elif _is_float_type(self.type):
|
||||||
|
return FloatAttr(self.value).value
|
||||||
|
else:
|
||||||
|
raise ValueError("only integer and float constants have literal values")
|
|
@ -0,0 +1,64 @@
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
from mlir.ir import *
|
||||||
|
from mlir.dialects import std
|
||||||
|
|
||||||
|
|
||||||
|
def constructAndPrintInModule(f):
|
||||||
|
print("\nTEST:", f.__name__)
|
||||||
|
with Context(), Location.unknown():
|
||||||
|
module = Module.create()
|
||||||
|
with InsertionPoint(module.body):
|
||||||
|
f()
|
||||||
|
print(module)
|
||||||
|
return f
|
||||||
|
|
||||||
|
# CHECK-LABEL: TEST: testConstantOp
|
||||||
|
|
||||||
|
@constructAndPrintInModule
|
||||||
|
def testConstantOp():
|
||||||
|
c1 = std.ConstantOp(IntegerType.get_signless(32), 42)
|
||||||
|
c2 = std.ConstantOp(IntegerType.get_signless(64), 100)
|
||||||
|
c3 = std.ConstantOp(F32Type.get(), 3.14)
|
||||||
|
c4 = std.ConstantOp(F64Type.get(), 1.23)
|
||||||
|
# CHECK: 42
|
||||||
|
print(c1.literal_value)
|
||||||
|
|
||||||
|
# CHECK: 100
|
||||||
|
print(c2.literal_value)
|
||||||
|
|
||||||
|
# CHECK: 3.140000104904175
|
||||||
|
print(c3.literal_value)
|
||||||
|
|
||||||
|
# CHECK: 1.23
|
||||||
|
print(c4.literal_value)
|
||||||
|
|
||||||
|
# CHECK: = constant 42 : i32
|
||||||
|
# CHECK: = constant 100 : i64
|
||||||
|
# CHECK: = constant 3.140000e+00 : f32
|
||||||
|
# CHECK: = constant 1.230000e+00 : f64
|
||||||
|
|
||||||
|
# CHECK-LABEL: TEST: testVectorConstantOp
|
||||||
|
@constructAndPrintInModule
|
||||||
|
def testVectorConstantOp():
|
||||||
|
int_type = IntegerType.get_signless(32)
|
||||||
|
vec_type = VectorType.get([2, 2], int_type)
|
||||||
|
c1 = std.ConstantOp(vec_type,
|
||||||
|
DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
|
||||||
|
try:
|
||||||
|
print(c1.literal_value)
|
||||||
|
except ValueError as e:
|
||||||
|
assert "only integer and float constants have literal values" in str(e)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
# CHECK: = constant dense<42> : vector<2x2xi32>
|
||||||
|
|
||||||
|
# CHECK-LABEL: TEST: testConstantIndexOp
|
||||||
|
@constructAndPrintInModule
|
||||||
|
def testConstantIndexOp():
|
||||||
|
c1 = std.ConstantOp.create_index(10)
|
||||||
|
# CHECK: 10
|
||||||
|
print(c1.literal_value)
|
||||||
|
|
||||||
|
# CHECK: = constant 10 : index
|
Loading…
Reference in New Issue