-
-
Notifications
You must be signed in to change notification settings - Fork 91
Adopt the array api #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mwcraig
wants to merge
111
commits into
astropy:main
Choose a base branch
from
mwcraig:explore-array-api
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Adopt the array api #885
Changes from 83 commits
Commits
Show all changes
111 commits
Select commit
Hold shift + click to select a range
6dd9166
Initial attempt to adopt array API
mwcraig d3b6203
A couple more changes
mwcraig f0bc6b8
Add function for check if something is an array
mwcraig 424ebdc
Remove numpy from core
mwcraig 7f5050d
WIP rewrite of a boolean index
mwcraig 6a6bfa2
WIP testing updates
mwcraig 1d2d171
Change more almost_equal to allclose
mwcraig 22bc04d
Use the random number generator from numpy
mwcraig 245d641
Remove more in-place array operations
mwcraig 55b90be
Set a reasonable tolerance value in float comparisons
mwcraig 23de261
Continue to use numpy arrays in a few places
mwcraig dffadce
Refactor warning handling for numerical warnings
mwcraig 12ff8b0
Avoid explicit use of numpy masked arrays
mwcraig 63f81ba
Rewrite test to not modify array in-place
mwcraig 4ae3fc6
Use assert_allclose instead of older alternatives
mwcraig 7b2ed18
Initial attempt to adopt array API in combiner
mwcraig 1cbb7bc
Write a mask-aware sum function that is array API compatible
mwcraig da90350
Eliminate most use of numpy in tests
mwcraig bddb2de
Fix a couple calls to numpy masked array
mwcraig 8ac481a
One more workaround for immutable arrays
mwcraig 220a96e
Fix up dependencies and test environment setup
mwcraig 7378297
Update minimum python to 3.10
mwcraig 4485839
Fix mask access
mwcraig 30e9e55
Ignore warnings about negative values in square root
mwcraig a45d2cd
Update several minimum dependencies
mwcraig e4afebc
Fix linting errors
mwcraig 3724f0a
Drop unnecessary import
mwcraig c0928c8
Drop unneeded test
mwcraig c04f842
Skip memory tests if jax is installed
mwcraig 9fcc655
Explain why numpy is still used in image_collection
mwcraig 1fe1180
Drop numpy import in combiner
mwcraig 282eaee
Use array_api_extra to handle immutable arrays
mwcraig 5f27cd4
Use a consistent namespace for arrays
mwcraig 5f7f3e0
Clean up a couple more cases to use array_api_extra
mwcraig b98f8f5
Change where bottleneck is test on GitHub Actions
mwcraig 6ce2309
Skip coverage of one function
mwcraig b66abbc
Remove unused argument and logic
mwcraig e67ac8b
Add a test
mwcraig cdb9c45
Use tox environment to handle testing of different array libraries
mwcraig ef30230
Convert combiner tests to use Array API
mwcraig fda5ad4
Remove unnecessary copy argument
mwcraig 03d413a
Use the array_api_compt numpy namespace instead of numpy
mwcraig dfc98cf
Add test against dask and fix bugs uncovered by tests
mwcraig 191da4e
Add dask test to CI
mwcraig 8fc85a5
Fix some errors introduced when changing the tests for dask
mwcraig 73cc0b0
Allow cupy for testing
mwcraig 53fb79a
Suppress square root warning generated in some array libraries
mwcraig 4bef390
Apply suggestions from code review
mwcraig efeac76
Undo suggested edit
mwcraig 69a79d5
Shorten up a loop with a comprehension
mwcraig f7ea0a9
Add minimum pin for dependency
mwcraig d3a732b
Update black target versions
mwcraig 309bbf4
Store array namespace when Combiner is created
mwcraig bd23ac4
Add optional namespace argument to several functions
mwcraig fcac48f
Use array API in all cosmic ray tests
mwcraig d0bcdad
cast number to float to avoid multiple namespaces
mwcraig baae84f
Change internal data and mask to private properties
mwcraig 655e844
Add properties for accessing the data and mask to be used in combination
mwcraig 17c0dca
Apply suggestions from code review
mwcraig a4e6d4a
Choose performance over style
mwcraig 5843955
Changes to testing for cupy
mwcraig bbc1249
Make sure to use CCDData.data instead of CCDData in comparisons
mwcraig 6af1e1e
Add more robust handling of open file test
mwcraig 5ec5759
Add several workarounds for non-compliance of CCData with Array API
mwcraig 41e17e6
Minor changes
mwcraig 7664bc0
Add and use wrapper classes to ensure array API use
mwcraig fb6e7d6
Fix for immutable array types
mwcraig c4c1277
Ensure array namespace is used throughout core tests
mwcraig 048f079
Include more uncertainty types in tests
mwcraig 359126a
Use array API correctly in combiner tests
mwcraig bb42d5b
Add array namespace input to combine function
mwcraig ba83726
Add an array namespace conversion in one more place
mwcraig 698c6fa
Add argument for desired array namespace
mwcraig 7c7ce31
Specify namespace in a couple of tests where data is read from disk
mwcraig 9a5ccd8
Fix broken links in docstrings
mwcraig 4503ddf
Switch more tests away from np.testing or explain why not switching
mwcraig 9afb5a6
Minimal array API docs
mwcraig 5289ddc
Improve test coverage
mwcraig 802d9a4
Replace .array with .asarray
mwcraig 126d174
Apply suggestions from code review
mwcraig ad9b9b5
Remove reference to list that is not used
mwcraig fd9f03e
Point to documentation for list of supported libraries
mwcraig 749761f
Avoid more numpy arrays in tests
mwcraig f5bea45
Fix formatting issues
mwcraig bb8b81b
Apply suggestions from code review
mwcraig 53dc6bb
Apply suggestions from code review
mwcraig 12f5ab5
Do masking operation in-place if possible
mwcraig f657dfa
Avoid converting data to numpy array
mwcraig 6d8edb7
Change mask setting in tests to avoid a numpy conversion
mwcraig b17566b
Use CCDData compatibility wrapper in trim_image
mwcraig fdec811
Cast number to array namespace
mwcraig 7531eee
Remove another instance of conversion to a numpy array
mwcraig c196b93
More fixes for flat_correct
mwcraig f9ad753
One more flat fix
mwcraig a0f738a
Do not use pytest.approx because it forces a numpy conversion
mwcraig 65ef951
Add cast to array namespace
mwcraig 52d6d45
Add cast to array namespace in gain_correct
mwcraig 9f59fb5
Fix handling of gain value/unit
mwcraig c3ffbdf
Avoid mask setter in more places
mwcraig 9de6a92
Fix typo
mwcraig 558ee79
Wrap the CCDData object inside transform_image
mwcraig df71b4d
Wrap before cop because copy sets mask
mwcraig 4845ec2
Fix typo
mwcraig 406ded9
Ensure uncertainty in test uses array namespace
mwcraig b062491
Set mask properly in test
mwcraig 42bb716
Make error from data, not ccd object
mwcraig c204f4d
Fix several array API issues in lacosmic
mwcraig ebf9127
Ensure uncertainty is wrapped even if CCDData is already wrapped
mwcraig aedb3b2
Add numpy conversion for now
mwcraig 2a877a9
Fix error in test
mwcraig 40ffbae
Cast mask to array API object
mwcraig File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
# This file is a rough draft of the changes that will be needed | ||
# in astropy.nddata to adopt the array API. This does not cover all | ||
# of the changes that will be needed, but it is a start. | ||
|
||
import array_api_compat | ||
import numpy as np | ||
from astropy import units as u | ||
from astropy.nddata import ( | ||
CCDData, | ||
StdDevUncertainty, | ||
) | ||
from astropy.nddata.compat import NDDataArray | ||
from astropy.units import UnitsError | ||
|
||
|
||
class _NDDataArray(NDDataArray): | ||
@NDDataArray.mask.setter | ||
def mask(self, value): | ||
xp = array_api_compat.array_namespace(self.data) | ||
# Check that value is not either type of null mask. | ||
if (value is not None) and (value is not np.ma.nomask): | ||
mask = xp.asarray(value, dtype=bool) | ||
if mask.shape != self.data.shape: | ||
raise ValueError( | ||
f"dimensions of mask {mask.shape} and data " | ||
f"{self.data.shape} do not match" | ||
) | ||
else: | ||
self._mask = mask | ||
else: | ||
# internal representation should be one numpy understands | ||
self._mask = np.ma.nomask | ||
|
||
|
||
class _CCDDataWrapperForArrayAPI(CCDData): | ||
""" | ||
Thin wrapper around CCDData to allow arithmetic operations with | ||
arbitray array API backends. | ||
""" | ||
|
||
def _arithmetic_wrapper(self, operation, operand, result_unit, **kwargs): | ||
""" " | ||
Use NDDataArray for arthmetic because that does not force conversion | ||
to Quantity (and hence numpy array). If there are units on the operands | ||
then NDArithmeticMixin will convert to Quantity. | ||
""" | ||
# Take the units off to make sure the arithmetic operation | ||
# does not try to convert to Quantity. | ||
if hasattr(self, "unit"): | ||
self_unit = self.unit | ||
self._unit = None | ||
else: | ||
self_unit = None | ||
|
||
if hasattr(operand, "unit"): | ||
operand_unit = operand.unit | ||
operand._unit = None | ||
else: | ||
operand_unit = None | ||
|
||
# Also take the units off of the uncertainty | ||
if self_unit is not None and hasattr(self.uncertainty, "unit"): | ||
self.uncertainty._unit = None | ||
|
||
if ( | ||
operand_unit is not None | ||
and hasattr(operand, "uncertainty") | ||
and hasattr(operand.uncertainty, "unit") | ||
): | ||
operand.uncertainty._unit = None | ||
|
||
_result = _NDDataArray._prepare_then_do_arithmetic( | ||
operation, self, operand, **kwargs | ||
) | ||
if self_unit: | ||
self._unit = self_unit | ||
if operand_unit: | ||
operand._unit = operand_unit | ||
# Also take the units off of the uncertainty | ||
if hasattr(self, "uncertainty") and self.uncertainty is not None: | ||
self.uncertainty._unit = self_unit | ||
|
||
if hasattr(operand, "uncertainty") and operand.uncertainty is not None: | ||
operand.uncertainty._unit = operand_unit | ||
|
||
# We need to handle the mask separately if we want to return a | ||
# genuine CCDDatta object and CCDData does not understand the | ||
# array API. | ||
result_mask = None | ||
if _result.mask is not None: | ||
result_mask = _result._mask | ||
_result._mask = None | ||
result = CCDData(_result, unit=result_unit) | ||
result._mask = result_mask | ||
return result | ||
|
||
def subtract(self, operand, xp=None, **kwargs): | ||
""" | ||
Determine the right operation to use and figure out | ||
the units of the result. | ||
""" | ||
xp = xp or array_api_compat.array_namespace(self.data) | ||
if not self.unit.is_equivalent(operand.unit): | ||
raise UnitsError("Units must be equivalent for subtraction.") | ||
result_unit = self.unit | ||
handle_mask = kwargs.pop("handle_mask", xp.logical_or) | ||
return self._arithmetic_wrapper( | ||
xp.subtract, operand, result_unit, handle_mask=handle_mask, **kwargs | ||
) | ||
|
||
def add(self, operand, xp=None, **kwargs): | ||
""" | ||
Determine the right operation to use and figure out | ||
the units of the result. | ||
""" | ||
xp = xp or array_api_compat.array_namespace(self.data) | ||
if not self.unit.is_equivalent(operand.unit): | ||
raise UnitsError("Units must be equivalent for addition.") | ||
result_unit = self.unit | ||
handle_mask = kwargs.pop("handle_mask", xp.logical_or) | ||
return self._arithmetic_wrapper( | ||
xp.add, operand, result_unit, handle_mask=handle_mask, **kwargs | ||
) | ||
|
||
def multiply(self, operand, xp=None, **kwargs): | ||
""" | ||
Determine the right operation to use and figure out | ||
the units of the result. | ||
""" | ||
xp = xp or array_api_compat.array_namespace(self.data) | ||
# The "1 *" below is because quantities do arithmetic properly | ||
# but units do not necessarily. | ||
if not hasattr(operand, "unit"): | ||
operand_unit = 1 * u.dimensionless_unscaled | ||
else: | ||
operand_unit = operand.unit | ||
result_unit = (1 * self.unit) * (1 * operand_unit) | ||
handle_mask = kwargs.pop("handle_mask", xp.logical_or) | ||
return self._arithmetic_wrapper( | ||
xp.multiply, operand, result_unit, handle_mask=handle_mask, **kwargs | ||
) | ||
|
||
def divide(self, operand, xp=None, **kwargs): | ||
""" | ||
Determine the right operation to use and figure out | ||
the units of the result. | ||
""" | ||
xp = xp or array_api_compat.array_namespace(self.data) | ||
if not hasattr(operand, "unit"): | ||
operand_unit = 1 * u.dimensionless_unscaled | ||
else: | ||
operand_unit = operand.unit | ||
result_unit = (1 * self.unit) / (1 * operand_unit) | ||
handle_mask = kwargs.pop("handle_mask", xp.logical_or) | ||
return self._arithmetic_wrapper( | ||
xp.divide, operand, result_unit, handle_mask=handle_mask, **kwargs | ||
) | ||
|
||
@NDDataArray.mask.setter | ||
def mask(self, value): | ||
xp = array_api_compat.array_namespace(self.data) | ||
# Check that value is not either type of null mask. | ||
if (value is not None) and (value is not np.ma.nomask): | ||
mask = xp.asarray(value, dtype=bool) | ||
if mask.shape != self.data.shape: | ||
raise ValueError( | ||
f"dimensions of mask {mask.shape} and data " | ||
f"{self.data.shape} do not match" | ||
) | ||
else: | ||
self._mask = mask | ||
else: | ||
# internal representation should be one numpy understands | ||
self._mask = np.ma.nomask | ||
|
||
|
||
class _StdDevUncertaintyWrapper(StdDevUncertainty): | ||
""" | ||
Override propagate methods to make sure they use the array API. | ||
""" | ||
|
||
def _propagate_add(self, other_uncert, result_data, correlation): | ||
xp = array_api_compat.array_namespace(self.array, other_uncert.array) | ||
return super()._propagate_add_sub( | ||
other_uncert, | ||
result_data, | ||
correlation, | ||
subtract=False, | ||
to_variance=xp.square, | ||
from_variance=xp.sqrt, | ||
) | ||
|
||
def _propagate_subtract(self, other_uncert, result_data, correlation): | ||
xp = array_api_compat.array_namespace(self.array, other_uncert.array) | ||
return super()._propagate_add_sub( | ||
other_uncert, | ||
result_data, | ||
correlation, | ||
subtract=True, | ||
to_variance=xp.square, | ||
from_variance=xp.sqrt, | ||
) | ||
|
||
def _propagate_multiply(self, other_uncert, result_data, correlation): | ||
xp = array_api_compat.array_namespace(self.array, other_uncert.array) | ||
return super()._propagate_multiply_divide( | ||
other_uncert, | ||
result_data, | ||
correlation, | ||
divide=False, | ||
to_variance=xp.square, | ||
from_variance=xp.sqrt, | ||
) | ||
|
||
def _propagate_divide(self, other_uncert, result_data, correlation): | ||
xp = array_api_compat.array_namespace(self.array, other_uncert.array) | ||
return super()._propagate_multiply_divide( | ||
other_uncert, | ||
result_data, | ||
correlation, | ||
divide=True, | ||
to_variance=xp.square, | ||
from_variance=xp.sqrt, | ||
) | ||
|
||
|
||
def _wrap_ccddata_for_array_api(ccd): | ||
""" | ||
Wrap a CCDData object for use with array API backends. | ||
""" | ||
if isinstance(ccd, _CCDDataWrapperForArrayAPI): | ||
return ccd | ||
|
||
_ccd = _CCDDataWrapperForArrayAPI(ccd) | ||
if isinstance(_ccd.uncertainty, StdDevUncertainty): | ||
_ccd.uncertainty = _StdDevUncertaintyWrapper(_ccd.uncertainty) | ||
return _ccd | ||
|
||
|
||
def _unwrap_ccddata_for_array_api(ccd): | ||
""" | ||
Unwrap a CCDData object from array API backends to the original CCDData. | ||
""" | ||
|
||
if isinstance(ccd.uncertainty, _StdDevUncertaintyWrapper): | ||
ccd.uncertainty = StdDevUncertainty(ccd.uncertainty.array) | ||
|
||
if isinstance(ccd, CCDData): | ||
return ccd | ||
|
||
if not isinstance(ccd, _CCDDataWrapperForArrayAPI): | ||
raise TypeError( | ||
"Input must be a CCDData or _CCDDataWrapperForArrayAPI instance." | ||
) | ||
|
||
# Convert back to CCDData | ||
return CCDData(ccd) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.