Skip to content

Commit 877cff6

Browse files
committed
Support nested tail calls
1 parent e6ce8bb commit 877cff6

File tree

2 files changed

+475
-124
lines changed

2 files changed

+475
-124
lines changed

tail_recursive/__init__.py

Lines changed: 252 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,256 @@
1-
"""Simple package for your tail recursion needs."""
1+
"""Simple package for your tail recursion needs.
22
3+
Use the ``tail_recursive`` decorator to define tail_recursive functions.
4+
5+
Example::
6+
7+
import sys
8+
from tail_recursive import tail_recursive
9+
10+
11+
@tail_recursive
12+
def mul(a, b):
13+
return a * b
14+
15+
@tail_recursive
16+
def factorial(n):
17+
if n == 1:
18+
return n
19+
# Nested tail calls are supported by default.
20+
return mul.tail_call(n, factorial.tail_call(n - 1))
21+
22+
23+
# Calls to tail recursive functions will not exceed the maximum recursion
24+
# depth, because functions are called sequentially.
25+
factorial(sys.getrecursionlimit() + 1)
26+
"""
27+
28+
import abc
329
from dataclasses import dataclass
4-
from functools import wraps
5-
from itertools import chain
6-
from typing import Any, Callable, Dict, Tuple
30+
import enum
31+
import functools
32+
import itertools
33+
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
34+
35+
36+
@dataclass
37+
class _FuncStore:
38+
func: Callable[..., Any]
39+
40+
@property
41+
def _func_repr(self) -> str:
42+
return repr(self.func)
43+
44+
45+
@dataclass
46+
class _ArgsAndKwargsStore:
47+
args: List[Any]
48+
kwargs: Dict[str, Any]
49+
50+
@property
51+
def _args_and_kwargs_string(self) -> str:
52+
return ', '.join(itertools.chain(
53+
(repr(arg) for arg in self.args),
54+
(f"{name}={repr(val)}" for name, val in self.kwargs.items())
55+
))
756

857

958
@dataclass(init=False)
10-
class tail_recursive:
59+
class _IndexedArgsAndKwargsAccess:
60+
length: int
61+
_accessing_object: _ArgsAndKwargsStore
62+
_last_arg_index: int
63+
_kwargs_index_key_map: Dict[int, str]
64+
65+
def __init__(self, _accessing_object: _ArgsAndKwargsStore):
66+
self._accessing_object = _accessing_object
67+
self._last_arg_index = len(self._accessing_object.args) - 1
68+
self.length = (
69+
self._last_arg_index
70+
+ len(self._accessing_object.kwargs) + 1
71+
)
72+
self._kwargs_index_key_map = {
73+
index: key for index, key in zip(
74+
range(self._last_arg_index + 1, self.length),
75+
self._accessing_object.kwargs.keys()
76+
)
77+
}
78+
79+
def get(self, index: int) -> Any:
80+
if index > self._last_arg_index:
81+
return self._accessing_object.kwargs[self._kwargs_index_key_map[index]]
82+
return self._accessing_object.args[index]
83+
84+
def set(self, index: int, val: Any) -> None:
85+
if index > self._last_arg_index:
86+
self._accessing_object.kwargs[self._kwargs_index_key_map[index]] = val
87+
else:
88+
self._accessing_object.args[index] = val
89+
90+
91+
# Mypy doesn't currently allow abstract dataclasses (see https://github.com/python/mypy/issues/5374).
92+
@dataclass # type: ignore
93+
class TailCall(abc.ABC, _FuncStore, _ArgsAndKwargsStore):
94+
"""Stores information necessary to lazily execute a function in the future."""
95+
96+
def __repr__(self) -> str:
97+
return f"{tail_recursive(self.func)}.tail_call({self._args_and_kwargs_string})"
98+
99+
@ abc.abstractmethod
100+
def resolve(self):
101+
"""Lazily and sequentially evaluates recursive tail calls while maintaining same size of callstack."""
102+
...
103+
104+
105+
class TailCallWithoutNestedCallResolution(TailCall):
106+
107+
def resolve(self):
108+
resolution_value = self.func(*self.args, **self.kwargs)
109+
while isinstance(resolution_value, TailCall):
110+
resolution_value = self.func(
111+
*resolution_value.args,
112+
**resolution_value.kwargs
113+
)
114+
return resolution_value
115+
116+
117+
@dataclass(init=False)
118+
class _TailCallStackItem:
119+
tail_call: TailCall
120+
indexed_args_and_kwargs: _IndexedArgsAndKwargsAccess
121+
resolving_arg_or_kwarg_with_index: Optional[int]
122+
123+
def __init__(self, tail_call: TailCall):
124+
self.tail_call = tail_call
125+
self.indexed_args_and_kwargs = _IndexedArgsAndKwargsAccess(tail_call)
126+
self.resolving_arg_or_kwarg_with_index = None
127+
128+
129+
@dataclass(init=False)
130+
class _TailCallStack:
131+
stack: List[_TailCallStackItem]
132+
length: int
133+
134+
def __init__(self, initial_item: TailCall):
135+
self.stack = [_TailCallStackItem(initial_item)]
136+
self.length = 1
137+
138+
@property
139+
def last_item(self):
140+
return self.stack[-1]
141+
142+
def push(self, item: TailCall):
143+
self.stack.append(_TailCallStackItem(item))
144+
self.length += 1
145+
146+
def pop_item_resolution(self):
147+
tail_call_with_fully_resolved_args_and_kwargs = self.stack.pop().tail_call
148+
self.length -= 1
149+
return tail_call_with_fully_resolved_args_and_kwargs.func(
150+
*tail_call_with_fully_resolved_args_and_kwargs.args,
151+
**tail_call_with_fully_resolved_args_and_kwargs.kwargs
152+
)
153+
154+
def set_arg_or_kwarg_of_last_item_to_resolution(self, resolution: Any):
155+
self.last_item.indexed_args_and_kwargs.set(
156+
self.last_item.resolving_arg_or_kwarg_with_index,
157+
resolution
158+
)
159+
160+
161+
class TailCallWithNestedCallResolution(TailCall):
162+
163+
def __init__(self, func: Callable[..., Any], args: List[Any], kwargs: Dict[str, Any]):
164+
# ``setattr`` stops mypy complaining.
165+
# Seems to be related to this issue https://github.com/python/mypy/issues/2427.
166+
setattr(self, "func", func)
167+
self.args = args
168+
self.kwargs = kwargs
169+
170+
def resolve(self):
171+
tail_call_stack = _TailCallStack(initial_item=self)
172+
while True:
173+
if tail_call_stack.last_item.resolving_arg_or_kwarg_with_index is None:
174+
start_arg_index = 0
175+
else:
176+
start_arg_index = tail_call_stack.last_item.resolving_arg_or_kwarg_with_index + 1
177+
for arg_index in range(start_arg_index, tail_call_stack.last_item.indexed_args_and_kwargs.length):
178+
arg = tail_call_stack.last_item.indexed_args_and_kwargs.get(
179+
arg_index
180+
)
181+
if isinstance(arg, TailCall):
182+
tail_call_stack.last_item.resolving_arg_or_kwarg_with_index = arg_index
183+
tail_call_stack.push(arg)
184+
break
185+
# Else block is evaluated if loop is not broken out of.
186+
else:
187+
resolution = tail_call_stack.pop_item_resolution()
188+
if isinstance(resolution, TailCall):
189+
tail_call_stack.push(resolution)
190+
elif tail_call_stack.length > 0:
191+
tail_call_stack.set_arg_or_kwarg_of_last_item_to_resolution(
192+
resolution
193+
)
194+
else:
195+
return resolution
196+
197+
198+
@enum.unique
199+
class NestedCallMode(enum.Enum):
200+
"""Different ways of resolving nested tail calls."""
201+
202+
DO_NOT_RESOLVE_NESTED_CALLS: str = "do_not_resolve_nested_calls"
203+
RESOLVE_NESTED_CALLS: str = "resolve_nested_calls"
204+
205+
206+
NESTED_CALL_MODE_TAILCALL_SUBCLASS_MAP: Dict[NestedCallMode, Type[TailCall]] = {
207+
NestedCallMode.DO_NOT_RESOLVE_NESTED_CALLS: TailCallWithoutNestedCallResolution,
208+
NestedCallMode.RESOLVE_NESTED_CALLS: TailCallWithNestedCallResolution,
209+
}
210+
211+
212+
@dataclass(init=False)
213+
class TailCallable(_FuncStore):
214+
215+
nested_call_mode: NestedCallMode
216+
217+
def __init__(self, func: Callable[..., Any], *, nested_call_mode: Union[NestedCallMode, str] = NestedCallMode.RESOLVE_NESTED_CALLS):
218+
functools.update_wrapper(self, func)
219+
# ``setattr`` stops mypy complaining.
220+
# Seems to be related to this issue https://github.com/python/mypy/issues/2427.
221+
setattr(self, "func", func)
222+
if isinstance(nested_call_mode, NestedCallMode):
223+
self.nested_call_mode = nested_call_mode
224+
else:
225+
self.nested_call_mode = NestedCallMode(nested_call_mode)
226+
227+
def __repr__(self) -> str:
228+
return f"{tail_recursive.__qualname__}(func={self._func_repr})"
229+
230+
def __call__(self, *args, **kwargs) -> Any:
231+
return self.tail_call(*args, **kwargs).resolve()
232+
233+
def tail_call(self, *args, **kwargs) -> TailCall:
234+
"""Passes arguments to a tail recursive function so that it may lazily called.
235+
236+
This method should be called as the single return value of a function. This
237+
enables the function to be called once the after its caller function has been
238+
garbage collected.
239+
240+
Example::
241+
242+
def f():
243+
return tail_recursive_function.tail_call(...)
244+
"""
245+
return NESTED_CALL_MODE_TAILCALL_SUBCLASS_MAP[self.nested_call_mode](func=self.func, args=list(args), kwargs=kwargs)
246+
247+
248+
def tail_recursive(_func=None, *, nested_call_mode=NestedCallMode.RESOLVE_NESTED_CALLS):
11249
"""A decorator that gives your functions the ability to be tail recursive.
12250
251+
Args:
252+
nested_call_mode: Defines the way in which nested calls are resolved.
253+
13254
Example::
14255
15256
# Pick a larger value if n is below your system's recursion limit.
@@ -39,63 +280,10 @@ def factorial(n, accumulator=1):
39280
Methods:
40281
tail_call(*args, **kwargs)
41282
"""
283+
def decorator(func):
284+
return TailCallable(func, nested_call_mode=nested_call_mode)
42285

43-
func: Callable[..., Any]
44-
args: Tuple[Any, ...]
45-
kwargs: Dict[str, Any]
46-
has_been_tail_called: bool
47-
48-
def __init__(self, func: Callable[..., Any]):
49-
"""Assigns the ``func`` attribute to the decorated function."""
50-
self.func = func # type: ignore
51-
self.has_been_tail_called = False
52-
53-
def __repr__(self) -> str:
54-
class_string: str = type(self).__qualname__
55-
func_string: str = repr(self.func)
56-
object_string: str = f"{class_string}(func={func_string})"
57-
if self.has_been_tail_called:
58-
args_string = ', '.join(chain(
59-
(repr(arg) for arg in self.args),
60-
(f"{name}={repr(val)}" for name, val in self.kwargs.items())
61-
))
62-
return f"{object_string}.tail_call({args_string})"
63-
return object_string
64-
65-
def __call__(self, *args, **kwargs) -> Any:
66-
@wraps(self.func)
67-
def wrapper(*args, **kwargs) -> Any:
68-
# If ``return_value`` is an instance of ``tail_recursive`` then
69-
# ``return_value`` will be reassigned to the return value of the
70-
# function stored as ``func`` called with the arguments set in the
71-
# call to tail_recursion.
72-
return_value: tail_recursive = self.tail_call(*args, **kwargs)
73-
while isinstance(
74-
(return_value := return_value.func(
75-
*return_value.args,
76-
**return_value.kwargs
77-
)),
78-
type(self)
79-
):
80-
pass
81-
# Once ``return_value`` is no longer an instance of ``tail_recursive``, it
82-
# is returned.
83-
return return_value
84-
return wrapper(*args, **kwargs)
85-
86-
def tail_call(self, *args, **kwargs) -> 'tail_recursive':
87-
"""Passes arguments to a tail recursive function so that it may lazily called.
88-
89-
This method should be called as the single return value of a function. This
90-
enables the function to be called once the after its caller function has been
91-
garbage collected.
92-
93-
Example::
94-
95-
def f():
96-
return tail_recursive_function.tail_call(...)
97-
"""
98-
self.args = args
99-
self.kwargs = kwargs
100-
self.has_been_tail_called = True
101-
return self
286+
if _func is None:
287+
return decorator
288+
else:
289+
return decorator(_func)

0 commit comments

Comments
 (0)