Skip to content

Commit 9cb1b84

Browse files
orionleeceb8
authored andcommitted
MAST query result cache: Observations.query_criteria()
1 parent 7ab4992 commit 9cb1b84

File tree

4 files changed

+99
-11
lines changed

4 files changed

+99
-11
lines changed

astroquery/mast/discovery_portal.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import uuid
1111
import json
1212
import time
13+
import re
1314

1415
import numpy as np
1516

@@ -210,7 +211,39 @@ def _request(self, method, url, params=None, data=None, headers=None,
210211

211212
return all_responses
212213

213-
def _get_col_config(self, service, fetch_name=None):
214+
def _request_w_cache(self, method, url, data=None, headers=None, retrieve_all=True,
215+
cache=False, cache_opts=None):
216+
# Note: the method only exposes 4 parameters of the underlying _request() function
217+
# to play nice with existing mocks
218+
# Caching: follow BaseQuery._request()'s pattern, which uses an AstroQuery object
219+
if not cache:
220+
response = self._request(method, url, data=data, headers=headers, retrieve_all=retrieve_all)
221+
else:
222+
cacher = self._get_cacher(method, url, data, headers, retrieve_all)
223+
response = cacher.from_cache(self.cache_location)
224+
if not response:
225+
response = self._request(method, url, data=data, headers=headers, retrieve_all=retrieve_all)
226+
to_cache(response, cacher.request_file(self.cache_location))
227+
return response
228+
229+
def _get_cacher(self, method, url, data, headers, retrieve_all):
230+
"""
231+
Return an object that can cache the HTTP request based on the supplied arguments
232+
"""
233+
234+
# cacheBreaker parameter (to underlying MAST service) is not relevant (and breaks) local caching
235+
# remove it from part of the cache key
236+
data_no_cache_breaker = re.sub(r'^(.+)cacheBreaker%22%3A%20%22.+%22', r'\1', data)
237+
# include retrieve_all as part of the cache key by appending it to data
238+
# it cannot be added as part of req_kwargs dict, as it will be rejected by AstroQuery
239+
data_w_retrieve_all = data_no_cache_breaker + " retrieve_all={}".format(retrieve_all)
240+
req_kwargs = dict(
241+
data=data_no_cache_breaker,
242+
headers=headers
243+
)
244+
return AstroQuery(method, url, **req_kwargs)
245+
246+
def _get_col_config(self, service, fetch_name=None, cache=False):
214247
"""
215248
Gets the columnsConfig entry for given service and stores it in `self._column_configs`.
216249
@@ -246,7 +279,7 @@ def _get_col_config(self, service, fetch_name=None):
246279
if more:
247280
mashup_request = {'service': all_name, 'params': {}, 'format': 'extjs'}
248281
req_string = _prepare_service_request_string(mashup_request)
249-
response = self._request("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers)
282+
response = self._request_w_cache("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers, cache=cache)
250283
json_response = response[0].json()
251284

252285
self._column_configs[service].update(json_response['data']['Tables'][0]
@@ -300,7 +333,7 @@ def _parse_result(self, responses, verbose=False):
300333
return all_results
301334

302335
@class_or_instance
303-
def service_request_async(self, service, params, pagesize=None, page=None, **kwargs):
336+
def service_request_async(self, service, params, pagesize=None, page=None, cache=False, cache_opts=None, **kwargs):
304337
"""
305338
Given a Mashup service and parameters, builds and excecutes a Mashup query.
306339
See documentation `here <https://mast.stsci.edu/api/v0/class_mashup_1_1_mashup_request.html>`__
@@ -320,6 +353,10 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
320353
Default None.
321354
Can be used to override the default behavior of all results being returned to obtain
322355
a specific page of results.
356+
cache : Boolean, optional
357+
try to use cached the query result if set to True
358+
cache_opts : dict, optional
359+
cache options, details TBD, e.g., cache expiration policy, etc.
323360
**kwargs :
324361
See MashupRequest properties
325362
`here <https://mast.stsci.edu/api/v0/class_mashup_1_1_mashup_request.html>`__
@@ -333,7 +370,7 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
333370
# setting self._current_service
334371
if service not in self._column_configs.keys():
335372
fetch_name = kwargs.pop('fetch_name', None)
336-
self._get_col_config(service, fetch_name)
373+
self._get_col_config(service, fetch_name, cache)
337374
self._current_service = service
338375

339376
# setting up pagination
@@ -359,12 +396,12 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
359396
mashup_request[prop] = value
360397

361398
req_string = _prepare_service_request_string(mashup_request)
362-
response = self._request("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers,
363-
retrieve_all=retrieve_all)
399+
response = self._request_w_cache("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers,
400+
retrieve_all=retrieve_all, cache=cache, cache_opts=cache_opts)
364401

365402
return response
366403

367-
def build_filter_set(self, column_config_name, service_name=None, **filters):
404+
def build_filter_set(self, column_config_name, service_name=None, cache=False, **filters):
368405
"""
369406
Takes user input dictionary of filters and returns a filterlist that the Mashup can understand.
370407
@@ -392,7 +429,7 @@ def build_filter_set(self, column_config_name, service_name=None, **filters):
392429
service_name = column_config_name
393430

394431
if not self._column_configs.get(service_name):
395-
self._get_col_config(service_name, fetch_name=column_config_name)
432+
self._get_col_config(service_name, fetch_name=column_config_name, cache=cache)
396433

397434
caom_col_config = self._column_configs[service_name]
398435

astroquery/mast/observations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def query_object_async(self, objectname, *, radius=0.2*u.deg, pagesize=None, pag
246246
return self.query_region_async(coordinates, radius=radius, pagesize=pagesize, page=page)
247247

248248
@class_or_instance
249-
def query_criteria_async(self, *, pagesize=None, page=None, **criteria):
249+
def query_criteria_async(self, *, pagesize=None, page=None, cache=False, cache_opts=None, **criteria):
250250
"""
251251
Given an set of criteria, returns a list of MAST observations.
252252
Valid criteria are returned by ``get_metadata("observations")``
@@ -291,7 +291,7 @@ def query_criteria_async(self, *, pagesize=None, page=None, **criteria):
291291
params = {"columns": "*",
292292
"filters": mashup_filters}
293293

294-
return self._portal_api_connection.service_request_async(service, params)
294+
return self._portal_api_connection.service_request_async(service, params, cache=cache, cache_opts=cache_opts)
295295

296296
def query_region_count(self, coordinates, *, radius=0.2*u.deg, pagesize=None, page=None):
297297
"""

astroquery/mast/tests/test_mast.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,25 @@ def patch_post(request):
7676
return mp
7777

7878

79+
_num_mockreturn = 0
80+
81+
82+
def _get_num_mockreturn():
83+
global _num_mockreturn
84+
return _num_mockreturn
85+
86+
87+
def _reset_mockreturn_counter():
88+
global _num_mockreturn
89+
_num_mockreturn = 0
90+
91+
92+
def _inc_num_mockreturn():
93+
global _num_mockreturn
94+
_num_mockreturn += 1
95+
return _num_mockreturn
96+
97+
7998
def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwargs):
8099
if "columnsconfig" in url:
81100
if "Mast.Catalogs.Tess.Cone" in data:
@@ -102,6 +121,9 @@ def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwar
102121
with open(filename, 'rb') as infile:
103122
content = infile.read()
104123

124+
# For cache tests
125+
_inc_num_mockreturn()
126+
105127
# returning as list because this is what the mast _request function does
106128
return [MockResponse(content)]
107129

@@ -365,6 +387,34 @@ def test_query_observations_criteria_async(patch_post):
365387
assert isinstance(responses, list)
366388

367389

390+
def test_query_observations_criteria_async_cache(patch_post):
391+
_reset_mockreturn_counter()
392+
assert 0 == _get_num_mockreturn(), "Mock HTTP call counter reset to 0"
393+
394+
responses_cache_miss = mast.Observations.query_criteria_async(dataproduct_type=["image"],
395+
proposal_pi="Ost*",
396+
s_dec=[43.5, 45.5], cache=True)
397+
assert isinstance(responses_cache_miss, list)
398+
num_mockreturn_after_first_call = _get_num_mockreturn()
399+
assert num_mockreturn_after_first_call > 0, "Cache miss, some underlying HTTP call"
400+
401+
responses_cache_hit = mast.Observations.query_criteria_async(dataproduct_type=["image"],
402+
proposal_pi="Ost*",
403+
s_dec=[43.5, 45.5], cache=True)
404+
# assert the cached response is the same
405+
assert len(responses_cache_hit) == len(responses_cache_miss)
406+
assert responses_cache_hit[0].text == responses_cache_miss[0].text
407+
# ensure the response really comes from the cache
408+
assert num_mockreturn_after_first_call == _get_num_mockreturn(), \
409+
'Cache hit: should reach cache only, i.e., no HTTP call'
410+
411+
responses_no_cache = mast.Observations.query_criteria_async(dataproduct_type=["image"],
412+
proposal_pi="Ost*",
413+
s_dec=[43.5, 45.5], cache=False)
414+
assert isinstance(responses_no_cache, list)
415+
assert _get_num_mockreturn() > num_mockreturn_after_first_call, "Cache off , some underlying HTTP call"
416+
417+
368418
def test_observations_query_criteria(patch_post):
369419
# without position
370420
result = mast.Observations.query_criteria(dataproduct_type=["image"],

astroquery/query.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def from_cache(self, cache_location, cache_timeout):
123123
if not expired:
124124
with open(request_file, "rb") as f:
125125
response = pickle.load(f)
126-
if not isinstance(response, requests.Response):
126+
if not isinstance(response, requests.Response)and not isinstance(response, list):
127+
# MAST query response is a list of Response
127128
response = None
128129
else:
129130
log.debug(f"Cache expired for {request_file}...")

0 commit comments

Comments
 (0)