Skip to content

Commit 6c2ce49

Browse files
author
Heloise Stevance
committed
added chunk size options and retries if chunk_get_response fails.
1 parent dde5201 commit 6c2ce49

File tree

1 file changed

+78
-93
lines changed

1 file changed

+78
-93
lines changed

atlasapiclient/client.py

Lines changed: 78 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import warnings
2020
from abc import ABC
2121
import logging
22+
import time
23+
import random
2224

2325
import requests
2426
import numpy as np
@@ -495,86 +497,55 @@ def save_response_to_json(self, output_dir=None):
495497
class RequestMultipleSourceData(APIClient):
496498
def __init__(self,
497499
array_ids: np.array = None,
498-
mjdthreshold = None,
500+
mjdthreshold=None,
499501
api_config_file: str = None,
502+
chunk_size: int = 100,
500503
**kwargs,
501504
):
502-
"""READ - Request data for multiple sources. Contains a convenience
503-
method to chunk the request in groups of 100 so don't get timed out by the server.
504-
505-
Parameters
506-
----------
507-
array_ids:
508-
The ATLAS\_IDs. Can be an array a list or a tuple.
509-
mjdthreshold;
510-
The Lower MJD threshold (we don't have a higher one yet)
511-
api_config_file:
512-
By default will use you api_config_MINE.yaml file.
513-
"""
514505
super().__init__(api_config_file, **kwargs)
515506

516-
# ATLAS_ID ARRAY - CHECK VALIDITY AND ASSIGN
517-
assert array_ids is not None, "You need to provide an array of object IDs" # Check not None
518-
# TODO: Checking that the array is a numpy array might not be necessary,
519-
# we can convert pretty easily beforehand. Also we don't seem to be using
520-
# any numpy features here, the ints are all converted to strings anyway?
521-
assert isinstance(array_ids, np.ndarray), "array_ids must be a numpy array" # check is a numpy array
522-
assert len(array_ids) > 0, "array_ids must not be empty" # check is not empty
523-
self.array_ids = np.array([self.parse_atlas_id(str(x)) for x in array_ids]) # check each element is valid
507+
assert array_ids is not None, "You need to provide an array of object IDs"
508+
assert isinstance(array_ids, np.ndarray), "array_ids must be a numpy array"
509+
assert len(array_ids) > 0, "array_ids must not be empty"
524510

525-
# MJD THRESHOLD AND URL - ASSIGN
511+
self.array_ids = np.array([self.parse_atlas_id(str(x)) for x in array_ids])
526512
self.mjdthreshold = mjdthreshold
527-
self.url= "objects/"
528-
self.url = self.apiURL + self.url
529-
530-
# INITIALIZE RESPONSE AS A LIST
513+
self.url = self.apiURL + "objects/"
531514
self.response_data = []
515+
self.chunk_size = chunk_size # NEW: assign chunk size
532516

533-
def chunk_get_response_quiet(self):
534-
"""Chunks the request in groups of 100 so don't get timed out by the server.
535-
Does not print out a progress bar.
536-
537-
Notes
538-
------
539-
This is typically used in production to avoid spamming the logs.
517+
def _run_chunked_requests(self, chunks, show_progress=False, max_retries=3, backoff_range=(1, 5)):
518+
iterable = tqdm(chunks) if show_progress else chunks
540519

541-
Returns
542-
-------
543-
None
544-
"""
545-
# Split array_ids into chunks of 100
546-
chunks = [self.array_ids[i:i + 100] for i in range(0, len(self.array_ids), 100)]
547-
548-
# Iterate over each chunk and make separate requests
549-
for chunk in chunks:
520+
for idx, chunk in enumerate(iterable):
550521
array_ids_str = ','.join(map(str, chunk))
551-
self.payload = {'objects': array_ids_str,
552-
'mjd': self.mjdthreshold
553-
}
554-
555-
_response = self.get_response(inplace=False)
556-
self.response_data.extend(_response)
557-
558-
def chunk_get_response(self):
559-
"""Chunks the request in groups of 100 so don't get timed out by the server.
560-
Prints out a progress bar.
561-
562-
Notes
563-
------
564-
This is typically used in human scripts and notebooks to see how long it'll take.
565-
"""
566-
# Split array_ids into chunks of 100
567-
chunks = [self.array_ids[i:i + 100] for i in range(0, len(self.array_ids), 100)]
522+
self.payload = {'objects': array_ids_str, 'mjd': self.mjdthreshold}
523+
524+
attempt = 0
525+
while attempt < max_retries:
526+
try:
527+
_response = self.get_response(inplace=False)
528+
self.response_data.extend(_response)
529+
break # success
530+
except Exception as e:
531+
attempt += 1
532+
wait = random.uniform(*backoff_range)
533+
logging.warning(
534+
f"[Chunk {idx+1}/{len(chunks)}] Retry {attempt}/{max_retries} due to error: {e}. Backing off {wait:.2f}s"
535+
)
536+
time.sleep(wait)
537+
else:
538+
logging.error(f"[Chunk {idx+1}/{len(chunks)}] Failed after {max_retries} retries. Skipping.")
568539

569-
# Iterate over each chunk and make separate requests
570-
for chunk in tqdm(chunks):
571-
array_ids_str = ','.join(map(str, chunk))
572-
self.payload = {'objects': array_ids_str,
573-
'mjd': self.mjdthreshold
574-
}
540+
def chunk_get_response_quiet(self, max_retries=3, backoff_range=(1, 5)):
541+
chunks = [self.array_ids[i:i + self.chunk_size] for i in range(0, len(self.array_ids), self.chunk_size)]
542+
self.response_data = []
543+
self._run_chunked_requests(chunks, show_progress=False, max_retries=max_retries, backoff_range=backoff_range)
575544

576-
_response = self.get_response(inplace=False)
577-
self.response_data.extend(_response)
545+
def chunk_get_response(self, max_retries=3, backoff_range=(1, 5)):
546+
chunks = [self.array_ids[i:i + self.chunk_size] for i in range(0, len(self.array_ids), self.chunk_size)]
547+
self.response_data = []
548+
self._run_chunked_requests(chunks, show_progress=True, max_retries=max_retries, backoff_range=backoff_range)
578549

579550
def save_response_to_json(self, output_dir=None):
580551
"""Saves the response to INDIVIDUAL text files with the name [ATLAS\_ID].json.
@@ -776,50 +747,64 @@ def __init__(self,
776747
list_name: str = None,
777748
get_response: bool = False,
778749
api_config_file: str = None,
750+
chunk_size: int = 100, # NEW: chunk size parameter
779751
**kwargs,
780752
):
781-
"""WRITE - Remove ATLAS\_IDs from a custom list
753+
"""
754+
WRITE - Remove ATLAS_IDs from a custom list.
782755
783756
Parameters
784-
------------
757+
----------
785758
array_ids: np.array, list, tuple
786759
The ATLAS IDs. Can be an array a list or a tuple.
787760
list_name: str
788-
The name of the list you want (NOT THE NUMBER). e.g. 'mookodi'.
761+
The name of the list you want (NOT THE NUMBER). e.g. 'mookodi'.
789762
get_response: bool
790-
If True, will get the response on instanciation
763+
If True, will get the response on instantiation.
791764
api_config_file: str
792-
By default will use you api_config_MINE.yaml file.
765+
Optional path to your API config file.
766+
chunk_size: int
767+
Max number of IDs per chunk. Default is 100.
793768
"""
794769
super().__init__(api_config_file, **kwargs)
795770
self.url = self.apiURL + 'objectgroupsdelete/'
796771
self.array_ids = array_ids
797-
self.object_group_id = self.dict_list_id[list_name][0] # object group id is the number of the custom list
772+
self.chunk_size = chunk_size # NEW: store chunk size
773+
self.object_group_id = self.dict_list_id[list_name][0]
774+
self.response_data = []
798775

799-
# if self.array_ids smaller than 100 we can just create the payload and use get_response from ATLASAPIBase
800-
# if self.array_ids is larger than 100 we need to call chunk_get_response(_quiet)
801-
if self.array_ids.shape[0] > 100:
776+
if self.array_ids.shape[0] > self.chunk_size:
802777
self.chunk_get_response_quiet()
803778
else:
804-
self.payload = {'objectid': ','.join(map(str, self.array_ids)),
805-
'objectgroupid': self.object_group_id
806-
}
807-
808-
809-
def chunk_get_response_quiet(self):
810-
"""Chunks the request in groups of 100 so don't get timed out by the server. No progress bar."""
811-
# Split array_ids into chunks of 100
812-
chunks = [self.array_ids[i:i + 100] for i in range(0, len(self.array_ids), 100)]
779+
self.payload = {
780+
'objectid': ','.join(map(str, self.array_ids)),
781+
'objectgroupid': self.object_group_id
782+
}
783+
784+
def chunk_get_response_quiet(self, max_retries=3, backoff_range=(1, 5)):
785+
chunks = [self.array_ids[i:i + self.chunk_size]
786+
for i in range(0, len(self.array_ids), self.chunk_size)]
787+
self.response_data = []
813788

814-
# Iterate over each chunk and make separate requests
815-
for chunk in chunks:
789+
for idx, chunk in enumerate(chunks):
816790
array_ids_str = ','.join(map(str, chunk))
817-
self.payload = {'objectid': array_ids_str,
818-
'objectgroupid': self.object_group_id
819-
}
820-
821-
_response = self.get_response(inplace=False)
822-
self.response_data.extend(_response)
791+
self.payload = {'objectid': array_ids_str, 'objectgroupid': self.object_group_id}
792+
793+
attempt = 0
794+
while attempt < max_retries:
795+
try:
796+
_response = self.get_response(inplace=False)
797+
self.response_data.extend(_response)
798+
break
799+
except Exception as e:
800+
attempt += 1
801+
wait_time = random.uniform(*backoff_range)
802+
logging.warning(
803+
f"[Chunk {idx+1}/{len(chunks)}] Retry {attempt}/{max_retries} after error: {e}. Sleeping {wait_time:.2f}s...")
804+
time.sleep(wait_time)
805+
else:
806+
logging.error(
807+
f"[Chunk {idx+1}/{len(chunks)}] Failed after {max_retries} retries. Chunk skipped.")
823808

824809
class WriteObjectDetectionListNumber(APIClient):
825810
def __init__(self,

0 commit comments

Comments
 (0)