Skip to content

WIP: MAST query result cache support outline #1578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
23 changes: 21 additions & 2 deletions astroquery/mast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ def __init__(self, mast_token=None):

super().__init__()

self.name = "Mast"

# Initializing API connections
self._portal_api_connection = PortalAPI(self._session)
self._service_api_connection = ServiceAPI(self._session)
self._portal_api_connection = PortalAPI(self._session, self.name)
self._service_api_connection = ServiceAPI(self._session, self.name)

if mast_token:
self._authenticated = self._auth_obj = MastAuth(self._session, mast_token)
Expand Down Expand Up @@ -59,6 +61,23 @@ def _login(self, token=None, store_token=False, reenter_token=False):

return self._auth_obj.login(token, store_token, reenter_token)


@property
def cache_location(self):
return super().cache_location

@cache_location.setter
def cache_location(self, loc):
self._cache_location = Path(loc)
self._portal_api_connection.cache_location = loc
self._service_api_connection.cache_location = loc

def reset_cache_location(self):
"""Resets the cache location to the default astropy cache"""
self._cache_location = None
self._portal_api_connection.reset_cache_location()
self._service_api_connection.reset_cache_location()

def session_info(self, verbose=True):
"""
Displays information about current MAST user, and returns user info dictionary.
Expand Down
57 changes: 51 additions & 6 deletions astroquery/mast/discovery_portal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import uuid
import json
import time
import re

import numpy as np

from urllib.parse import quote as urlencode

from astropy.table import Table, vstack, MaskedColumn

from ..query import BaseQuery
from astroquery import cache_conf
from ..query import BaseQuery, AstroQuery, to_cache
from ..utils import async_to_sync
from ..utils.class_or_instance import class_or_instance
from ..exceptions import InputWarning, NoResultsWarning, RemoteServiceError
Expand Down Expand Up @@ -126,14 +128,17 @@ class PortalAPI(BaseQuery):
_column_configs = dict()
_current_service = None

def __init__(self, session=None):
def __init__(self, session=None, name=None):

super().__init__()
if session:
self._session = session

if name:
self.name = name

def _request(self, method, url, params=None, data=None, headers=None,
files=None, stream=False, auth=None, retrieve_all=True):
files=None, cache=None, stream=False, auth=None, retrieve_all=True):
"""
Override of the parent method:
A generic HTTP request method, similar to `~requests.Session.request`
Expand All @@ -160,6 +165,8 @@ def _request(self, method, url, params=None, data=None, headers=None,
files : None or dict
stream : bool
See `~requests.request`
cache : bool
Optional, if specified, overrides global cache settings.
retrieve_all : bool
Default True. Retrieve all pages of data or just the one indicated in the params value.

Expand All @@ -169,6 +176,18 @@ def _request(self, method, url, params=None, data=None, headers=None,
The response from the server.
"""

if cache is None: # Global caching not overridden
cache = cache_conf.cache_active


if cache: # cache active, look for cached file
cacher = self._get_cacher(method, url, data, headers, retrieve_all)
response = cacher.from_cache(self.cache_location, cache_conf.cache_timeout)
if response:
return response


# Either cache is not active or a cached file was not found, proceed with query
start_time = time.time()
all_responses = []
total_pages = 1
Expand Down Expand Up @@ -208,8 +227,30 @@ def _request(self, method, url, params=None, data=None, headers=None,

data = data.replace("page%22%3A%20"+str(cur_page)+"%2C", "page%22%3A%20"+str(cur_page+1)+"%2C")

if cache: # cache is active, so cache response before returning
to_cache(all_responses, cacher.request_file(self.cache_location))

return all_responses


def _get_cacher(self, method, url, data, headers, retrieve_all):
"""
Return an object that can cache the HTTP request based on the supplied arguments
"""

# cacheBreaker parameter (to underlying MAST service) is not relevant (and breaks) local caching
# remove it from part of the cache key
data_no_cache_breaker = re.sub(r'^(.+)cacheBreaker%22%3A%20%22.+%22', r'\1', data)

# include retrieve_all as part of the cache key by appending it to data
# it cannot be added as part of req_kwargs dict, as it will be rejected by AstroQuery
data_w_retrieve_all = data_no_cache_breaker + " retrieve_all={}".format(retrieve_all)
req_kwargs = dict(
data=data_no_cache_breaker,
headers=headers
)
return AstroQuery(method, url, **req_kwargs)

def _get_col_config(self, service, fetch_name=None):
"""
Gets the columnsConfig entry for given service and stores it in `self._column_configs`.
Expand Down Expand Up @@ -320,6 +361,10 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
Default None.
Can be used to override the default behavior of all results being returned to obtain
a specific page of results.
cache : Boolean, optional
try to use cached the query result if set to True
cache_opts : dict, optional
cache options, details TBD, e.g., cache expiration policy, etc.
**kwargs :
See MashupRequest properties
`here <https://mast.stsci.edu/api/v0/class_mashup_1_1_mashup_request.html>`__
Expand All @@ -333,7 +378,7 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
# setting self._current_service
if service not in self._column_configs.keys():
fetch_name = kwargs.pop('fetch_name', None)
self._get_col_config(service, fetch_name)
self._get_col_config(service, fetch_name, cache)
self._current_service = service

# setting up pagination
Expand Down Expand Up @@ -364,7 +409,7 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa

return response

def build_filter_set(self, column_config_name, service_name=None, **filters):
def build_filter_set(self, column_config_name, service_name=None, cache=False, **filters):
"""
Takes user input dictionary of filters and returns a filterlist that the Mashup can understand.

Expand Down Expand Up @@ -392,7 +437,7 @@ def build_filter_set(self, column_config_name, service_name=None, **filters):
service_name = column_config_name

if not self._column_configs.get(service_name):
self._get_col_config(service_name, fetch_name=column_config_name)
self._get_col_config(service_name, fetch_name=column_config_name, cache=cache)

caom_col_config = self._column_configs[service_name]

Expand Down
2 changes: 1 addition & 1 deletion astroquery/mast/missions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, *, mission='hst', service='search'):
self.service = service
self.mission = mission
self.limit = 5000

service_dict = {self.service: {'path': self.service, 'args': {}}}
self._service_api_connection.set_service_params(service_dict, f"{self.service}/{self.mission}")

Expand Down
4 changes: 2 additions & 2 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def query_object_async(self, objectname, *, radius=0.2*u.deg, pagesize=None, pag
return self.query_region_async(coordinates, radius=radius, pagesize=pagesize, page=page)

@class_or_instance
def query_criteria_async(self, *, pagesize=None, page=None, **criteria):
def query_criteria_async(self, *, pagesize=None, page=None, cache=False, cache_opts=None, **criteria):
"""
Given an set of criteria, returns a list of MAST observations.
Valid criteria are returned by ``get_metadata("observations")``
Expand Down Expand Up @@ -291,7 +291,7 @@ def query_criteria_async(self, *, pagesize=None, page=None, **criteria):
params = {"columns": "*",
"filters": mashup_filters}

return self._portal_api_connection.service_request_async(service, params)
return self._portal_api_connection.service_request_async(service, params, cache=cache, cache_opts=cache_opts)

def query_region_count(self, coordinates, *, radius=0.2*u.deg, pagesize=None, page=None):
"""
Expand Down
9 changes: 6 additions & 3 deletions astroquery/mast/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,15 @@ class ServiceAPI(BaseQuery):
REQUEST_URL = conf.server + "/api/v0.1/"
SERVICES = {}

def __init__(self, session=None):
def __init__(self, session=None, name=None):

super().__init__()
if session:
self._session = session

if name:
self.name = name

self.TIMEOUT = conf.timeout

def set_service_params(self, service_dict, service_name="", server_prefix=False):
Expand Down Expand Up @@ -143,7 +146,7 @@ def set_service_params(self, service_dict, service_name="", server_prefix=False)
self.SERVICES = service_dict

def _request(self, method, url, params=None, data=None, headers=None,
files=None, stream=False, auth=None, cache=False, use_json=False):
files=None, stream=False, auth=None, cache=None, use_json=False):
"""
Override of the parent method:
A generic HTTP request method, similar to `~requests.Session.request`
Expand All @@ -168,7 +171,7 @@ def _request(self, method, url, params=None, data=None, headers=None,
stream : bool
See `~requests.request`
cache : bool
Default False. Use of built in caching
Optional, if specified, overrides global cache settings.
use_json: bool
Default False. if True then data is already in json format.

Expand Down
50 changes: 50 additions & 0 deletions astroquery/mast/tests/test_mast.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,25 @@ def patch_post(request):
return mp


_num_mockreturn = 0


def _get_num_mockreturn():
global _num_mockreturn
return _num_mockreturn


def _reset_mockreturn_counter():
global _num_mockreturn
_num_mockreturn = 0


def _inc_num_mockreturn():
global _num_mockreturn
_num_mockreturn += 1
return _num_mockreturn


def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwargs):
if "columnsconfig" in url:
if "Mast.Catalogs.Tess.Cone" in data:
Expand All @@ -102,6 +121,9 @@ def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwar
with open(filename, 'rb') as infile:
content = infile.read()

# For cache tests
_inc_num_mockreturn()

# returning as list because this is what the mast _request function does
return [MockResponse(content)]

Expand Down Expand Up @@ -365,6 +387,34 @@ def test_query_observations_criteria_async(patch_post):
assert isinstance(responses, list)


def test_query_observations_criteria_async_cache(patch_post):
_reset_mockreturn_counter()
assert 0 == _get_num_mockreturn(), "Mock HTTP call counter reset to 0"

responses_cache_miss = mast.Observations.query_criteria_async(dataproduct_type=["image"],
proposal_pi="Ost*",
s_dec=[43.5, 45.5], cache=True)
assert isinstance(responses_cache_miss, list)
num_mockreturn_after_first_call = _get_num_mockreturn()
assert num_mockreturn_after_first_call > 0, "Cache miss, some underlying HTTP call"

responses_cache_hit = mast.Observations.query_criteria_async(dataproduct_type=["image"],
proposal_pi="Ost*",
s_dec=[43.5, 45.5], cache=True)
# assert the cached response is the same
assert len(responses_cache_hit) == len(responses_cache_miss)
assert responses_cache_hit[0].text == responses_cache_miss[0].text
# ensure the response really comes from the cache
assert num_mockreturn_after_first_call == _get_num_mockreturn(), \
'Cache hit: should reach cache only, i.e., no HTTP call'

responses_no_cache = mast.Observations.query_criteria_async(dataproduct_type=["image"],
proposal_pi="Ost*",
s_dec=[43.5, 45.5], cache=False)
assert isinstance(responses_no_cache, list)
assert _get_num_mockreturn() > num_mockreturn_after_first_call, "Cache off , some underlying HTTP call"


def test_observations_query_criteria(patch_post):
# without position
result = mast.Observations.query_criteria(dataproduct_type=["image"],
Expand Down
21 changes: 20 additions & 1 deletion astroquery/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ def from_cache(self, cache_location, cache_timeout):
current_time = datetime.utcnow()
cache_time = datetime.utcfromtimestamp(request_file.stat().st_mtime)
expired = current_time-cache_time > timedelta(seconds=cache_timeout)

if not expired:
with open(request_file, "rb") as f:
response = pickle.load(f)
if not isinstance(response, requests.Response):
if not isinstance(response, requests.Response)and not isinstance(response, list):
# MAST query response is a list of Response
response = None
else:
log.debug(f"Cache expired for {request_file}...")
Expand Down Expand Up @@ -228,6 +230,20 @@ def _response_hook(self, response, *args, **kwargs):
f"{response.text}\n"
f"-----------------------------------------", '\t')
log.log(5, f"HTTP response\n{response_log}")

def clear_cache():
"""Removes all cache files."""

cache_files = [x for x in os.listdir(self.cache_location) if x.endswidth("pickle")]
for fle in cache_files:
os.remove(fle)

def reset_cache_preferences():
"""Resets cache preferences to default values"""

self.reset_cache_location()
self.cache_active = conf.default_cache_active
self.cache_timeout = conf.default_cache_timeout

@property
def cache_location(self):
Expand Down Expand Up @@ -336,12 +352,14 @@ def _request(self, method, url,
files=files, timeout=timeout, json=json)
if not cache:
with cache_conf.set_temp("cache_active", False):

response = query.request(self._session, stream=stream,
auth=auth, verify=verify,
allow_redirects=allow_redirects,
json=json)
else:
response = query.from_cache(self.cache_location, cache_conf.cache_timeout)

if not response:
response = query.request(self._session,
self.cache_location,
Expand Down Expand Up @@ -495,6 +513,7 @@ def __exit__(self, exc_type, exc_value, traceback):
return False



class QueryWithLogin(BaseQuery):
"""
This is the base class for all the query classes which are required to
Expand Down
1 change: 1 addition & 0 deletions astroquery/tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from astroquery.query import QueryWithLogin
from astroquery import cache_conf


URL1 = "http://fakeurl.edu"
URL2 = "http://fakeurl.ac.uk"

Expand Down