Skip to content

Commit f668f97

Browse files
committed
Fix bug in NDCube.crop.
The bug was revealed by trying to crop a 1-D cube and was found to be caused by creating a SlicedLowLevelWCS object with a slice(None) slice item.
1 parent 58c8075 commit f668f97

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

ndcube/tests/test_ndcube_slice_and_crop.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,3 +526,21 @@ def test_crop_rotated_celestial(ndcube_4d_ln_lt_l_t):
526526
small = cube.crop(bottom_left, bottom_right, top_left, top_right)
527527

528528
assert small.data.shape == (1652, 1652)
529+
530+
531+
def test_crop_1d():
532+
# This use case revealed a bug so has been added as a test.
533+
# Create NDCube.
534+
wcs = astropy.wcs.WCS(naxis=1)
535+
wcs.wcs.ctype = 'WAVE',
536+
wcs.wcs.cunit = 'nm',
537+
wcs.wcs.cdelt = 4,
538+
wcs.wcs.crpix = 1,
539+
wcs.wcs.crval = 3,
540+
cube = NDCube(np.arange(200), wcs=wcs)
541+
542+
expected = cube[1:4]
543+
544+
output = cube.crop((7*u.nm,), (15*u.nm,))
545+
546+
helpers.assert_cubes_equal(output, expected)

ndcube/utils/cube.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,17 @@ def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
141141
# where each inner list gives the index of all points for that array axis.
142142
combined_points_array_idx = [[]] * wcs.pixel_n_dim
143143
high_level_wcs = HighLevelWCSWrapper(wcs) if isinstance(wcs, BaseLowLevelWCS) else wcs
144-
wcs = high_level_wcs.low_level_wcs
144+
low_level_wcs = high_level_wcs.low_level_wcs
145145
# For each point compute the corresponding array indices.
146146
for point in points:
147147
# Get the arrays axes associated with each element in point.
148148
if crop_by_values:
149149
point_inputs_array_axes = []
150-
for i in range(wcs.world_n_dim):
150+
for i in range(low_level_wcs.world_n_dim):
151151
pix_axes = np.array(
152-
wcs_utils.world_axis_to_pixel_axes(i, wcs.axis_correlation_matrix))
152+
wcs_utils.world_axis_to_pixel_axes(i, low_level_wcs.axis_correlation_matrix))
153153
point_inputs_array_axes.append(tuple(
154-
wcs_utils.convert_between_array_and_pixel_axes(pix_axes, wcs.pixel_n_dim)))
154+
wcs_utils.convert_between_array_and_pixel_axes(pix_axes, low_level_wcs.pixel_n_dim)))
155155
point_inputs_array_axes = tuple(point_inputs_array_axes)
156156
else:
157157
point_inputs_array_axes = wcs_utils.array_indices_for_world_objects(high_level_wcs)
@@ -164,14 +164,16 @@ def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
164164
point_indices_with_inputs.append(i)
165165
array_axes_with_input.append(point_inputs_array_axes[i])
166166
array_axes_with_input = set(chain.from_iterable(array_axes_with_input))
167-
array_axes_without_input = set(range(wcs.pixel_n_dim)) - array_axes_with_input
167+
array_axes_without_input = set(range(low_level_wcs.pixel_n_dim)) - array_axes_with_input
168168
# Slice out the axes that do not correspond to a coord
169169
# from the WCS and the input point.
170-
wcs_slice = np.array([slice(None)] * wcs.pixel_n_dim)
171-
if len(array_axes_without_input):
170+
if len(array_axes_without_input) > 0:
171+
wcs_slice = np.array([slice(None)] * low_level_wcs.pixel_n_dim)
172172
wcs_slice[np.array(list(array_axes_without_input))] = 0
173-
sliced_wcs = SlicedLowLevelWCS(wcs, slices=tuple(wcs_slice))
174-
sliced_point = np.array(point, dtype=object)[np.array(point_indices_with_inputs)]
173+
sliced_wcs = SlicedLowLevelWCS(low_level_wcs, slices=tuple(wcs_slice))
174+
sliced_point = np.array(point, dtype=object)[np.array(point_indices_with_inputs)]
175+
else:
176+
sliced_wcs, sliced_point = low_level_wcs, np.array(point, dtype=object)
175177
# Derive the array indices of the input point and place each index
176178
# in the list corresponding to its axis.
177179
if crop_by_values:

0 commit comments

Comments
 (0)