@@ -26,9 +26,7 @@ def factorial(n):
26
26
import enum
27
27
import functools
28
28
import itertools
29
- import types
30
- from typing import Any , Callable , Dict , Iterable , List , Optional , Type , Union
31
-
29
+ from typing import Any , Callable , Dict , List , Optional , Type , Union
32
30
33
31
# Dunder methods from https://docs.python.org/3/reference/datamodel.html.
34
32
_NUMERIC_DUNDER_METH_BASE_NAMES : List [str ] = [
@@ -48,74 +46,83 @@ def factorial(n):
48
46
"or" ,
49
47
]
50
48
51
- _NUMERIC_DUNDER_METH_NAMES : List [str ] = _NUMERIC_DUNDER_METH_BASE_NAMES + [
49
+ _NUMERIC_RIGHT_DUNDER_METH_NAMES : List [str ] = [
52
50
f"r{ name } " for name in _NUMERIC_DUNDER_METH_BASE_NAMES
53
- ] + [
54
- f"i{ name } " for name in _NUMERIC_DUNDER_METH_BASE_NAMES
55
- ] + [
56
- "neg" ,
57
- "pos" ,
58
- "abs" ,
59
- "invert" ,
60
- "complex" ,
61
- "int" ,
62
- "float" ,
63
- "index" ,
64
- "round" ,
65
- "trunc" ,
66
- "floor" ,
67
- "ciel"
68
51
]
52
+ _NUMERIC_RIGHT_DUNDER_METH_NAMES_SET_WITH_UNDERSCORES = {
53
+ f"__{ meth_name } __" for meth_name in _NUMERIC_RIGHT_DUNDER_METH_NAMES
54
+ }
69
55
70
- _DUNDER_METH_NAMES : List [str ] = [
71
- # Cannot be overridden because they will break functionality:
72
- # "new",
73
- # "init",
74
- # "del",
75
- # "getattribute",
76
- # "setattr",
77
- # "get",
78
- # "set",
79
- # "delete",
80
- # "set_name",
81
- # "init_subclass",
82
- # "prepare",
83
- #
84
- # getattr and delattr have custom overrides (see below).
85
- "repr" ,
86
- "str" ,
87
- "bytes" ,
88
- "format" ,
89
- "lt" ,
90
- "le" ,
91
- "eq" ,
92
- "ne" ,
93
- "gt" ,
94
- "ge" ,
95
- "hash" ,
96
- "bool" ,
97
- "dir" ,
98
- "instancecheck" ,
99
- "subclasscheck" ,
100
- "class_getitem" ,
101
- "call" ,
102
- "len" ,
103
- "length_hint" ,
104
- "getitem" ,
105
- "setitem" ,
106
- "delitem" ,
107
- "missing" ,
108
- "iter" ,
109
- "reversed" ,
110
- "contains" ,
111
- "enter" ,
112
- "exit" ,
113
- "await" ,
114
- "aiter" ,
115
- "anext" ,
116
- "aenter" ,
117
- "aexit" ,
118
- ] + _NUMERIC_DUNDER_METH_NAMES
56
+ _NUMERIC_DUNDER_METH_NAMES : List [str ] = \
57
+ _NUMERIC_DUNDER_METH_BASE_NAMES \
58
+ + _NUMERIC_RIGHT_DUNDER_METH_NAMES \
59
+ + [f"i{ name } " for name in _NUMERIC_DUNDER_METH_BASE_NAMES ] \
60
+ + [
61
+ "neg" ,
62
+ "pos" ,
63
+ "abs" ,
64
+ "invert" ,
65
+ "complex" ,
66
+ "int" ,
67
+ "float" ,
68
+ "index" ,
69
+ "round" ,
70
+ "trunc" ,
71
+ "floor" ,
72
+ "ciel"
73
+ ]
74
+
75
+ _DUNDER_METH_NAMES : List [str ] = \
76
+ [
77
+ # Cannot be overridden because they will break functionality:
78
+ # "new",
79
+ # "init",
80
+ # "del",
81
+ # "getattribute",
82
+ # "setattr",
83
+ # "get",
84
+ # "set",
85
+ # "delete",
86
+ # "set_name",
87
+ # "init_subclass",
88
+ # "prepare",
89
+ #
90
+ # getattr and delattr have custom overrides (see below).
91
+ "repr" ,
92
+ "str" ,
93
+ "bytes" ,
94
+ "format" ,
95
+ "lt" ,
96
+ "le" ,
97
+ "eq" ,
98
+ "ne" ,
99
+ "gt" ,
100
+ "ge" ,
101
+ "hash" ,
102
+ "bool" ,
103
+ "dir" ,
104
+ "instancecheck" ,
105
+ "subclasscheck" ,
106
+ "class_getitem" ,
107
+ "call" ,
108
+ "len" ,
109
+ "length_hint" ,
110
+ "getitem" ,
111
+ "setitem" ,
112
+ "delitem" ,
113
+ "missing" ,
114
+ "iter" ,
115
+ "reversed" ,
116
+ "contains" ,
117
+ "enter" ,
118
+ "exit" ,
119
+ "await" ,
120
+ "aiter" ,
121
+ "anext" ,
122
+ "aenter" ,
123
+ "aexit" ,
124
+ ] \
125
+ + _NUMERIC_DUNDER_METH_NAMES
119
126
120
127
121
128
@dataclass
@@ -151,8 +158,8 @@ def __init__(self, _accessing_object: _ArgsAndKwargsStore):
151
158
self ._accessing_object = _accessing_object
152
159
self ._last_arg_index = len (self ._accessing_object ._args ) - 1
153
160
self .length = (
154
- self ._last_arg_index
155
- + len (self ._accessing_object ._kwargs ) + 1
161
+ self ._last_arg_index
162
+ + len (self ._accessing_object ._kwargs ) + 1
156
163
)
157
164
self ._kwargs_index_key_map = {
158
165
index : key for index , key in zip (
@@ -181,13 +188,13 @@ class TailCall(abc.ABC, _FuncStore, _ArgsAndKwargsStore):
181
188
def _to_string (self ) -> str :
182
189
return f"{ tail_recursive (self ._func )} .tail_call({ self ._args_and_kwargs_string } )"
183
190
184
- @ abc .abstractmethod
191
+ @abc .abstractmethod
185
192
def _resolve (self ):
186
193
"""Lazily and sequentially evaluates recursive tail calls while maintaining same size of callstack."""
187
194
...
188
195
189
196
190
- class TailCallWithoutNestedCallResolutionAndDunderOverloads (TailCall ):
197
+ class TailCallBase (TailCall ):
191
198
192
199
def _resolve (self ):
193
200
resolution_value = self ._func (* self ._args , ** self ._kwargs )
@@ -199,7 +206,7 @@ def _resolve(self):
199
206
return resolution_value
200
207
201
208
202
- @ dataclass (init = False )
209
+ @dataclass (init = False )
203
210
class _TailCallStackItem :
204
211
tail_call : TailCall
205
212
indexed_args_and_kwargs : _IndexedArgsAndKwargsAccess
@@ -211,7 +218,7 @@ def __init__(self, tail_call: TailCall):
211
218
self .resolving_arg_or_kwarg_with_index = None
212
219
213
220
214
- @ dataclass (init = False )
221
+ @dataclass (init = False )
215
222
class _TailCallStack :
216
223
stack : List [_TailCallStackItem ]
217
224
length : int
@@ -220,7 +227,7 @@ def __init__(self, initial_item: TailCall):
220
227
self .stack = [_TailCallStackItem (initial_item )]
221
228
self .length = 1
222
229
223
- @ property
230
+ @property
224
231
def last_item (self ):
225
232
return self .stack [- 1 ]
226
233
@@ -244,16 +251,57 @@ def set_arg_or_kwarg_of_last_item_to_resolution(self, resolution: Any):
244
251
245
252
246
253
@dataclass
247
- class TailCallWithNestedCallResolutionAndDunderOverloads (TailCall ):
254
+ class TailCallWithNestedCallResolution (TailCall ):
255
+
256
+ def _resolve (self ):
257
+ tail_call_stack = _TailCallStack (initial_item = self )
258
+ while True :
259
+ if tail_call_stack .last_item .resolving_arg_or_kwarg_with_index is None :
260
+ start_arg_index = 0
261
+ else :
262
+ start_arg_index = tail_call_stack .last_item .resolving_arg_or_kwarg_with_index + 1
263
+ for arg_index in range (start_arg_index , tail_call_stack .last_item .indexed_args_and_kwargs .length ):
264
+ arg = tail_call_stack .last_item .indexed_args_and_kwargs .get (
265
+ arg_index
266
+ )
267
+ if isinstance (arg , TailCall ):
268
+ tail_call_stack .last_item .resolving_arg_or_kwarg_with_index = arg_index
269
+ tail_call_stack .push (arg )
270
+ break
271
+ # Else block is evaluated if loop is not broken out of.
272
+ else :
273
+ resolution = tail_call_stack .pop_item_resolution ()
274
+ if isinstance (resolution , TailCall ):
275
+ tail_call_stack .push (resolution )
276
+ elif tail_call_stack .length > 0 :
277
+ tail_call_stack .set_arg_or_kwarg_of_last_item_to_resolution (
278
+ resolution
279
+ )
280
+ else :
281
+ return resolution
282
+
283
+
284
+ class TailCallWithDunderOverloads (TailCallBase ):
248
285
249
286
@staticmethod
250
287
def _tail_call_dunder_meth_factory (dunder_meth_name : str ):
251
288
289
+ # If <self>.__r<operation>__(other) does not exist, try <other>.__<operation>__(self)
290
+ if dunder_meth_name in _NUMERIC_RIGHT_DUNDER_METH_NAMES_SET_WITH_UNDERSCORES :
291
+ def func (self , other , * args , ** kwargs ) -> Any :
292
+ try :
293
+ return getattr (self , dunder_meth_name )(other , * args , ** kwargs )
294
+ except AttributeError :
295
+ return getattr (other , f"__{ dunder_meth_name [3 :]} " )(self , * args , ** kwargs )
296
+ else :
297
+ # Ignore differing parameter signatures
298
+ def func (self , * args , ** kwargs ) -> Any : # type: ignore[misc]
299
+ return getattr (self , dunder_meth_name )(* args , ** kwargs )
300
+
252
301
def dunder_meth (self , * args , ** kwargs ):
253
302
tail_call_class = type (self )
254
303
return tail_call_class (
255
- _func = lambda self , * args , ** kwargs :
256
- getattr (self , dunder_meth_name )(* args , ** kwargs ),
304
+ _func = func ,
257
305
_args = [self ] + list (args ),
258
306
_kwargs = kwargs
259
307
)
@@ -288,51 +336,32 @@ def __delattr__(self, name):
288
336
_args = [self , name ], _kwargs = {}
289
337
)
290
338
291
- def _resolve (self ):
292
- tail_call_stack = _TailCallStack (initial_item = self )
293
- while True :
294
- if tail_call_stack .last_item .resolving_arg_or_kwarg_with_index is None :
295
- start_arg_index = 0
296
- else :
297
- start_arg_index = tail_call_stack .last_item .resolving_arg_or_kwarg_with_index + 1
298
- for arg_index in range (start_arg_index , tail_call_stack .last_item .indexed_args_and_kwargs .length ):
299
- arg = tail_call_stack .last_item .indexed_args_and_kwargs .get (
300
- arg_index
301
- )
302
- if isinstance (arg , TailCall ):
303
- tail_call_stack .last_item .resolving_arg_or_kwarg_with_index = arg_index
304
- tail_call_stack .push (arg )
305
- break
306
- # Else block is evaluated if loop is not broken out of.
307
- else :
308
- resolution = tail_call_stack .pop_item_resolution ()
309
- if isinstance (resolution , TailCall ):
310
- tail_call_stack .push (resolution )
311
- elif tail_call_stack .length > 0 :
312
- tail_call_stack .set_arg_or_kwarg_of_last_item_to_resolution (
313
- resolution
314
- )
315
- else :
316
- return resolution
317
339
340
+ class TailCallWithNestedCallResolutionAndDunderOverloads (TailCallWithNestedCallResolution , TailCallWithDunderOverloads ):
341
+ pass
318
342
319
- @ enum .unique
320
- class FeatureSet (enum .Enum ):
343
+
344
+ @enum .unique
345
+ class FeatureSet (enum .IntFlag ):
321
346
"""Different ways of resolving nested tail calls."""
322
347
323
- BASE : str = "base"
324
- FULL : str = "full"
348
+ BASE = 0
349
+ NESTED_CALLS = 1
350
+ OVERLOADING = 2
351
+ FULL = NESTED_CALLS | OVERLOADING
325
352
326
353
327
354
FEATURE_SET_TAILCALL_SUBCLASS_MAP : Dict [FeatureSet , Type [TailCall ]] = {
328
- FeatureSet .BASE : TailCallWithoutNestedCallResolutionAndDunderOverloads ,
355
+ FeatureSet .BASE : TailCallBase ,
356
+ FeatureSet .NESTED_CALLS : TailCallWithNestedCallResolution ,
357
+ FeatureSet .OVERLOADING : TailCallWithDunderOverloads ,
358
+ FeatureSet .NESTED_CALLS | FeatureSet .OVERLOADING : TailCallWithNestedCallResolutionAndDunderOverloads ,
329
359
FeatureSet .FULL : TailCallWithNestedCallResolutionAndDunderOverloads ,
330
360
}
331
361
332
362
333
- @ dataclass (init = False )
363
+ @dataclass (init = False )
334
364
class TailCallable (_FuncStore ):
335
-
336
365
feature_set : FeatureSet
337
366
338
367
def __init__ (self , _func : Callable [..., Any ], * , feature_set : Union [FeatureSet , str ] = FeatureSet .FULL ):
@@ -343,7 +372,10 @@ def __init__(self, _func: Callable[..., Any], *, feature_set: Union[FeatureSet,
343
372
if isinstance (feature_set , FeatureSet ):
344
373
self .feature_set = feature_set
345
374
else :
346
- self .feature_set = FeatureSet (feature_set )
375
+ try :
376
+ self .feature_set = getattr (FeatureSet , feature_set .upper ())
377
+ except AttributeError as err :
378
+ raise ValueError (f"'{ feature_set } ' is not a valid FeatureSet" ) from err
347
379
348
380
def __repr__ (self ) -> str :
349
381
return f"{ tail_recursive .__qualname__ } (func={ self ._func_repr } )"
@@ -367,9 +399,9 @@ def f():
367
399
368
400
369
401
def tail_recursive (
370
- _func : Optional [Callable [..., Any ]] = None ,
371
- * ,
372
- feature_set : Union [FeatureSet , str ] = FeatureSet .FULL
402
+ _func : Optional [Callable [..., Any ]] = None ,
403
+ * ,
404
+ feature_set : Union [FeatureSet , str ] = FeatureSet .FULL
373
405
):
374
406
"""A decorator that gives your functions the ability to be tail recursive.
375
407
@@ -403,6 +435,7 @@ def factorial(n):
403
435
Methods:
404
436
tail_call(*args, **kwargs)
405
437
"""
438
+
406
439
def decorator (func ):
407
440
return TailCallable (func , feature_set = feature_set )
408
441
0 commit comments