More progress on type checking ValueRanges (#118870)
Type checking Python is a pain. Here are my learnings: * The types for heavily polymorphic code is going to be verbose, no way around it. I originally was hoping I could lean on polymorphism with a bounded TypeVar to compactly write signatures for many of the ValueRanges methods, but I ran into some unworkaroundable mypy bugs. Writing out all the types explicitly and using `@overload` liberally works pretty well, so I think I recommend people do that instead of trying to do fancy things. * Sympy is missing annotations for assumptions, because they are all metaprogrammed. I don't really relish maintaining a typeshed for sympy, so I wrote a small mypy plugin to add them in. * GADT style refinement is... just not a good idea in practice. Mypy easily gets confused whether or not a return value from a refined section is allowed for the outer return type. So many of these have been replaced with less informative implementation types and more informative external types via overloads. Hopefully this is good for use sites. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118870 Approved by: https://github.com/Skylion007, https://github.com/albanD
This commit is contained in:
parent
b92819a039
commit
b816760a2f
2
mypy.ini
2
mypy.ini
|
@ -2,7 +2,7 @@
|
|||
# test_run_mypy in test/test_type_hints.py uses this string)
|
||||
|
||||
[mypy]
|
||||
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
|
||||
plugins = mypy_plugins/check_mypy_version.py, mypy_plugins/sympy_mypy_plugin.py, numpy.typing.mypy_plugin
|
||||
|
||||
cache_dir = .mypy_cache/normal
|
||||
allow_redefinition = True
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
from mypy.plugin import Plugin
|
||||
from mypy.plugins.common import add_attribute_to_class
|
||||
from mypy.types import NoneType, UnionType
|
||||
|
||||
|
||||
class SympyPlugin(Plugin):
|
||||
def get_base_class_hook(self, fullname: str):
|
||||
if fullname == "sympy.core.basic.Basic":
|
||||
return add_assumptions
|
||||
return None
|
||||
|
||||
|
||||
def add_assumptions(ctx) -> None:
|
||||
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
|
||||
# (do not import sympy to speedup mypy plugin load time)
|
||||
assumptions = [
|
||||
"hermitian",
|
||||
"prime",
|
||||
"noninteger",
|
||||
"negative",
|
||||
"antihermitian",
|
||||
"infinite",
|
||||
"finite",
|
||||
"irrational",
|
||||
"extended_positive",
|
||||
"nonpositive",
|
||||
"odd",
|
||||
"algebraic",
|
||||
"integer",
|
||||
"rational",
|
||||
"extended_real",
|
||||
"nonnegative",
|
||||
"transcendental",
|
||||
"extended_nonzero",
|
||||
"extended_negative",
|
||||
"composite",
|
||||
"complex",
|
||||
"imaginary",
|
||||
"nonzero",
|
||||
"zero",
|
||||
"even",
|
||||
"positive",
|
||||
"polar",
|
||||
"extended_nonpositive",
|
||||
"extended_nonnegative",
|
||||
"real",
|
||||
"commutative",
|
||||
]
|
||||
for a in assumptions:
|
||||
add_attribute_to_class(
|
||||
ctx.api,
|
||||
ctx.cls,
|
||||
f"is_{a}",
|
||||
UnionType([ctx.api.named_type("builtins.bool"), NoneType()]),
|
||||
)
|
||||
|
||||
|
||||
def plugin(version: str):
|
||||
return SympyPlugin
|
|
@ -67,7 +67,7 @@ def _run_mypy() -> Dict[str, List[str]]:
|
|||
directory,
|
||||
]
|
||||
)
|
||||
assert not stderr, directory
|
||||
assert not stderr, stderr
|
||||
stdout = stdout.replace("*", "")
|
||||
|
||||
# Parse the output
|
||||
|
|
|
@ -8,7 +8,7 @@ import operator
|
|||
import math
|
||||
import logging
|
||||
import torch
|
||||
from typing import Dict, Optional, SupportsFloat, TypeVar, Generic, cast, Union
|
||||
from typing import Dict, Optional, SupportsFloat, TypeVar, Generic, Union, overload, Callable, TYPE_CHECKING
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from torch._prims_common import dtype_to_type
|
||||
|
@ -70,8 +70,25 @@ def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]:
|
|||
return not vr.is_bool
|
||||
|
||||
|
||||
ExprIn = Union[int, float, sympy.Expr]
|
||||
BoolIn = Union[bool, SympyBoolean]
|
||||
AllIn = Union[ExprIn, BoolIn]
|
||||
ExprFn = Callable[[sympy.Expr], sympy.Expr]
|
||||
ExprFn2 = Callable[[sympy.Expr, sympy.Expr], sympy.Expr]
|
||||
BoolFn = Callable[[SympyBoolean], SympyBoolean]
|
||||
BoolFn2 = Callable[[SympyBoolean, SympyBoolean], SympyBoolean]
|
||||
AllFn = Union[ExprFn, BoolFn]
|
||||
AllFn2 = Union[ExprFn2, BoolFn2]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ValueRanges(Generic[_T]):
|
||||
if TYPE_CHECKING:
|
||||
# ruff doesn't understand circular references but mypy does
|
||||
ExprVR = ValueRanges[sympy.Expr] # noqa: F821
|
||||
BoolVR = ValueRanges[SympyBoolean] # noqa: F821
|
||||
AllVR = Union[ExprVR, BoolVR]
|
||||
|
||||
# Although the type signature here suggests you can pass any
|
||||
# sympy expression, in practice the analysis here only works
|
||||
# with constant sympy expressions
|
||||
|
@ -79,7 +96,15 @@ class ValueRanges(Generic[_T]):
|
|||
upper: _T
|
||||
is_bool: bool
|
||||
|
||||
def __init__(self, lower: Union[_T, bool, int, float], upper: Union[_T, bool, int, float]) -> None:
|
||||
@overload
|
||||
def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self: ValueRanges[SympyBoolean], lower: BoolIn, upper: BoolIn) -> None:
|
||||
...
|
||||
|
||||
def __init__(self, lower: AllIn, upper: AllIn) -> None:
|
||||
lower = simple_sympify(lower)
|
||||
upper = simple_sympify(upper)
|
||||
# TODO: when the bounds have free variables, this may be
|
||||
|
@ -92,15 +117,15 @@ class ValueRanges(Generic[_T]):
|
|||
object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean))
|
||||
assert isinstance(upper, SympyBoolean) == self.is_bool
|
||||
|
||||
def boolify(self):
|
||||
if self.is_bool:
|
||||
def boolify(self) -> ValueRanges[SympyBoolean]:
|
||||
if vr_is_bool(self):
|
||||
return self
|
||||
elif self == ValueRanges.unknown():
|
||||
return ValueRanges.unknown_bool()
|
||||
else:
|
||||
raise AssertionError(f"not bool like {self}")
|
||||
|
||||
def __contains__(self, x):
|
||||
def __contains__(self, x: AllIn) -> bool:
|
||||
x = simple_sympify(x)
|
||||
return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)
|
||||
|
||||
|
@ -109,30 +134,42 @@ class ValueRanges(Generic[_T]):
|
|||
return self & other
|
||||
|
||||
# Intersection
|
||||
def __and__(self: ValueRanges[_T], other: ValueRanges[_T]) -> ValueRanges[_T]:
|
||||
@overload
|
||||
def __and__(self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr]) -> ValueRanges[sympy.Expr]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __and__(self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean]) -> ValueRanges[SympyBoolean]:
|
||||
...
|
||||
|
||||
def __and__(self: AllVR, other: AllVR) -> AllVR:
|
||||
if other == ValueRanges.unknown():
|
||||
return self
|
||||
if self == ValueRanges.unknown():
|
||||
return other
|
||||
assert self.is_bool == other.is_bool, (self, other)
|
||||
if vr_is_bool(self):
|
||||
return cast(ValueRanges[_T], ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)))
|
||||
elif vr_is_expr(self):
|
||||
return cast(ValueRanges[_T], ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper)))
|
||||
if self.is_bool:
|
||||
return ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper))
|
||||
else:
|
||||
raise AssertionError("impossible")
|
||||
return ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper))
|
||||
|
||||
# Union
|
||||
def __or__(self, other) -> ValueRanges:
|
||||
@overload
|
||||
def __or__(self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr]) -> ValueRanges[sympy.Expr]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __or__(self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean]) -> ValueRanges[SympyBoolean]:
|
||||
...
|
||||
|
||||
def __or__(self: AllVR, other: AllVR) -> AllVR:
|
||||
if ValueRanges.unknown() in (self, other):
|
||||
return ValueRanges.unknown()
|
||||
assert self.is_bool == other.is_bool, (self, other)
|
||||
if vr_is_bool(self):
|
||||
return cast(ValueRanges[_T], ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper)))
|
||||
elif vr_is_expr(self):
|
||||
return cast(ValueRanges[_T], ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper)))
|
||||
if self.is_bool:
|
||||
return ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper))
|
||||
else:
|
||||
raise AssertionError("impossible")
|
||||
return ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper))
|
||||
|
||||
def is_singleton(self) -> bool:
|
||||
return self.lower == self.upper
|
||||
|
@ -146,43 +183,76 @@ class ValueRanges(Generic[_T]):
|
|||
def unknown_bool() -> ValueRanges[SympyBoolean]:
|
||||
return ValueRanges(sympy.false, sympy.true)
|
||||
|
||||
@classmethod
|
||||
def wrap(cls, arg):
|
||||
@overload
|
||||
@staticmethod
|
||||
# work around the fact that bool and int overlap
|
||||
def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap]
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def wrap(arg: Union[AllIn, AllVR]) -> AllVR:
|
||||
if isinstance(arg, ValueRanges):
|
||||
return arg
|
||||
return ValueRanges(arg, arg)
|
||||
# arg is either ExprIn or BoolIn, but we don't know it here
|
||||
return ValueRanges(arg, arg) # type: ignore[arg-type]
|
||||
|
||||
@classmethod
|
||||
def increasing_map(cls, x, fn):
|
||||
@staticmethod
|
||||
def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
|
||||
"""Increasing: x <= y => f(x) <= f(y)."""
|
||||
x = cls.wrap(x)
|
||||
x = ValueRanges.wrap(x)
|
||||
return ValueRanges(fn(x.lower), fn(x.upper))
|
||||
|
||||
@classmethod
|
||||
def decreasing_map(cls, x, fn):
|
||||
"""Decreasing: x <= y => f(x) >= f(y)."""
|
||||
x = cls.wrap(x)
|
||||
return ValueRanges(fn(x.upper), fn(x.lower))
|
||||
@overload
|
||||
@staticmethod
|
||||
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def monotone_map(cls, x, fn):
|
||||
@overload
|
||||
@staticmethod
|
||||
def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR:
|
||||
"""Decreasing: x <= y => f(x) >= f(y)."""
|
||||
x = ValueRanges.wrap(x)
|
||||
# consistently either Expr or Bool, but we don't know it here
|
||||
return ValueRanges(fn(x.upper), fn(x.lower)) # type: ignore[arg-type]
|
||||
|
||||
@staticmethod
|
||||
def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
|
||||
"""It's increasing or decreasing."""
|
||||
x = cls.wrap(x)
|
||||
x = ValueRanges.wrap(x)
|
||||
l = fn(x.lower)
|
||||
u = fn(x.upper)
|
||||
return ValueRanges(min(l, u), max(l, u))
|
||||
|
||||
@classmethod
|
||||
def convex_min_zero_map(cls, x, fn):
|
||||
@staticmethod
|
||||
def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
|
||||
"""Fn is convex and has a minimum at 0."""
|
||||
x = ValueRanges.wrap(x)
|
||||
if 0 in x:
|
||||
return ValueRanges(0, max(fn(x.lower), fn(x.upper)))
|
||||
else:
|
||||
return cls.monotone_map(x, fn)
|
||||
return ValueRanges.monotone_map(x, fn)
|
||||
|
||||
@classmethod
|
||||
def coordinatewise_increasing_map(cls, x, y, fn):
|
||||
@overload
|
||||
@staticmethod
|
||||
def coordinatewise_increasing_map(x: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR], fn: ExprFn2) -> ExprVR:
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def coordinatewise_increasing_map(x: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR], fn: BoolFn2) -> BoolVR:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def coordinatewise_increasing_map(x: Union[AllIn, AllVR], y: Union[AllIn, AllVR], fn: AllFn2) -> AllVR:
|
||||
"""
|
||||
It's increasing on each coordinate.
|
||||
|
||||
|
@ -190,10 +260,10 @@ class ValueRanges(Generic[_T]):
|
|||
For every 1 <= i <= n and x_i <= y_i we have that
|
||||
f(x1, .., xn) <= f(x1, , yi, ..., xn)
|
||||
"""
|
||||
x, y = cls.wrap(x), cls.wrap(y)
|
||||
x, y = ValueRanges.wrap(x), ValueRanges.wrap(y)
|
||||
return ValueRanges(
|
||||
fn(x.lower, y.lower),
|
||||
fn(x.upper, y.upper),
|
||||
fn(x.lower, y.lower), # type: ignore[arg-type]
|
||||
fn(x.upper, y.upper), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -450,7 +520,7 @@ class SymPyValueRangeAnalysis:
|
|||
b = ValueRanges.wrap(b)
|
||||
|
||||
# Performs upcasting first
|
||||
def fn_(x, y):
|
||||
def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr:
|
||||
# Poorman's version of upcasting in Sympy
|
||||
# Inf is not a float...
|
||||
if x.is_Integer and y.is_Integer:
|
||||
|
|
Loading…
Reference in New Issue