|
19 | 19 | import warnings
|
20 | 20 | from abc import ABC
|
21 | 21 | import logging
|
| 22 | +import time |
| 23 | +import random |
22 | 24 |
|
23 | 25 | import requests
|
24 | 26 | import numpy as np
|
@@ -495,86 +497,55 @@ def save_response_to_json(self, output_dir=None):
|
495 | 497 | class RequestMultipleSourceData(APIClient):
|
496 | 498 | def __init__(self,
|
497 | 499 | array_ids: np.array = None,
|
498 |
| - mjdthreshold = None, |
| 500 | + mjdthreshold=None, |
499 | 501 | api_config_file: str = None,
|
| 502 | + chunk_size: int = 100, |
500 | 503 | **kwargs,
|
501 | 504 | ):
|
502 |
| - """READ - Request data for multiple sources. Contains a convenience |
503 |
| - method to chunk the request in groups of 100 so don't get timed out by the server. |
504 |
| -
|
505 |
| - Parameters |
506 |
| - ---------- |
507 |
| - array_ids: |
508 |
| - The ATLAS\_IDs. Can be an array a list or a tuple. |
509 |
| - mjdthreshold; |
510 |
| - The Lower MJD threshold (we don't have a higher one yet) |
511 |
| - api_config_file: |
512 |
| - By default will use you api_config_MINE.yaml file. |
513 |
| - """ |
514 | 505 | super().__init__(api_config_file, **kwargs)
|
515 | 506 |
|
516 |
| - # ATLAS_ID ARRAY - CHECK VALIDITY AND ASSIGN |
517 |
| - assert array_ids is not None, "You need to provide an array of object IDs" # Check not None |
518 |
| - # TODO: Checking that the array is a numpy array might not be necessary, |
519 |
| - # we can convert pretty easily beforehand. Also we don't seem to be using |
520 |
| - # any numpy features here, the ints are all converted to strings anyway? |
521 |
| - assert isinstance(array_ids, np.ndarray), "array_ids must be a numpy array" # check is a numpy array |
522 |
| - assert len(array_ids) > 0, "array_ids must not be empty" # check is not empty |
523 |
| - self.array_ids = np.array([self.parse_atlas_id(str(x)) for x in array_ids]) # check each element is valid |
| 507 | + assert array_ids is not None, "You need to provide an array of object IDs" |
| 508 | + assert isinstance(array_ids, np.ndarray), "array_ids must be a numpy array" |
| 509 | + assert len(array_ids) > 0, "array_ids must not be empty" |
524 | 510 |
|
525 |
| - # MJD THRESHOLD AND URL - ASSIGN |
| 511 | + self.array_ids = np.array([self.parse_atlas_id(str(x)) for x in array_ids]) |
526 | 512 | self.mjdthreshold = mjdthreshold
|
527 |
| - self.url= "objects/" |
528 |
| - self.url = self.apiURL + self.url |
529 |
| - |
530 |
| - # INITIALIZE RESPONSE AS A LIST |
| 513 | + self.url = self.apiURL + "objects/" |
531 | 514 | self.response_data = []
|
| 515 | + self.chunk_size = chunk_size # NEW: assign chunk size |
532 | 516 |
|
533 |
| - def chunk_get_response_quiet(self): |
534 |
| - """Chunks the request in groups of 100 so don't get timed out by the server. |
535 |
| - Does not print out a progress bar. |
536 |
| -
|
537 |
| - Notes |
538 |
| - ------ |
539 |
| - This is typically used in production to avoid spamming the logs. |
| 517 | + def _run_chunked_requests(self, chunks, show_progress=False, max_retries=3, backoff_range=(1, 5)): |
| 518 | + iterable = tqdm(chunks) if show_progress else chunks |
540 | 519 |
|
541 |
| - Returns |
542 |
| - ------- |
543 |
| - None |
544 |
| - """ |
545 |
| - # Split array_ids into chunks of 100 |
546 |
| - chunks = [self.array_ids[i:i + 100] for i in range(0, len(self.array_ids), 100)] |
547 |
| - |
548 |
| - # Iterate over each chunk and make separate requests |
549 |
| - for chunk in chunks: |
| 520 | + for idx, chunk in enumerate(iterable): |
550 | 521 | array_ids_str = ','.join(map(str, chunk))
|
551 |
| - self.payload = {'objects': array_ids_str, |
552 |
| - 'mjd': self.mjdthreshold |
553 |
| - } |
554 |
| - |
555 |
| - _response = self.get_response(inplace=False) |
556 |
| - self.response_data.extend(_response) |
557 |
| - |
558 |
| - def chunk_get_response(self): |
559 |
| - """Chunks the request in groups of 100 so don't get timed out by the server. |
560 |
| - Prints out a progress bar. |
561 |
| -
|
562 |
| - Notes |
563 |
| - ------ |
564 |
| - This is typically used in human scripts and notebooks to see how long it'll take. |
565 |
| - """ |
566 |
| - # Split array_ids into chunks of 100 |
567 |
| - chunks = [self.array_ids[i:i + 100] for i in range(0, len(self.array_ids), 100)] |
| 522 | + self.payload = {'objects': array_ids_str, 'mjd': self.mjdthreshold} |
| 523 | + |
| 524 | + attempt = 0 |
| 525 | + while attempt < max_retries: |
| 526 | + try: |
| 527 | + _response = self.get_response(inplace=False) |
| 528 | + self.response_data.extend(_response) |
| 529 | + break # success |
| 530 | + except Exception as e: |
| 531 | + attempt += 1 |
| 532 | + wait = random.uniform(*backoff_range) |
| 533 | + logging.warning( |
| 534 | + f"[Chunk {idx+1}/{len(chunks)}] Retry {attempt}/{max_retries} due to error: {e}. Backing off {wait:.2f}s" |
| 535 | + ) |
| 536 | + time.sleep(wait) |
| 537 | + else: |
| 538 | + logging.error(f"[Chunk {idx+1}/{len(chunks)}] Failed after {max_retries} retries. Skipping.") |
568 | 539 |
|
569 |
| - # Iterate over each chunk and make separate requests |
570 |
| - for chunk in tqdm(chunks): |
571 |
| - array_ids_str = ','.join(map(str, chunk)) |
572 |
| - self.payload = {'objects': array_ids_str, |
573 |
| - 'mjd': self.mjdthreshold |
574 |
| - } |
| 540 | + def chunk_get_response_quiet(self, max_retries=3, backoff_range=(1, 5)): |
| 541 | + chunks = [self.array_ids[i:i + self.chunk_size] for i in range(0, len(self.array_ids), self.chunk_size)] |
| 542 | + self.response_data = [] |
| 543 | + self._run_chunked_requests(chunks, show_progress=False, max_retries=max_retries, backoff_range=backoff_range) |
575 | 544 |
|
576 |
| - _response = self.get_response(inplace=False) |
577 |
| - self.response_data.extend(_response) |
| 545 | + def chunk_get_response(self, max_retries=3, backoff_range=(1, 5)): |
| 546 | + chunks = [self.array_ids[i:i + self.chunk_size] for i in range(0, len(self.array_ids), self.chunk_size)] |
| 547 | + self.response_data = [] |
| 548 | + self._run_chunked_requests(chunks, show_progress=True, max_retries=max_retries, backoff_range=backoff_range) |
578 | 549 |
|
579 | 550 | def save_response_to_json(self, output_dir=None):
|
580 | 551 | """Saves the response to INDIVIDUAL text files with the name [ATLAS\_ID].json.
|
@@ -776,50 +747,64 @@ def __init__(self,
|
776 | 747 | list_name: str = None,
|
777 | 748 | get_response: bool = False,
|
778 | 749 | api_config_file: str = None,
|
| 750 | + chunk_size: int = 100, # NEW: chunk size parameter |
779 | 751 | **kwargs,
|
780 | 752 | ):
|
781 |
| - """WRITE - Remove ATLAS\_IDs from a custom list |
| 753 | + """ |
| 754 | + WRITE - Remove ATLAS_IDs from a custom list. |
782 | 755 |
|
783 | 756 | Parameters
|
784 |
| - ------------ |
| 757 | + ---------- |
785 | 758 | array_ids: np.array, list, tuple
|
786 | 759 | The ATLAS IDs. Can be an array a list or a tuple.
|
787 | 760 | list_name: str
|
788 |
| - The name of the list you want (NOT THE NUMBER). e.g. 'mookodi'. |
| 761 | + The name of the list you want (NOT THE NUMBER). e.g. 'mookodi'. |
789 | 762 | get_response: bool
|
790 |
| - If True, will get the response on instanciation |
| 763 | + If True, will get the response on instantiation. |
791 | 764 | api_config_file: str
|
792 |
| - By default will use you api_config_MINE.yaml file. |
| 765 | + Optional path to your API config file. |
| 766 | + chunk_size: int |
| 767 | + Max number of IDs per chunk. Default is 100. |
793 | 768 | """
|
794 | 769 | super().__init__(api_config_file, **kwargs)
|
795 | 770 | self.url = self.apiURL + 'objectgroupsdelete/'
|
796 | 771 | self.array_ids = array_ids
|
797 |
| - self.object_group_id = self.dict_list_id[list_name][0] # object group id is the number of the custom list |
| 772 | + self.chunk_size = chunk_size # NEW: store chunk size |
| 773 | + self.object_group_id = self.dict_list_id[list_name][0] |
| 774 | + self.response_data = [] |
798 | 775 |
|
799 |
| - # if self.array_ids smaller than 100 we can just create the payload and use get_response from ATLASAPIBase |
800 |
| - # if self.array_ids is larger than 100 we need to call chunk_get_response(_quiet) |
801 |
| - if self.array_ids.shape[0] > 100: |
| 776 | + if self.array_ids.shape[0] > self.chunk_size: |
802 | 777 | self.chunk_get_response_quiet()
|
803 | 778 | else:
|
804 |
| - self.payload = {'objectid': ','.join(map(str, self.array_ids)), |
805 |
| - 'objectgroupid': self.object_group_id |
806 |
| - } |
807 |
| - |
808 |
| - |
809 |
| - def chunk_get_response_quiet(self): |
810 |
| - """Chunks the request in groups of 100 so don't get timed out by the server. No progress bar.""" |
811 |
| - # Split array_ids into chunks of 100 |
812 |
| - chunks = [self.array_ids[i:i + 100] for i in range(0, len(self.array_ids), 100)] |
| 779 | + self.payload = { |
| 780 | + 'objectid': ','.join(map(str, self.array_ids)), |
| 781 | + 'objectgroupid': self.object_group_id |
| 782 | + } |
| 783 | + |
| 784 | + def chunk_get_response_quiet(self, max_retries=3, backoff_range=(1, 5)): |
| 785 | + chunks = [self.array_ids[i:i + self.chunk_size] |
| 786 | + for i in range(0, len(self.array_ids), self.chunk_size)] |
| 787 | + self.response_data = [] |
813 | 788 |
|
814 |
| - # Iterate over each chunk and make separate requests |
815 |
| - for chunk in chunks: |
| 789 | + for idx, chunk in enumerate(chunks): |
816 | 790 | array_ids_str = ','.join(map(str, chunk))
|
817 |
| - self.payload = {'objectid': array_ids_str, |
818 |
| - 'objectgroupid': self.object_group_id |
819 |
| - } |
820 |
| - |
821 |
| - _response = self.get_response(inplace=False) |
822 |
| - self.response_data.extend(_response) |
| 791 | + self.payload = {'objectid': array_ids_str, 'objectgroupid': self.object_group_id} |
| 792 | + |
| 793 | + attempt = 0 |
| 794 | + while attempt < max_retries: |
| 795 | + try: |
| 796 | + _response = self.get_response(inplace=False) |
| 797 | + self.response_data.extend(_response) |
| 798 | + break |
| 799 | + except Exception as e: |
| 800 | + attempt += 1 |
| 801 | + wait_time = random.uniform(*backoff_range) |
| 802 | + logging.warning( |
| 803 | + f"[Chunk {idx+1}/{len(chunks)}] Retry {attempt}/{max_retries} after error: {e}. Sleeping {wait_time:.2f}s...") |
| 804 | + time.sleep(wait_time) |
| 805 | + else: |
| 806 | + logging.error( |
| 807 | + f"[Chunk {idx+1}/{len(chunks)}] Failed after {max_retries} retries. Chunk skipped.") |
823 | 808 |
|
824 | 809 | class WriteObjectDetectionListNumber(APIClient):
|
825 | 810 | def __init__(self,
|
|
0 commit comments