@@ -338,14 +338,12 @@ def __init__(
338
338
self .value = value
339
339
self .per_page = per_page
340
340
self .n_max = n_max
341
- # The current number of results retrieved.
342
341
self .n = 0
343
342
344
343
self ._next_value = value
344
+ self ._session = _get_requests_session ()
345
345
346
346
def __iter__ (self ):
347
- self .n = 0
348
-
349
347
return self
350
348
351
349
def _is_max (self ):
@@ -358,13 +356,22 @@ def __next__(self):
358
356
raise StopIteration
359
357
360
358
if self .method == "cursor" :
361
- pagination_params = { "cursor" : self ._next_value }
359
+ self . endpoint_class . _add_params ( "cursor" , self ._next_value )
362
360
elif self .method == "page" :
363
- pagination_params = { "page" : self ._next_value }
361
+ self . endpoint_class . _add_params ( "page" , self ._next_value )
364
362
else :
365
- raise ValueError ()
363
+ raise ValueError ("Method should be 'cursor' or 'page'" )
364
+
365
+ if self .per_page is not None and (
366
+ not isinstance (self .per_page , int )
367
+ or (self .per_page < 1 or self .per_page > 200 )
368
+ ):
369
+ raise ValueError ("per_page should be a integer between 1 and 200" )
370
+
371
+ if self .per_page is not None :
372
+ self .endpoint_class ._add_params ("per-page" , self .per_page )
366
373
367
- r = self .endpoint_class .get ( per_page = self .per_page , ** pagination_params )
374
+ r = self .endpoint_class ._get_from_url ( self .endpoint_class . url , self . _session )
368
375
369
376
if self .method == "cursor" :
370
377
self ._next_value = r .meta ["next_cursor" ]
@@ -501,8 +508,11 @@ def count(self):
501
508
"""
502
509
return self .get (per_page = 1 ).meta ["count" ]
503
510
504
- def _get_from_url (self , url ):
505
- res = _get_requests_session ().get (url , auth = OpenAlexAuth (config ))
511
+ def _get_from_url (self , url , session = None ):
512
+ if session is None :
513
+ session = _get_requests_session ()
514
+
515
+ res = session .get (url , auth = OpenAlexAuth (config ))
506
516
507
517
if res .status_code == 403 :
508
518
if (
@@ -528,8 +538,10 @@ def _get_from_url(self, url):
528
538
raise ValueError ("Unknown response format" )
529
539
530
540
def get (self , return_meta = False , page = None , per_page = None , cursor = None ):
531
- if per_page is not None and (per_page < 1 or per_page > 200 ):
532
- raise ValueError ("per_page should be a number between 1 and 200." )
541
+ if per_page is not None and (
542
+ not isinstance (per_page , int ) or (per_page < 1 or per_page > 200 )
543
+ ):
544
+ raise ValueError ("per_page should be an integer between 1 and 200" )
533
545
534
546
if not isinstance (self .params , (str , list )):
535
547
self ._add_params ("per-page" , per_page )
@@ -570,7 +582,7 @@ def paginate(self, method="cursor", page=1, per_page=None, cursor="*", n_max=100
570
582
Paginator object.
571
583
"""
572
584
if method == "cursor" :
573
- if self .params is not None and self .params .get ("sample" ):
585
+ if isinstance ( self .params , dict ) and self .params .get ("sample" ):
574
586
raise ValueError ("method should be 'page' when using sample" )
575
587
value = cursor
576
588
elif method == "page" :
0 commit comments