diff --git a/astroquery/mast/core.py b/astroquery/mast/core.py index 712579a383..9fe63909cf 100644 --- a/astroquery/mast/core.py +++ b/astroquery/mast/core.py @@ -25,9 +25,11 @@ def __init__(self, mast_token=None): super().__init__() + self.name = "Mast" + # Initializing API connections - self._portal_api_connection = PortalAPI(self._session) - self._service_api_connection = ServiceAPI(self._session) + self._portal_api_connection = PortalAPI(self._session, self.name) + self._service_api_connection = ServiceAPI(self._session, self.name) if mast_token: self._authenticated = self._auth_obj = MastAuth(self._session, mast_token) @@ -59,6 +61,23 @@ def _login(self, token=None, store_token=False, reenter_token=False): return self._auth_obj.login(token, store_token, reenter_token) + + @property + def cache_location(self): + return super().cache_location + + @cache_location.setter + def cache_location(self, loc): + self._cache_location = Path(loc) + self._portal_api_connection.cache_location = loc + self._service_api_connection.cache_location = loc + + def reset_cache_location(self): + """Resets the cache location to the default astropy cache""" + self._cache_location = None + self._portal_api_connection.reset_cache_location() + self._service_api_connection.reset_cache_location() + def session_info(self, verbose=True): """ Displays information about current MAST user, and returns user info dictionary. diff --git a/astroquery/mast/discovery_portal.py b/astroquery/mast/discovery_portal.py index 024d17cc96..512bdb3080 100644 --- a/astroquery/mast/discovery_portal.py +++ b/astroquery/mast/discovery_portal.py @@ -10,6 +10,7 @@ import uuid import json import time +import re import numpy as np @@ -17,7 +18,8 @@ from astropy.table import Table, vstack, MaskedColumn -from ..query import BaseQuery +from astroquery import cache_conf +from ..query import BaseQuery, AstroQuery, to_cache from ..utils import async_to_sync from ..utils.class_or_instance import class_or_instance from ..exceptions import InputWarning, NoResultsWarning, RemoteServiceError @@ -126,14 +128,17 @@ class PortalAPI(BaseQuery): _column_configs = dict() _current_service = None - def __init__(self, session=None): + def __init__(self, session=None, name=None): super().__init__() if session: self._session = session + if name: + self.name = name + def _request(self, method, url, params=None, data=None, headers=None, - files=None, stream=False, auth=None, retrieve_all=True): + files=None, cache=None, stream=False, auth=None, retrieve_all=True): """ Override of the parent method: A generic HTTP request method, similar to `~requests.Session.request` @@ -160,6 +165,8 @@ def _request(self, method, url, params=None, data=None, headers=None, files : None or dict stream : bool See `~requests.request` + cache : bool + Optional, if specified, overrides global cache settings. retrieve_all : bool Default True. Retrieve all pages of data or just the one indicated in the params value. @@ -169,6 +176,18 @@ def _request(self, method, url, params=None, data=None, headers=None, The response from the server. """ + if cache is None: # Global caching not overridden + cache = cache_conf.cache_active + + + if cache: # cache active, look for cached file + cacher = self._get_cacher(method, url, data, headers, retrieve_all) + response = cacher.from_cache(self.cache_location, cache_conf.cache_timeout) + if response: + return response + + + # Either cache is not active or a cached file was not found, proceed with query start_time = time.time() all_responses = [] total_pages = 1 @@ -208,8 +227,30 @@ def _request(self, method, url, params=None, data=None, headers=None, data = data.replace("page%22%3A%20"+str(cur_page)+"%2C", "page%22%3A%20"+str(cur_page+1)+"%2C") + if cache: # cache is active, so cache response before returning + to_cache(all_responses, cacher.request_file(self.cache_location)) + return all_responses + + def _get_cacher(self, method, url, data, headers, retrieve_all): + """ + Return an object that can cache the HTTP request based on the supplied arguments + """ + + # cacheBreaker parameter (to underlying MAST service) is not relevant (and breaks) local caching + # remove it from part of the cache key + data_no_cache_breaker = re.sub(r'^(.+)cacheBreaker%22%3A%20%22.+%22', r'\1', data) + + # include retrieve_all as part of the cache key by appending it to data + # it cannot be added as part of req_kwargs dict, as it will be rejected by AstroQuery + data_w_retrieve_all = data_no_cache_breaker + " retrieve_all={}".format(retrieve_all) + req_kwargs = dict( + data=data_no_cache_breaker, + headers=headers + ) + return AstroQuery(method, url, **req_kwargs) + def _get_col_config(self, service, fetch_name=None): """ Gets the columnsConfig entry for given service and stores it in `self._column_configs`. @@ -320,6 +361,10 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa Default None. Can be used to override the default behavior of all results being returned to obtain a specific page of results. + cache : Boolean, optional + try to use cached the query result if set to True + cache_opts : dict, optional + cache options, details TBD, e.g., cache expiration policy, etc. **kwargs : See MashupRequest properties `here `__ @@ -333,7 +378,7 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa # setting self._current_service if service not in self._column_configs.keys(): fetch_name = kwargs.pop('fetch_name', None) - self._get_col_config(service, fetch_name) + self._get_col_config(service, fetch_name, cache) self._current_service = service # setting up pagination @@ -364,7 +409,7 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa return response - def build_filter_set(self, column_config_name, service_name=None, **filters): + def build_filter_set(self, column_config_name, service_name=None, cache=False, **filters): """ Takes user input dictionary of filters and returns a filterlist that the Mashup can understand. @@ -392,7 +437,7 @@ def build_filter_set(self, column_config_name, service_name=None, **filters): service_name = column_config_name if not self._column_configs.get(service_name): - self._get_col_config(service_name, fetch_name=column_config_name) + self._get_col_config(service_name, fetch_name=column_config_name, cache=cache) caom_col_config = self._column_configs[service_name] diff --git a/astroquery/mast/missions.py b/astroquery/mast/missions.py index caa264aca7..4dba038eec 100644 --- a/astroquery/mast/missions.py +++ b/astroquery/mast/missions.py @@ -40,7 +40,7 @@ def __init__(self, *, mission='hst', service='search'): self.service = service self.mission = mission self.limit = 5000 - + service_dict = {self.service: {'path': self.service, 'args': {}}} self._service_api_connection.set_service_params(service_dict, f"{self.service}/{self.mission}") diff --git a/astroquery/mast/observations.py b/astroquery/mast/observations.py index 65d2186727..f593558681 100644 --- a/astroquery/mast/observations.py +++ b/astroquery/mast/observations.py @@ -246,7 +246,7 @@ def query_object_async(self, objectname, *, radius=0.2*u.deg, pagesize=None, pag return self.query_region_async(coordinates, radius=radius, pagesize=pagesize, page=page) @class_or_instance - def query_criteria_async(self, *, pagesize=None, page=None, **criteria): + def query_criteria_async(self, *, pagesize=None, page=None, cache=False, cache_opts=None, **criteria): """ Given an set of criteria, returns a list of MAST observations. Valid criteria are returned by ``get_metadata("observations")`` @@ -291,7 +291,7 @@ def query_criteria_async(self, *, pagesize=None, page=None, **criteria): params = {"columns": "*", "filters": mashup_filters} - return self._portal_api_connection.service_request_async(service, params) + return self._portal_api_connection.service_request_async(service, params, cache=cache, cache_opts=cache_opts) def query_region_count(self, coordinates, *, radius=0.2*u.deg, pagesize=None, page=None): """ diff --git a/astroquery/mast/services.py b/astroquery/mast/services.py index b82bc36cba..44ea048871 100644 --- a/astroquery/mast/services.py +++ b/astroquery/mast/services.py @@ -109,12 +109,15 @@ class ServiceAPI(BaseQuery): REQUEST_URL = conf.server + "/api/v0.1/" SERVICES = {} - def __init__(self, session=None): + def __init__(self, session=None, name=None): super().__init__() if session: self._session = session + if name: + self.name = name + self.TIMEOUT = conf.timeout def set_service_params(self, service_dict, service_name="", server_prefix=False): @@ -143,7 +146,7 @@ def set_service_params(self, service_dict, service_name="", server_prefix=False) self.SERVICES = service_dict def _request(self, method, url, params=None, data=None, headers=None, - files=None, stream=False, auth=None, cache=False, use_json=False): + files=None, stream=False, auth=None, cache=None, use_json=False): """ Override of the parent method: A generic HTTP request method, similar to `~requests.Session.request` @@ -168,7 +171,7 @@ def _request(self, method, url, params=None, data=None, headers=None, stream : bool See `~requests.request` cache : bool - Default False. Use of built in caching + Optional, if specified, overrides global cache settings. use_json: bool Default False. if True then data is already in json format. diff --git a/astroquery/mast/tests/test_mast.py b/astroquery/mast/tests/test_mast.py index 39d962ced5..2dbb845b31 100644 --- a/astroquery/mast/tests/test_mast.py +++ b/astroquery/mast/tests/test_mast.py @@ -76,6 +76,25 @@ def patch_post(request): return mp +_num_mockreturn = 0 + + +def _get_num_mockreturn(): + global _num_mockreturn + return _num_mockreturn + + +def _reset_mockreturn_counter(): + global _num_mockreturn + _num_mockreturn = 0 + + +def _inc_num_mockreturn(): + global _num_mockreturn + _num_mockreturn += 1 + return _num_mockreturn + + def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwargs): if "columnsconfig" in url: if "Mast.Catalogs.Tess.Cone" in data: @@ -102,6 +121,9 @@ def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwar with open(filename, 'rb') as infile: content = infile.read() + # For cache tests + _inc_num_mockreturn() + # returning as list because this is what the mast _request function does return [MockResponse(content)] @@ -365,6 +387,34 @@ def test_query_observations_criteria_async(patch_post): assert isinstance(responses, list) +def test_query_observations_criteria_async_cache(patch_post): + _reset_mockreturn_counter() + assert 0 == _get_num_mockreturn(), "Mock HTTP call counter reset to 0" + + responses_cache_miss = mast.Observations.query_criteria_async(dataproduct_type=["image"], + proposal_pi="Ost*", + s_dec=[43.5, 45.5], cache=True) + assert isinstance(responses_cache_miss, list) + num_mockreturn_after_first_call = _get_num_mockreturn() + assert num_mockreturn_after_first_call > 0, "Cache miss, some underlying HTTP call" + + responses_cache_hit = mast.Observations.query_criteria_async(dataproduct_type=["image"], + proposal_pi="Ost*", + s_dec=[43.5, 45.5], cache=True) + # assert the cached response is the same + assert len(responses_cache_hit) == len(responses_cache_miss) + assert responses_cache_hit[0].text == responses_cache_miss[0].text + # ensure the response really comes from the cache + assert num_mockreturn_after_first_call == _get_num_mockreturn(), \ + 'Cache hit: should reach cache only, i.e., no HTTP call' + + responses_no_cache = mast.Observations.query_criteria_async(dataproduct_type=["image"], + proposal_pi="Ost*", + s_dec=[43.5, 45.5], cache=False) + assert isinstance(responses_no_cache, list) + assert _get_num_mockreturn() > num_mockreturn_after_first_call, "Cache off , some underlying HTTP call" + + def test_observations_query_criteria(patch_post): # without position result = mast.Observations.query_criteria(dataproduct_type=["image"], diff --git a/astroquery/query.py b/astroquery/query.py index 998d1a1cdd..772067e819 100644 --- a/astroquery/query.py +++ b/astroquery/query.py @@ -119,10 +119,12 @@ def from_cache(self, cache_location, cache_timeout): current_time = datetime.utcnow() cache_time = datetime.utcfromtimestamp(request_file.stat().st_mtime) expired = current_time-cache_time > timedelta(seconds=cache_timeout) + if not expired: with open(request_file, "rb") as f: response = pickle.load(f) - if not isinstance(response, requests.Response): + if not isinstance(response, requests.Response)and not isinstance(response, list): + # MAST query response is a list of Response response = None else: log.debug(f"Cache expired for {request_file}...") @@ -228,6 +230,20 @@ def _response_hook(self, response, *args, **kwargs): f"{response.text}\n" f"-----------------------------------------", '\t') log.log(5, f"HTTP response\n{response_log}") + + def clear_cache(): + """Removes all cache files.""" + + cache_files = [x for x in os.listdir(self.cache_location) if x.endswidth("pickle")] + for fle in cache_files: + os.remove(fle) + + def reset_cache_preferences(): + """Resets cache preferences to default values""" + + self.reset_cache_location() + self.cache_active = conf.default_cache_active + self.cache_timeout = conf.default_cache_timeout @property def cache_location(self): @@ -336,12 +352,14 @@ def _request(self, method, url, files=files, timeout=timeout, json=json) if not cache: with cache_conf.set_temp("cache_active", False): + response = query.request(self._session, stream=stream, auth=auth, verify=verify, allow_redirects=allow_redirects, json=json) else: response = query.from_cache(self.cache_location, cache_conf.cache_timeout) + if not response: response = query.request(self._session, self.cache_location, @@ -495,6 +513,7 @@ def __exit__(self, exc_type, exc_value, traceback): return False + class QueryWithLogin(BaseQuery): """ This is the base class for all the query classes which are required to diff --git a/astroquery/tests/test_cache.py b/astroquery/tests/test_cache.py index c74110eb51..29ad675ebe 100644 --- a/astroquery/tests/test_cache.py +++ b/astroquery/tests/test_cache.py @@ -10,6 +10,7 @@ from astroquery.query import QueryWithLogin from astroquery import cache_conf + URL1 = "http://fakeurl.edu" URL2 = "http://fakeurl.ac.uk"