Skip to content

Commit d62bc60

Browse files
author
oscar.butler
committed
Use flags for feature sets
1 parent 1298441 commit d62bc60

File tree

2 files changed

+243
-118
lines changed

2 files changed

+243
-118
lines changed

tail_recursive/__init__.py

Lines changed: 148 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ def factorial(n):
2626
import enum
2727
import functools
2828
import itertools
29-
import types
30-
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
31-
29+
from typing import Any, Callable, Dict, List, Optional, Type, Union
3230

3331
# Dunder methods from https://docs.python.org/3/reference/datamodel.html.
3432
_NUMERIC_DUNDER_METH_BASE_NAMES: List[str] = [
@@ -48,74 +46,83 @@ def factorial(n):
4846
"or",
4947
]
5048

51-
_NUMERIC_DUNDER_METH_NAMES: List[str] = _NUMERIC_DUNDER_METH_BASE_NAMES + [
49+
_NUMERIC_RIGHT_DUNDER_METH_NAMES: List[str] = [
5250
f"r{name}" for name in _NUMERIC_DUNDER_METH_BASE_NAMES
53-
] + [
54-
f"i{name}" for name in _NUMERIC_DUNDER_METH_BASE_NAMES
55-
] + [
56-
"neg",
57-
"pos",
58-
"abs",
59-
"invert",
60-
"complex",
61-
"int",
62-
"float",
63-
"index",
64-
"round",
65-
"trunc",
66-
"floor",
67-
"ciel"
6851
]
52+
_NUMERIC_RIGHT_DUNDER_METH_NAMES_SET_WITH_UNDERSCORES = {
53+
f"__{meth_name}__" for meth_name in _NUMERIC_RIGHT_DUNDER_METH_NAMES
54+
}
6955

70-
_DUNDER_METH_NAMES: List[str] = [
71-
# Cannot be overridden because they will break functionality:
72-
# "new",
73-
# "init",
74-
# "del",
75-
# "getattribute",
76-
# "setattr",
77-
# "get",
78-
# "set",
79-
# "delete",
80-
# "set_name",
81-
# "init_subclass",
82-
# "prepare",
83-
#
84-
# getattr and delattr have custom overrides (see below).
85-
"repr",
86-
"str",
87-
"bytes",
88-
"format",
89-
"lt",
90-
"le",
91-
"eq",
92-
"ne",
93-
"gt",
94-
"ge",
95-
"hash",
96-
"bool",
97-
"dir",
98-
"instancecheck",
99-
"subclasscheck",
100-
"class_getitem",
101-
"call",
102-
"len",
103-
"length_hint",
104-
"getitem",
105-
"setitem",
106-
"delitem",
107-
"missing",
108-
"iter",
109-
"reversed",
110-
"contains",
111-
"enter",
112-
"exit",
113-
"await",
114-
"aiter",
115-
"anext",
116-
"aenter",
117-
"aexit",
118-
] + _NUMERIC_DUNDER_METH_NAMES
56+
_NUMERIC_DUNDER_METH_NAMES: List[str] = \
57+
_NUMERIC_DUNDER_METH_BASE_NAMES \
58+
+ _NUMERIC_RIGHT_DUNDER_METH_NAMES \
59+
+ [f"i{name}" for name in _NUMERIC_DUNDER_METH_BASE_NAMES] \
60+
+ [
61+
"neg",
62+
"pos",
63+
"abs",
64+
"invert",
65+
"complex",
66+
"int",
67+
"float",
68+
"index",
69+
"round",
70+
"trunc",
71+
"floor",
72+
"ciel"
73+
]
74+
75+
_DUNDER_METH_NAMES: List[str] = \
76+
[
77+
# Cannot be overridden because they will break functionality:
78+
# "new",
79+
# "init",
80+
# "del",
81+
# "getattribute",
82+
# "setattr",
83+
# "get",
84+
# "set",
85+
# "delete",
86+
# "set_name",
87+
# "init_subclass",
88+
# "prepare",
89+
#
90+
# getattr and delattr have custom overrides (see below).
91+
"repr",
92+
"str",
93+
"bytes",
94+
"format",
95+
"lt",
96+
"le",
97+
"eq",
98+
"ne",
99+
"gt",
100+
"ge",
101+
"hash",
102+
"bool",
103+
"dir",
104+
"instancecheck",
105+
"subclasscheck",
106+
"class_getitem",
107+
"call",
108+
"len",
109+
"length_hint",
110+
"getitem",
111+
"setitem",
112+
"delitem",
113+
"missing",
114+
"iter",
115+
"reversed",
116+
"contains",
117+
"enter",
118+
"exit",
119+
"await",
120+
"aiter",
121+
"anext",
122+
"aenter",
123+
"aexit",
124+
] \
125+
+ _NUMERIC_DUNDER_METH_NAMES
119126

120127

121128
@dataclass
@@ -151,8 +158,8 @@ def __init__(self, _accessing_object: _ArgsAndKwargsStore):
151158
self._accessing_object = _accessing_object
152159
self._last_arg_index = len(self._accessing_object._args) - 1
153160
self.length = (
154-
self._last_arg_index
155-
+ len(self._accessing_object._kwargs) + 1
161+
self._last_arg_index
162+
+ len(self._accessing_object._kwargs) + 1
156163
)
157164
self._kwargs_index_key_map = {
158165
index: key for index, key in zip(
@@ -181,13 +188,13 @@ class TailCall(abc.ABC, _FuncStore, _ArgsAndKwargsStore):
181188
def _to_string(self) -> str:
182189
return f"{tail_recursive(self._func)}.tail_call({self._args_and_kwargs_string})"
183190

184-
@ abc.abstractmethod
191+
@abc.abstractmethod
185192
def _resolve(self):
186193
"""Lazily and sequentially evaluates recursive tail calls while maintaining same size of callstack."""
187194
...
188195

189196

190-
class TailCallWithoutNestedCallResolutionAndDunderOverloads(TailCall):
197+
class TailCallBase(TailCall):
191198

192199
def _resolve(self):
193200
resolution_value = self._func(*self._args, **self._kwargs)
@@ -199,7 +206,7 @@ def _resolve(self):
199206
return resolution_value
200207

201208

202-
@ dataclass(init=False)
209+
@dataclass(init=False)
203210
class _TailCallStackItem:
204211
tail_call: TailCall
205212
indexed_args_and_kwargs: _IndexedArgsAndKwargsAccess
@@ -211,7 +218,7 @@ def __init__(self, tail_call: TailCall):
211218
self.resolving_arg_or_kwarg_with_index = None
212219

213220

214-
@ dataclass(init=False)
221+
@dataclass(init=False)
215222
class _TailCallStack:
216223
stack: List[_TailCallStackItem]
217224
length: int
@@ -220,7 +227,7 @@ def __init__(self, initial_item: TailCall):
220227
self.stack = [_TailCallStackItem(initial_item)]
221228
self.length = 1
222229

223-
@ property
230+
@property
224231
def last_item(self):
225232
return self.stack[-1]
226233

@@ -244,16 +251,57 @@ def set_arg_or_kwarg_of_last_item_to_resolution(self, resolution: Any):
244251

245252

246253
@dataclass
247-
class TailCallWithNestedCallResolutionAndDunderOverloads(TailCall):
254+
class TailCallWithNestedCallResolution(TailCall):
255+
256+
def _resolve(self):
257+
tail_call_stack = _TailCallStack(initial_item=self)
258+
while True:
259+
if tail_call_stack.last_item.resolving_arg_or_kwarg_with_index is None:
260+
start_arg_index = 0
261+
else:
262+
start_arg_index = tail_call_stack.last_item.resolving_arg_or_kwarg_with_index + 1
263+
for arg_index in range(start_arg_index, tail_call_stack.last_item.indexed_args_and_kwargs.length):
264+
arg = tail_call_stack.last_item.indexed_args_and_kwargs.get(
265+
arg_index
266+
)
267+
if isinstance(arg, TailCall):
268+
tail_call_stack.last_item.resolving_arg_or_kwarg_with_index = arg_index
269+
tail_call_stack.push(arg)
270+
break
271+
# Else block is evaluated if loop is not broken out of.
272+
else:
273+
resolution = tail_call_stack.pop_item_resolution()
274+
if isinstance(resolution, TailCall):
275+
tail_call_stack.push(resolution)
276+
elif tail_call_stack.length > 0:
277+
tail_call_stack.set_arg_or_kwarg_of_last_item_to_resolution(
278+
resolution
279+
)
280+
else:
281+
return resolution
282+
283+
284+
class TailCallWithDunderOverloads(TailCallBase):
248285

249286
@staticmethod
250287
def _tail_call_dunder_meth_factory(dunder_meth_name: str):
251288

289+
# If <self>.__r<operation>__(other) does not exist, try <other>.__<operation>__(self)
290+
if dunder_meth_name in _NUMERIC_RIGHT_DUNDER_METH_NAMES_SET_WITH_UNDERSCORES:
291+
def func(self, other, *args, **kwargs) -> Any:
292+
try:
293+
return getattr(self, dunder_meth_name)(other, *args, **kwargs)
294+
except AttributeError:
295+
return getattr(other, f"__{dunder_meth_name[3:]}")(self, *args, **kwargs)
296+
else:
297+
# Ignore differing parameter signatures
298+
def func(self, *args, **kwargs) -> Any: # type: ignore[misc]
299+
return getattr(self, dunder_meth_name)(*args, **kwargs)
300+
252301
def dunder_meth(self, *args, **kwargs):
253302
tail_call_class = type(self)
254303
return tail_call_class(
255-
_func=lambda self, *args, **kwargs:
256-
getattr(self, dunder_meth_name)(*args, **kwargs),
304+
_func=func,
257305
_args=[self] + list(args),
258306
_kwargs=kwargs
259307
)
@@ -288,51 +336,32 @@ def __delattr__(self, name):
288336
_args=[self, name], _kwargs={}
289337
)
290338

291-
def _resolve(self):
292-
tail_call_stack = _TailCallStack(initial_item=self)
293-
while True:
294-
if tail_call_stack.last_item.resolving_arg_or_kwarg_with_index is None:
295-
start_arg_index = 0
296-
else:
297-
start_arg_index = tail_call_stack.last_item.resolving_arg_or_kwarg_with_index + 1
298-
for arg_index in range(start_arg_index, tail_call_stack.last_item.indexed_args_and_kwargs.length):
299-
arg = tail_call_stack.last_item.indexed_args_and_kwargs.get(
300-
arg_index
301-
)
302-
if isinstance(arg, TailCall):
303-
tail_call_stack.last_item.resolving_arg_or_kwarg_with_index = arg_index
304-
tail_call_stack.push(arg)
305-
break
306-
# Else block is evaluated if loop is not broken out of.
307-
else:
308-
resolution = tail_call_stack.pop_item_resolution()
309-
if isinstance(resolution, TailCall):
310-
tail_call_stack.push(resolution)
311-
elif tail_call_stack.length > 0:
312-
tail_call_stack.set_arg_or_kwarg_of_last_item_to_resolution(
313-
resolution
314-
)
315-
else:
316-
return resolution
317339

340+
class TailCallWithNestedCallResolutionAndDunderOverloads(TailCallWithNestedCallResolution, TailCallWithDunderOverloads):
341+
pass
318342

319-
@ enum.unique
320-
class FeatureSet(enum.Enum):
343+
344+
@enum.unique
345+
class FeatureSet(enum.IntFlag):
321346
"""Different ways of resolving nested tail calls."""
322347

323-
BASE: str = "base"
324-
FULL: str = "full"
348+
BASE = 0
349+
NESTED_CALLS = 1
350+
OVERLOADING = 2
351+
FULL = NESTED_CALLS | OVERLOADING
325352

326353

327354
FEATURE_SET_TAILCALL_SUBCLASS_MAP: Dict[FeatureSet, Type[TailCall]] = {
328-
FeatureSet.BASE: TailCallWithoutNestedCallResolutionAndDunderOverloads,
355+
FeatureSet.BASE: TailCallBase,
356+
FeatureSet.NESTED_CALLS: TailCallWithNestedCallResolution,
357+
FeatureSet.OVERLOADING: TailCallWithDunderOverloads,
358+
FeatureSet.NESTED_CALLS | FeatureSet.OVERLOADING: TailCallWithNestedCallResolutionAndDunderOverloads,
329359
FeatureSet.FULL: TailCallWithNestedCallResolutionAndDunderOverloads,
330360
}
331361

332362

333-
@ dataclass(init=False)
363+
@dataclass(init=False)
334364
class TailCallable(_FuncStore):
335-
336365
feature_set: FeatureSet
337366

338367
def __init__(self, _func: Callable[..., Any], *, feature_set: Union[FeatureSet, str] = FeatureSet.FULL):
@@ -343,7 +372,10 @@ def __init__(self, _func: Callable[..., Any], *, feature_set: Union[FeatureSet,
343372
if isinstance(feature_set, FeatureSet):
344373
self.feature_set = feature_set
345374
else:
346-
self.feature_set = FeatureSet(feature_set)
375+
try:
376+
self.feature_set = getattr(FeatureSet, feature_set.upper())
377+
except AttributeError as err:
378+
raise ValueError(f"'{feature_set}' is not a valid FeatureSet") from err
347379

348380
def __repr__(self) -> str:
349381
return f"{tail_recursive.__qualname__}(func={self._func_repr})"
@@ -367,9 +399,9 @@ def f():
367399

368400

369401
def tail_recursive(
370-
_func: Optional[Callable[..., Any]] = None,
371-
*,
372-
feature_set: Union[FeatureSet, str] = FeatureSet.FULL
402+
_func: Optional[Callable[..., Any]] = None,
403+
*,
404+
feature_set: Union[FeatureSet, str] = FeatureSet.FULL
373405
):
374406
"""A decorator that gives your functions the ability to be tail recursive.
375407
@@ -403,6 +435,7 @@ def factorial(n):
403435
Methods:
404436
tail_call(*args, **kwargs)
405437
"""
438+
406439
def decorator(func):
407440
return TailCallable(func, feature_set=feature_set)
408441

0 commit comments

Comments
 (0)