Skip to content

Commit 65ace1e

Browse files
Merge pull request #302 from egraphs-good/fix-python-3.10
Fix python 3.10 compat
2 parents d7971e4 + fd6e1ff commit 65ace1e

File tree

10 files changed

+143
-39
lines changed

10 files changed

+143
-39
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ jobs:
4242
- uses: astral-sh/setup-uv@v6
4343
with:
4444
enable-cache: true
45-
# Run on oldest Python version to catch more errors
46-
python-version: "3.10"
4745
- uses: dtolnay/rust-toolchain@1.79.0
4846
- uses: Swatinem/rust-cache@v2
4947
- run: uv sync --extra test --locked

docs/changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7-
- fix using `f64Like` when not importing star (also properly includes removal of `Callable` special case from previous release).
7+
- Fix using `f64Like` when not importing star (also properly includes removal of `Callable` special case from previous release).
8+
- Fix Python 3.10 compatibility
89

910
## 10.0.1 (2025-04-06)
1011

python/egglog/declarations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __egg_decls__(self) -> Declarations:
9393
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
9494
# instead raise explicitly
9595
except AttributeError as err:
96-
msg = f"Cannot resolve declerations for {self}"
96+
msg = f"Cannot resolve declarations for {self}"
9797
raise RuntimeError(msg) from err
9898

9999

@@ -308,11 +308,11 @@ class ClassTypeVarRef:
308308
module: str
309309

310310
def to_just(self) -> JustTypeRef:
311-
msg = "egglog does not support generic classes yet."
311+
msg = f"{self}: egglog does not support generic classes yet."
312312
raise NotImplementedError(msg)
313313

314314
def __str__(self) -> str:
315-
return f"{self.module}.{self.name}"
315+
return str(self.to_type_var())
316316

317317
@classmethod
318318
def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef:

python/egglog/egraph.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
ClassVar,
1717
Generic,
1818
Literal,
19-
Never,
2019
TypeAlias,
2120
TypedDict,
2221
TypeVar,
@@ -26,7 +25,7 @@
2625
)
2726

2827
import graphviz
29-
from typing_extensions import ParamSpec, Self, Unpack, assert_never
28+
from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
3029

3130
from . import bindings
3231
from .conversion import *
@@ -36,6 +35,7 @@
3635
from .pretty import pretty_decl
3736
from .runtime import *
3837
from .thunk import *
38+
from .version_compat import *
3939

4040
if TYPE_CHECKING:
4141
from .builtins import String, Unit
@@ -169,8 +169,9 @@ def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *, ad
169169
except bindings.EggSmolError as err:
170170
if display:
171171
egraph.display()
172-
err.add_note(f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})")
173-
raise
172+
raise add_note(
173+
f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})", err
174+
) from None
174175
return egraph
175176

176177

@@ -492,8 +493,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
492493
reverse_args=reverse_args,
493494
)
494495
except Exception as e:
495-
e.add_note(f"Error processing {cls_name}.{method_name}")
496-
raise
496+
raise add_note(f"Error processing {cls_name}.{method_name}", e) from None
497497

498498
if not builtin and not isinstance(ref, InitRef) and not mutates:
499499
add_default_funcs.append(add_rewrite)
@@ -627,7 +627,7 @@ def _fn_decl(
627627
)
628628
decls |= merged
629629

630-
# defer this in generator so it doesnt resolve for builtins eagerly
630+
# defer this in generator so it doesn't resolve for builtins eagerly
631631
args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
632632
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
633633
res_thunk: Callable[[], object]
@@ -671,7 +671,7 @@ def _fn_decl(
671671
)
672672
res_ref = ref
673673
decls.set_function_decl(ref, decl)
674-
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
674+
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset, context=f"creating {ref}")
675675
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
676676

677677

@@ -1040,8 +1040,7 @@ def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._ExtractReport:
10401040
bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
10411041
)
10421042
except BaseException as e:
1043-
e.add_note("Extracting: " + str(expr))
1044-
raise
1043+
raise add_note("Extracting: " + str(expr), e) # noqa: B904
10451044
extract_report = self._egraph.extract_report()
10461045
if not extract_report:
10471046
msg = "No extract report saved"

python/egglog/exp/array_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969

7070
from egglog import *
7171
from egglog.runtime import RuntimeExpr
72+
from egglog.version_compat import add_note
7273

7374
from .program_gen import *
7475

@@ -1198,13 +1199,13 @@ def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ...
11981199

11991200
NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
12001201

1201-
converter(NDArray, IndexKey, IndexKey.ndarray)
1202-
converter(Value, NDArray, NDArray.scalar)
1202+
converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v))
1203+
converter(Value, NDArray, lambda v: NDArray.scalar(v))
12031204
# Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
12041205
# to prefer upcasting in the other direction when we can, which is safer at runtime
12051206
converter(NDArray, Value, lambda n: n.to_value(), 100)
1206-
converter(TupleValue, NDArray, NDArray.vector)
1207-
converter(TupleInt, TupleValue, TupleValue.from_tuple_int)
1207+
converter(TupleValue, NDArray, lambda v: NDArray.vector(v))
1208+
converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v))
12081209

12091210

12101211
@array_api_ruleset.register
@@ -1383,8 +1384,8 @@ def int(cls, value: Int) -> IntOrTuple: ...
13831384
def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...
13841385

13851386

1386-
converter(Int, IntOrTuple, IntOrTuple.int)
1387-
converter(TupleInt, IntOrTuple, IntOrTuple.tuple)
1387+
converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v))
1388+
converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v))
13881389

13891390

13901391
class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
@@ -1395,7 +1396,7 @@ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ...
13951396

13961397

13971398
converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
1398-
converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
1399+
converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v))
13991400

14001401

14011402
@function
@@ -1980,6 +1981,5 @@ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: Built
19801981
extracted = egraph.extract(prim_expr)
19811982
except BaseException as e:
19821983
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1983-
e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1984-
raise
1984+
raise add_note(f"Cannot evaluate {egraph.extract(expr)}", e) # noqa: B904
19851985
return extracted.eval() # type: ignore[attr-defined]

python/egglog/runtime.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .pretty import *
2323
from .thunk import Thunk
2424
from .type_constraint_solver import *
25+
from .version_compat import *
2526

2627
if TYPE_CHECKING:
2728
from collections.abc import Iterable
@@ -249,8 +250,7 @@ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
249250
try:
250251
cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
251252
except Exception as e:
252-
e.add_note(f"Error processing class {self.__egg_tp__.name}")
253-
raise
253+
raise add_note(f"Error processing class {self.__egg_tp__.name}", e) from None
254254

255255
preserved_methods = cls_decl.preserved_methods
256256
if name in preserved_methods:
@@ -281,6 +281,9 @@ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
281281
def __str__(self) -> str:
282282
return str(self.__egg_tp__)
283283

284+
def __repr__(self) -> str:
285+
return str(self)
286+
284287
# Make hashable so can go in Union
285288
def __hash__(self) -> int:
286289
return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
@@ -315,8 +318,7 @@ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs:
315318
try:
316319
signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).signature
317320
except Exception as e:
318-
e.add_note(f"Failed to find callable {self}")
319-
raise
321+
raise add_note(f"Failed to find callable {self}", e) # noqa: B904
320322
decls = self.__egg_decls__.copy()
321323
# Special case function application bc we dont support variadic generics yet generally
322324
if signature == "fn-app":

python/egglog/thunk.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ class Thunk(Generic[T, Unpack[TS]]):
4141
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
4242

4343
@classmethod
44-
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
44+
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], context: str | None = None) -> Thunk[T, Unpack[TS]]:
4545
"""
4646
Create a thunk based on some functions and some partial args.
4747
4848
If the function is called while it is being resolved recursively it will raise an exception.
4949
"""
50-
return cls(Unresolved(fn, args))
50+
return cls(Unresolved(fn, args, context))
5151

5252
@classmethod
5353
def value(cls, value: T) -> Thunk[T]:
@@ -57,12 +57,12 @@ def __call__(self) -> T:
5757
match self.state:
5858
case Resolved(value):
5959
return value
60-
case Unresolved(fn, args):
60+
case Unresolved(fn, args, context):
6161
self.state = Resolving()
6262
try:
6363
res = fn(*args)
6464
except Exception as e:
65-
self.state = Error(e)
65+
self.state = Error(e, context)
6666
raise e from None
6767
else:
6868
self.state = Resolved(res)
@@ -83,6 +83,7 @@ class Resolved(Generic[T]):
8383
class Unresolved(Generic[T, Unpack[TS]]):
8484
fn: Callable[[Unpack[TS]], T]
8585
args: tuple[Unpack[TS]]
86+
context: str | None
8687

8788

8889
@dataclass
@@ -93,3 +94,4 @@ class Resolving:
9394
@dataclass
9495
class Error:
9596
e: Exception
97+
context: str | None

python/egglog/type_constraint_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def substitute_typevars(self, tp: TypeOrVarRef, cls_name: str | None = None) ->
107107
try:
108108
return self._cls_typevar_index_to_type[cls_name][tp]
109109
except KeyError as e:
110-
raise TypeConstraintError(f"Not enough bound typevars for {tp} in class {cls_name}") from e
110+
raise TypeConstraintError(f"Not enough bound typevars for {tp!r} in class {cls_name}") from e
111111
case TypeRefWithVars(name, args):
112112
return JustTypeRef(name, tuple(self.substitute_typevars(arg, cls_name) for arg in args))
113113
assert_never(tp)

python/egglog/version_compat.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import collections
2+
import sys
3+
import types
4+
import typing
5+
6+
BEFORE_3_11 = sys.version_info < (3, 11)
7+
8+
__all__ = ["add_note"]
9+
10+
11+
def add_note(message: str, exc: BaseException) -> BaseException:
12+
"""
13+
Backwards compatible add_note for Python <= 3.10
14+
"""
15+
if BEFORE_3_11:
16+
return exc
17+
exc.add_note(message)
18+
return exc
19+
20+
21+
# For Python version 3.10 need to monkeypatch this function so that RuntimeClass type parameters
22+
# will be collected as typevars
23+
if BEFORE_3_11:
24+
25+
@typing.no_type_check
26+
def _collect_type_vars_monkeypatch(types_, typevar_types=None):
27+
"""
28+
Collect all type variable contained
29+
in types in order of first appearance (lexicographic order). For example::
30+
31+
_collect_type_vars((T, List[S, T])) == (T, S)
32+
"""
33+
from .runtime import RuntimeClass
34+
35+
if typevar_types is None:
36+
typevar_types = typing.TypeVar
37+
tvars = []
38+
for t in types_:
39+
if isinstance(t, typevar_types) and t not in tvars:
40+
tvars.append(t)
41+
# **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
42+
if isinstance(t, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)): # type: ignore[name-defined]
43+
tvars.extend([t for t in t.__parameters__ if t not in tvars])
44+
return tuple(tvars)
45+
46+
typing._collect_type_vars = _collect_type_vars_monkeypatch # type: ignore[attr-defined]
47+
48+
@typing.no_type_check
49+
@typing._tp_cache
50+
def __getitem__monkeypatch(self, params): # noqa: C901, PLR0912
51+
from .runtime import RuntimeClass
52+
53+
if self.__origin__ in (typing.Generic, typing.Protocol):
54+
# Can't subscript Generic[...] or Protocol[...].
55+
raise TypeError(f"Cannot subscript already-subscripted {self}")
56+
if not isinstance(params, tuple):
57+
params = (params,)
58+
params = tuple(typing._type_convert(p) for p in params)
59+
if self._paramspec_tvars and any(isinstance(t, typing.ParamSpec) for t in self.__parameters__):
60+
params = typing._prepare_paramspec_params(self, params)
61+
else:
62+
typing._check_generic(self, params, len(self.__parameters__))
63+
64+
subst = dict(zip(self.__parameters__, params, strict=False))
65+
new_args = []
66+
for arg in self.__args__:
67+
if isinstance(arg, self._typevar_types):
68+
if isinstance(arg, typing.ParamSpec):
69+
arg = subst[arg] # noqa: PLW2901
70+
if not typing._is_param_expr(arg):
71+
raise TypeError(f"Expected a list of types, an ellipsis, ParamSpec, or Concatenate. Got {arg}")
72+
else:
73+
arg = subst[arg] # noqa: PLW2901
74+
# **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
75+
elif isinstance(arg, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)):
76+
subparams = arg.__parameters__
77+
if subparams:
78+
subargs = tuple(subst[x] for x in subparams)
79+
arg = arg[subargs] # noqa: PLW2901
80+
# Required to flatten out the args for CallableGenericAlias
81+
if self.__origin__ == collections.abc.Callable and isinstance(arg, tuple):
82+
new_args.extend(arg)
83+
else:
84+
new_args.append(arg)
85+
return self.copy_with(tuple(new_args))
86+
87+
typing._GenericAlias.__getitem__ = __getitem__monkeypatch # type: ignore[attr-defined]

0 commit comments

Comments
 (0)