Skip to content

Commit 42f4d93

Browse files
committed
Added a new _get_prof_prod method, which can do the fuzzy matching of radii that will help us sort out issue #1401
1 parent 54e7cdd commit 42f4d93

File tree

1 file changed

+38
-26
lines changed

1 file changed

+38
-26
lines changed

xga/sources/base.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# This code is a part of X-ray: Generate and Analyse (XGA), a module designed for the XMM Cluster Survey (XCS).
2-
# Last modified by David J Turner (turne540@msu.edu) 16/07/2025, 13:37. Copyright (c) The Contributors
2+
# Last modified by David J Turner (turne540@msu.edu) 18/07/2025, 09:11. Copyright (c) The Contributors
33

44
import gc
55
import os
@@ -546,6 +546,7 @@ def __init__(self, ra: float, dec: float, redshift: float = None, name: str = No
546546
self._load_fits = load_fits
547547
self._load_products = load_products
548548
self._load_spectra = load_spectra
549+
self._load_profiles = load_profiles
549550

550551
# Firstly, we have all the properties
551552
@property
@@ -2275,33 +2276,44 @@ def _get_prof_prod(self, search_key: str, obs_id: str = None, inst: str = None,
22752276
were multiple matching products).
22762277
:rtype: Union[BaseProfile1D, List[BaseProfile1D]]
22772278
"""
2278-
if all([lo_en is None, hi_en is None]):
2279-
energy_key = "_"
2280-
elif all([lo_en is not None, hi_en is not None]):
2281-
energy_key = "bound_{l}-{h}_".format(l=lo_en.to('keV').value, h=hi_en.to('keV').value)
2282-
else:
2283-
raise ValueError("lo_en and hi_en must be either BOTH None or BOTH an Astropy quantity.")
2279+
# Fetch all the matching profiles for the specified telescope
2280+
matched_prods = self.get_products(search_key, obs_id, inst, just_obj=True, telescope=telescope)
22842281

2285-
if central_coord is None:
2286-
central_coord = self.default_coord
2287-
cen_chunk = "ra{r}_dec{d}_".format(r=central_coord[0].value, d=central_coord[1].value)
2282+
matched_prods: List[BaseProfile1D]
22882283

2284+
# Matching the radii is going to take a maybe slightly (but practically not really) dangerous approach. The
2285+
# radii passed here can either mean annulus bounds, or centres of those annuli. So we compare the passed
2286+
# radii to both of those pieces of information for each profile.
22892287
if radii is not None:
2288+
# Makes sure the radii are in degrees, as this is the base distance unit used in XGA
22902289
radii = self.convert_radius(radii, 'deg')
2291-
rad_chunk = "r" + "_".join(radii.value.astype(str))
2292-
rad_info = True
2293-
else:
2294-
rad_info = False
22952290

2296-
broad_prods = self.get_products(search_key, obs_id, inst, just_obj=False, telescope=telescope)
2297-
matched_prods = []
2298-
for p in broad_prods:
2299-
rad_str = p[-2].split("_st")[0].split(cen_chunk)[-1]
2300-
2301-
if cen_chunk in p[-2] and energy_key in p[-2] and rad_info and rad_str == rad_chunk:
2302-
matched_prods.append(p[-1])
2303-
elif cen_chunk in p[-2] and energy_key in p[-2] and not rad_info:
2304-
matched_prods.append(p[-1])
2291+
# First we'll check which profiles have the same number of radii as those that have
2292+
# been passed in by the user
2293+
matched_prods = [m_prod for m_prod in matched_prods if len(radii) == len(m_prod.radii) or
2294+
len(radii) == len(m_prod.annulus_bounds)]
2295+
# Then look for actual radii matches - we use the allclose() method here to check that the
2296+
# radii of the annuli are all within a very small tolerance of the passed radii. This is
2297+
# to head off problems we've had with float precision, the last digit of the float gets flipped
2298+
# and then exact comparisons no longer work
2299+
matched_prods = [m_prod for m_prod in matched_prods
2300+
if np.allclose(radii, m_prod.deg_radii, rtol=0, atol=RAD_MATCH_PRECISION)
2301+
or np.allclose(radii, self.convert_radius(m_prod.annulus_bounds, 'deg'), rtol=0,
2302+
atol=RAD_MATCH_PRECISION)]
2303+
2304+
# Now onto matching to some of the other information that may have been passed to this method
2305+
# First the energy bounds, making sure we convert the input energy to keV
2306+
if lo_en is not None:
2307+
lo_en = lo_en.to('keV')
2308+
matched_prods = [m_prod for m_prod in matched_prods if m_prod.energy_bounds[0] == lo_en]
2309+
if hi_en is not None:
2310+
hi_en = hi_en.to('keV')
2311+
matched_prods = [m_prod for m_prod in matched_prods if m_prod.energy_bounds[1] == hi_en]
2312+
2313+
# The central coordinate is also checked against the current default coordinate if the
2314+
# user didn't pass anything else in to override that
2315+
check_coord = self.default_coord if central_coord is None else central_coord
2316+
matched_prods = [m_prod for m_prod in matched_prods if (m_prod.centre == check_coord).all()]
23052317

23062318
return matched_prods
23072319

@@ -2927,8 +2939,8 @@ def update_products(self, prod_obj: Union[BaseProduct, BaseAggregateProduct, Bas
29272939
inven.drop_duplicates(subset=None, keep='first', inplace=True)
29282940
inven.to_csv(OUTPUT + "{t}/profiles/{n}/inventory.csv".format(t=tel, n=self.name), index=False)
29292941

2930-
def get_products(self, p_type: str, obs_id: str = None, inst: str = None,
2931-
extra_key: str = None, just_obj: bool = True, telescope: str = None) -> List[BaseProduct]:
2942+
def get_products(self, p_type: str, obs_id: str = None, inst: str = None, extra_key: str = None,
2943+
just_obj: bool = True, telescope: str = None) -> Union[List[BaseProduct], List[BaseProfile1D]]:
29322944
"""
29332945
This is the getter for the products data structure of Source objects. Passing a product type
29342946
such as 'events' or 'images' will return every matching entry in the products data structure.
@@ -2941,7 +2953,7 @@ def get_products(self, p_type: str, obs_id: str = None, inst: str = None,
29412953
or the other information that goes with it like ObsID and instrument.
29422954
:param str telescope: Optionally, a specific telescope to search can be supplied.
29432955
:return: List of matching products.
2944-
:rtype: List[BaseProduct]
2956+
:rtype: Union[List[BaseProduct], List[BaseProfile1D]]
29452957
"""
29462958
def unpack_list(to_unpack: list):
29472959
"""

0 commit comments

Comments
 (0)