More progress on type checking ValueRanges (#118870)

Type checking Python is a pain. Here are my learnings:

* The types for heavily polymorphic code is going to be verbose, no way around it. I originally was hoping I could lean on polymorphism with a bounded TypeVar to compactly write signatures for many of the ValueRanges methods, but I ran into some unworkaroundable mypy bugs. Writing out all the types explicitly and using `@overload` liberally works pretty well, so I think I recommend people do that instead of trying to do fancy things.
* Sympy is missing annotations for assumptions, because they are all metaprogrammed. I don't really relish maintaining a typeshed for sympy, so I wrote a small mypy plugin to add them in.
* GADT style refinement is... just not a good idea in practice. Mypy easily gets confused whether or not a return value from a refined section is allowed for the outer return type. So many of these have been replaced with less informative implementation types and more informative external types via overloads. Hopefully this is good for use sites.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118870
Approved by: https://github.com/Skylion007, https://github.com/albanD
This commit is contained in:
Edward Z. Yang 2024-02-05 09:22:07 -08:00 committed by PyTorch MergeBot
parent b92819a039
commit b816760a2f
4 changed files with 171 additions and 42 deletions

View File

@ -2,7 +2,7 @@
# test_run_mypy in test/test_type_hints.py uses this string)
[mypy]
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
plugins = mypy_plugins/check_mypy_version.py, mypy_plugins/sympy_mypy_plugin.py, numpy.typing.mypy_plugin
cache_dir = .mypy_cache/normal
allow_redefinition = True

View File

@ -0,0 +1,59 @@
from mypy.plugin import Plugin
from mypy.plugins.common import add_attribute_to_class
from mypy.types import NoneType, UnionType
class SympyPlugin(Plugin):
def get_base_class_hook(self, fullname: str):
if fullname == "sympy.core.basic.Basic":
return add_assumptions
return None
def add_assumptions(ctx) -> None:
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
# (do not import sympy to speedup mypy plugin load time)
assumptions = [
"hermitian",
"prime",
"noninteger",
"negative",
"antihermitian",
"infinite",
"finite",
"irrational",
"extended_positive",
"nonpositive",
"odd",
"algebraic",
"integer",
"rational",
"extended_real",
"nonnegative",
"transcendental",
"extended_nonzero",
"extended_negative",
"composite",
"complex",
"imaginary",
"nonzero",
"zero",
"even",
"positive",
"polar",
"extended_nonpositive",
"extended_nonnegative",
"real",
"commutative",
]
for a in assumptions:
add_attribute_to_class(
ctx.api,
ctx.cls,
f"is_{a}",
UnionType([ctx.api.named_type("builtins.bool"), NoneType()]),
)
def plugin(version: str):
return SympyPlugin

View File

@ -67,7 +67,7 @@ def _run_mypy() -> Dict[str, List[str]]:
directory,
]
)
assert not stderr, directory
assert not stderr, stderr
stdout = stdout.replace("*", "")
# Parse the output

View File

@ -8,7 +8,7 @@ import operator
import math
import logging
import torch
from typing import Dict, Optional, SupportsFloat, TypeVar, Generic, cast, Union
from typing import Dict, Optional, SupportsFloat, TypeVar, Generic, Union, overload, Callable, TYPE_CHECKING
from typing_extensions import TypeGuard
from torch._prims_common import dtype_to_type
@ -70,8 +70,25 @@ def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]:
return not vr.is_bool
ExprIn = Union[int, float, sympy.Expr]
BoolIn = Union[bool, SympyBoolean]
AllIn = Union[ExprIn, BoolIn]
ExprFn = Callable[[sympy.Expr], sympy.Expr]
ExprFn2 = Callable[[sympy.Expr, sympy.Expr], sympy.Expr]
BoolFn = Callable[[SympyBoolean], SympyBoolean]
BoolFn2 = Callable[[SympyBoolean, SympyBoolean], SympyBoolean]
AllFn = Union[ExprFn, BoolFn]
AllFn2 = Union[ExprFn2, BoolFn2]
@dataclasses.dataclass(frozen=True)
class ValueRanges(Generic[_T]):
if TYPE_CHECKING:
# ruff doesn't understand circular references but mypy does
ExprVR = ValueRanges[sympy.Expr] # noqa: F821
BoolVR = ValueRanges[SympyBoolean] # noqa: F821
AllVR = Union[ExprVR, BoolVR]
# Although the type signature here suggests you can pass any
# sympy expression, in practice the analysis here only works
# with constant sympy expressions
@ -79,7 +96,15 @@ class ValueRanges(Generic[_T]):
upper: _T
is_bool: bool
def __init__(self, lower: Union[_T, bool, int, float], upper: Union[_T, bool, int, float]) -> None:
@overload
def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None:
...
@overload
def __init__(self: ValueRanges[SympyBoolean], lower: BoolIn, upper: BoolIn) -> None:
...
def __init__(self, lower: AllIn, upper: AllIn) -> None:
lower = simple_sympify(lower)
upper = simple_sympify(upper)
# TODO: when the bounds have free variables, this may be
@ -92,15 +117,15 @@ class ValueRanges(Generic[_T]):
object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean))
assert isinstance(upper, SympyBoolean) == self.is_bool
def boolify(self):
if self.is_bool:
def boolify(self) -> ValueRanges[SympyBoolean]:
if vr_is_bool(self):
return self
elif self == ValueRanges.unknown():
return ValueRanges.unknown_bool()
else:
raise AssertionError(f"not bool like {self}")
def __contains__(self, x):
def __contains__(self, x: AllIn) -> bool:
x = simple_sympify(x)
return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)
@ -109,30 +134,42 @@ class ValueRanges(Generic[_T]):
return self & other
# Intersection
def __and__(self: ValueRanges[_T], other: ValueRanges[_T]) -> ValueRanges[_T]:
@overload
def __and__(self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr]) -> ValueRanges[sympy.Expr]:
...
@overload
def __and__(self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean]) -> ValueRanges[SympyBoolean]:
...
def __and__(self: AllVR, other: AllVR) -> AllVR:
if other == ValueRanges.unknown():
return self
if self == ValueRanges.unknown():
return other
assert self.is_bool == other.is_bool, (self, other)
if vr_is_bool(self):
return cast(ValueRanges[_T], ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)))
elif vr_is_expr(self):
return cast(ValueRanges[_T], ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper)))
if self.is_bool:
return ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper))
else:
raise AssertionError("impossible")
return ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper))
# Union
def __or__(self, other) -> ValueRanges:
@overload
def __or__(self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr]) -> ValueRanges[sympy.Expr]:
...
@overload
def __or__(self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean]) -> ValueRanges[SympyBoolean]:
...
def __or__(self: AllVR, other: AllVR) -> AllVR:
if ValueRanges.unknown() in (self, other):
return ValueRanges.unknown()
assert self.is_bool == other.is_bool, (self, other)
if vr_is_bool(self):
return cast(ValueRanges[_T], ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper)))
elif vr_is_expr(self):
return cast(ValueRanges[_T], ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper)))
if self.is_bool:
return ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper))
else:
raise AssertionError("impossible")
return ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper))
def is_singleton(self) -> bool:
return self.lower == self.upper
@ -146,43 +183,76 @@ class ValueRanges(Generic[_T]):
def unknown_bool() -> ValueRanges[SympyBoolean]:
return ValueRanges(sympy.false, sympy.true)
@classmethod
def wrap(cls, arg):
@overload
@staticmethod
# work around the fact that bool and int overlap
def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap]
...
@overload
@staticmethod
def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR:
...
@staticmethod
def wrap(arg: Union[AllIn, AllVR]) -> AllVR:
if isinstance(arg, ValueRanges):
return arg
return ValueRanges(arg, arg)
# arg is either ExprIn or BoolIn, but we don't know it here
return ValueRanges(arg, arg) # type: ignore[arg-type]
@classmethod
def increasing_map(cls, x, fn):
@staticmethod
def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
"""Increasing: x <= y => f(x) <= f(y)."""
x = cls.wrap(x)
x = ValueRanges.wrap(x)
return ValueRanges(fn(x.lower), fn(x.upper))
@classmethod
def decreasing_map(cls, x, fn):
"""Decreasing: x <= y => f(x) >= f(y)."""
x = cls.wrap(x)
return ValueRanges(fn(x.upper), fn(x.lower))
@overload
@staticmethod
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
...
@classmethod
def monotone_map(cls, x, fn):
@overload
@staticmethod
def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR:
...
@staticmethod
def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR:
"""Decreasing: x <= y => f(x) >= f(y)."""
x = ValueRanges.wrap(x)
# consistently either Expr or Bool, but we don't know it here
return ValueRanges(fn(x.upper), fn(x.lower)) # type: ignore[arg-type]
@staticmethod
def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
"""It's increasing or decreasing."""
x = cls.wrap(x)
x = ValueRanges.wrap(x)
l = fn(x.lower)
u = fn(x.upper)
return ValueRanges(min(l, u), max(l, u))
@classmethod
def convex_min_zero_map(cls, x, fn):
@staticmethod
def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
"""Fn is convex and has a minimum at 0."""
x = ValueRanges.wrap(x)
if 0 in x:
return ValueRanges(0, max(fn(x.lower), fn(x.upper)))
else:
return cls.monotone_map(x, fn)
return ValueRanges.monotone_map(x, fn)
@classmethod
def coordinatewise_increasing_map(cls, x, y, fn):
@overload
@staticmethod
def coordinatewise_increasing_map(x: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR], fn: ExprFn2) -> ExprVR:
...
@overload
@staticmethod
def coordinatewise_increasing_map(x: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR], fn: BoolFn2) -> BoolVR:
...
@staticmethod
def coordinatewise_increasing_map(x: Union[AllIn, AllVR], y: Union[AllIn, AllVR], fn: AllFn2) -> AllVR:
"""
It's increasing on each coordinate.
@ -190,10 +260,10 @@ class ValueRanges(Generic[_T]):
For every 1 <= i <= n and x_i <= y_i we have that
f(x1, .., xn) <= f(x1, , yi, ..., xn)
"""
x, y = cls.wrap(x), cls.wrap(y)
x, y = ValueRanges.wrap(x), ValueRanges.wrap(y)
return ValueRanges(
fn(x.lower, y.lower),
fn(x.upper, y.upper),
fn(x.lower, y.lower), # type: ignore[arg-type]
fn(x.upper, y.upper), # type: ignore[arg-type]
)
@classmethod
@ -450,7 +520,7 @@ class SymPyValueRangeAnalysis:
b = ValueRanges.wrap(b)
# Performs upcasting first
def fn_(x, y):
def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr:
# Poorman's version of upcasting in Sympy
# Inf is not a float...
if x.is_Integer and y.is_Integer: