Skip to content

Commit 3e1dc82

Browse files
committed
Update documentation and tests relvant to documentation
1 parent 877cff6 commit 3e1dc82

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed

README.md

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,91 @@ from the caller function (in this case also `factorial`). This means that the
4747
resources in the caller function's scope are free to be garbage collected and that its
4848
frame is popped from the call stack before we push the returned function on.
4949

50+
## Nested Calls
51+
52+
In the previous example the whole concept of an accumulator my not fit your mental model
53+
that well (it doesn't for me at least).
54+
Luckily calls to `tail_call` support nested calls (i.e. another `tail_call` passed as an
55+
argument).
56+
Taking this functionality into consideration we can refactor the previous example.
57+
58+
```python
59+
...
60+
61+
@tail_recursive
62+
def mul(a, b):
63+
return a * b
64+
65+
@tail_recursive
66+
def factorial(n):
67+
if n == 1:
68+
return n
69+
return mul.tail_call(n, factorial.tail_call(n - 1))
70+
71+
...
72+
```
73+
74+
This, however, comes a performance cost and can be disabled as follows.
75+
76+
```python
77+
@tail_recursive(nested_call_mode="do_not_resolve_nested_calls")
78+
def factorial(n, accumulator=1):
79+
if n == 1:
80+
return accumulator
81+
return factorial.tail_call(n - 1, n * accumulator)
82+
```
83+
84+
or
85+
86+
```python
87+
from tail_recursive import tail_recursive, NestedCallMode
88+
89+
...
90+
91+
@tail_recursive(nested_call_mode=NestedCallMode.DO_NOT_RESOLVE_NESTED_CALLS)
92+
def factorial(n, accumulator=1):
93+
...
94+
```
95+
96+
Similarly, use `nested_call_mode="resolve_nested_calls"` or `nested_call_mode=NestedCallMode.RESOLVE_NESTED_CALLS`
97+
to explicitly enable this feature.
98+
99+
## Current Limitations
100+
101+
### Return Values
102+
103+
Currently tail calls that are returned as an item in a tuple or other
104+
data structure are not evaluated.
105+
106+
The following will not evaluate the tail call.
107+
108+
```python
109+
from tail_recursive import tail_recursive
110+
111+
@tail_recursive
112+
def func(...):
113+
...
114+
return return_val1, func.tail_call(...)
115+
```
116+
117+
A workaround is to use factory functions.
118+
119+
```python
120+
from tail_recursive import tail_recursive
121+
122+
@tail_recursive
123+
def tuple_factory(*args):
124+
return tuple(args)
125+
126+
@tail_recursive
127+
def func(...):
128+
...
129+
return tuple_factory.tail_call(
130+
return_val1,
131+
func.tail_call(...)
132+
)
133+
```
134+
50135
## Other Packages
51136

52137
Check out [tco](https://github.com/baruchel/tco) for an alternative api with extra functionality.

tests.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,63 @@ def fibonacci(n):
273273

274274
n = sys.getrecursionlimit() + 1
275275
assert fibonacci(n) == non_recursive_fibonacci(n)
276+
277+
278+
def test_tail_call_as_part_for_datastructure_is_not_evaluated():
279+
280+
@tail_recursive
281+
def add(a, b):
282+
return a + b
283+
284+
@tail_recursive
285+
def getitem(obj, index):
286+
return obj[index]
287+
288+
@tail_recursive
289+
def square_and_triangular_numbers(n):
290+
square = n**2
291+
if n == 1:
292+
triangular_number = n
293+
else:
294+
triangular_number = add.tail_call(
295+
n,
296+
getitem.tail_call(
297+
square_and_triangular_numbers.tail_call(n - 1),
298+
1
299+
)
300+
)
301+
return square, triangular_number
302+
303+
assert square_and_triangular_numbers(3) != (9, 6)
304+
305+
306+
def test_tail_call_as_part_for_datastructure_with_factory_succeeds():
307+
308+
@tail_recursive
309+
def tuple_factory(*args):
310+
return tuple(args)
311+
312+
@tail_recursive
313+
def add(a, b):
314+
return a + b
315+
316+
@tail_recursive
317+
def getitem(obj, index):
318+
return obj[index]
319+
320+
@tail_recursive
321+
def square_and_triangular_numbers(n):
322+
square = n**2
323+
if n == 1:
324+
triangular_number = n
325+
else:
326+
triangular_number = add.tail_call(
327+
n,
328+
getitem.tail_call(
329+
square_and_triangular_numbers.tail_call(n - 1),
330+
1
331+
)
332+
)
333+
return tuple_factory.tail_call(square, triangular_number)
334+
335+
assert square_and_triangular_numbers(3) == (9, 6)

0 commit comments

Comments
 (0)