Skip to content

Commit 1517794

Browse files
authored
Add session reuse for pagination (#65)
1 parent 3f89081 commit 1517794

File tree

3 files changed

+64
-12
lines changed

3 files changed

+64
-12
lines changed

pyalex/api.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,12 @@ def __init__(
338338
self.value = value
339339
self.per_page = per_page
340340
self.n_max = n_max
341-
# The current number of results retrieved.
342341
self.n = 0
343342

344343
self._next_value = value
344+
self._session = _get_requests_session()
345345

346346
def __iter__(self):
347-
self.n = 0
348-
349347
return self
350348

351349
def _is_max(self):
@@ -358,13 +356,22 @@ def __next__(self):
358356
raise StopIteration
359357

360358
if self.method == "cursor":
361-
pagination_params = {"cursor": self._next_value}
359+
self.endpoint_class._add_params("cursor", self._next_value)
362360
elif self.method == "page":
363-
pagination_params = {"page": self._next_value}
361+
self.endpoint_class._add_params("page", self._next_value)
364362
else:
365-
raise ValueError()
363+
raise ValueError("Method should be 'cursor' or 'page'")
364+
365+
if self.per_page is not None and (
366+
not isinstance(self.per_page, int)
367+
or (self.per_page < 1 or self.per_page > 200)
368+
):
369+
raise ValueError("per_page should be a integer between 1 and 200")
370+
371+
if self.per_page is not None:
372+
self.endpoint_class._add_params("per-page", self.per_page)
366373

367-
r = self.endpoint_class.get(per_page=self.per_page, **pagination_params)
374+
r = self.endpoint_class._get_from_url(self.endpoint_class.url, self._session)
368375

369376
if self.method == "cursor":
370377
self._next_value = r.meta["next_cursor"]
@@ -501,8 +508,11 @@ def count(self):
501508
"""
502509
return self.get(per_page=1).meta["count"]
503510

504-
def _get_from_url(self, url):
505-
res = _get_requests_session().get(url, auth=OpenAlexAuth(config))
511+
def _get_from_url(self, url, session=None):
512+
if session is None:
513+
session = _get_requests_session()
514+
515+
res = session.get(url, auth=OpenAlexAuth(config))
506516

507517
if res.status_code == 403:
508518
if (
@@ -528,8 +538,10 @@ def _get_from_url(self, url):
528538
raise ValueError("Unknown response format")
529539

530540
def get(self, return_meta=False, page=None, per_page=None, cursor=None):
531-
if per_page is not None and (per_page < 1 or per_page > 200):
532-
raise ValueError("per_page should be a number between 1 and 200.")
541+
if per_page is not None and (
542+
not isinstance(per_page, int) or (per_page < 1 or per_page > 200)
543+
):
544+
raise ValueError("per_page should be an integer between 1 and 200")
533545

534546
if not isinstance(self.params, (str, list)):
535547
self._add_params("per-page", per_page)
@@ -570,7 +582,7 @@ def paginate(self, method="cursor", page=1, per_page=None, cursor="*", n_max=100
570582
Paginator object.
571583
"""
572584
if method == "cursor":
573-
if self.params is not None and self.params.get("sample"):
585+
if isinstance(self.params, dict) and self.params.get("sample"):
574586
raise ValueError("method should be 'page' when using sample")
575587
value = cursor
576588
elif method == "page":

tests/test_paging.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
pyalex.config.max_retries = 10
88

99

10+
def test_cursor_no_filter():
11+
assert len(list(pyalex.Works().paginate(per_page=200, n_max=1000))) == 5
12+
13+
1014
def test_cursor():
1115
query = Authors().search_filter(display_name="einstein")
1216

@@ -73,6 +77,28 @@ def test_paginate_counts():
7377
assert r.meta["count"] == n_p_page >= n_p_default == n_p_cursor
7478

7579

80+
def test_paginate_per_page():
81+
assert all(len(page) <= 10 for page in Authors().paginate(per_page=10, n_max=50))
82+
83+
84+
def test_paginate_per_page_200():
85+
assert all(len(page) == 200 for page in Authors().paginate(per_page=200, n_max=400))
86+
87+
88+
def test_paginate_per_page_none():
89+
assert all(len(page) == 25 for page in Authors().paginate(n_max=500))
90+
91+
92+
def test_paginate_per_page_1000():
93+
with pytest.raises(ValueError):
94+
assert next(Authors().paginate(per_page=1000))
95+
96+
97+
def test_paginate_per_page_str():
98+
with pytest.raises(ValueError):
99+
assert next(Authors().paginate(per_page="100"))
100+
101+
76102
def test_paginate_instance():
77103
p_default = Authors().search_filter(display_name="einstein").paginate(per_page=200)
78104
assert isinstance(p_default, Paginator)

tests/test_pyalex.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,20 @@ def test_per_page():
7676
assert len(Works().filter(publication_year=2020).get(per_page=200)) == 200
7777

7878

79+
def test_per_page_none():
80+
assert len(Works().filter(publication_year=2020).get(per_page=None)) == 25
81+
82+
83+
def test_per_page_1000():
84+
with pytest.raises(ValueError):
85+
Works().filter(publication_year=2020).get(per_page=1000)
86+
87+
88+
def test_per_page_str():
89+
with pytest.raises(ValueError):
90+
Works().filter(publication_year=2020).get(per_page="100")
91+
92+
7993
def test_W4238809453_works():
8094
assert isinstance(Works()["W4238809453"], Work)
8195
assert Works()["W4238809453"]["doi"] == "https://doi.org/10.1001/jama.264.8.944b"

0 commit comments

Comments
 (0)