Skip to content

Commit d6710aa

Browse files
Fix mask consistency in getvar functions across all data types
- Implement filtered_dataobject pattern to prevent dimension mismatches - Fix recursive getvar calls in getvar_hydro.jl (35+ functions affected) - Fix recursive getvar calls in getvar_particles.jl (20+ functions affected) - Fix recursive getvar calls in getvar_gravity.jl (10+ functions affected) - Maintains performance with O(masked_cells) complexity through early filtering
1 parent 44b920c commit d6710aa

File tree

3 files changed

+243
-223
lines changed

3 files changed

+243
-223
lines changed

src/functions/getvar/getvar_gravity.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@ function get_data(dataobject::GravDataType,
2121
# This gives true O(masked_cells) performance instead of O(total_cells)
2222
mask_indices = findall(mask)
2323
masked_data = dataobject.data[mask_indices]
24-
use_masked_data = false # No need to apply mask again since data is pre-filtered
24+
# Create a temporary dataobject with filtered data for recursive calls
25+
filtered_dataobject = deepcopy(dataobject)
26+
filtered_dataobject.data = masked_data
27+
use_mask_in_recursion = [false] # Don't apply mask in recursive calls since data is pre-filtered
2528
else
26-
use_masked_data = false
29+
filtered_dataobject = dataobject
2730
masked_data = dataobject.data
31+
use_mask_in_recursion = mask # Use original mask for recursive calls
2832
end
2933

3034

@@ -102,7 +106,7 @@ function get_data(dataobject::GravDataType,
102106
end
103107
elseif i == :volume
104108
selected_unit = getunit(dataobject, :volume, vars, units)
105-
vars_dict[:volume] = convert(Array{Float64,1}, getvar(dataobject, :cellsize, mask=mask) .^3 .* selected_unit)
109+
vars_dict[:volume] = convert(Array{Float64,1}, getvar(filtered_dataobject, :cellsize, mask=use_mask_in_recursion) .^3 .* selected_unit)
106110

107111

108112
elseif i == :x
@@ -124,7 +128,7 @@ function get_data(dataobject::GravDataType,
124128
if isamr
125129
vars_dict[:z] = (select(masked_data, cpos) .* boxlen ./ 2 .^select(masked_data, :level) .- boxlen * center[3] ) .* selected_unit
126130
else # if uniform grid
127-
vars_dict[:z] = (getvar(dataobject, cpos, mask=mask) .* boxlen ./ 2^lmax .- boxlen * center[3] ) .* selected_unit
131+
vars_dict[:z] = (getvar(filtered_dataobject, cpos, mask=use_mask_in_recursion) .* boxlen ./ 2^lmax .- boxlen * center[3] ) .* selected_unit
128132
end
129133

130134
# Gravitational acceleration magnitude - code units by default
@@ -165,8 +169,8 @@ function get_data(dataobject::GravDataType,
165169
# Cylindrical acceleration components - code units by default
166170
elseif i == :ar_cylinder
167171
selected_unit = getunit(dataobject, :ar_cylinder, vars, units)
168-
x = getvar(dataobject, :x, center=center, mask=mask)
169-
y = getvar(dataobject, :y, center=center, mask=mask)
172+
x = getvar(filtered_dataobject, :x, center=center, mask=use_mask_in_recursion)
173+
y = getvar(filtered_dataobject, :y, center=center, mask=use_mask_in_recursion)
170174
ax = select(masked_data, :ax)
171175
ay = select(masked_data, :ay)
172176

@@ -177,8 +181,8 @@ function get_data(dataobject::GravDataType,
177181

178182
elseif i == :aϕ_cylinder
179183
selected_unit = getunit(dataobject, :aϕ_cylinder, vars, units)
180-
x = getvar(dataobject, :x, center=center, mask=mask)
181-
y = getvar(dataobject, :y, center=center, mask=mask)
184+
x = getvar(filtered_dataobject, :x, center=center, mask=use_mask_in_recursion)
185+
y = getvar(filtered_dataobject, :y, center=center, mask=use_mask_in_recursion)
182186
ax = select(masked_data, :ax)
183187
ay = select(masked_data, :ay)
184188

@@ -190,9 +194,9 @@ function get_data(dataobject::GravDataType,
190194
# Spherical acceleration components - code units by default
191195
elseif i == :ar_sphere
192196
selected_unit = getunit(dataobject, :ar_sphere, vars, units)
193-
x = getvar(dataobject, :x, center=center, mask=mask)
194-
y = getvar(dataobject, :y, center=center, mask=mask)
195-
z = getvar(dataobject, :z, center=center, mask=mask)
197+
x = getvar(filtered_dataobject, :x, center=center, mask=use_mask_in_recursion)
198+
y = getvar(filtered_dataobject, :y, center=center, mask=use_mask_in_recursion)
199+
z = getvar(filtered_dataobject, :z, center=center, mask=use_mask_in_recursion)
196200
ax = select(masked_data, :ax)
197201
ay = select(masked_data, :ay)
198202
az = select(masked_data, :az)
@@ -204,9 +208,9 @@ function get_data(dataobject::GravDataType,
204208

205209
elseif i == :aθ_sphere
206210
selected_unit = getunit(dataobject, :aθ_sphere, vars, units)
207-
x = getvar(dataobject, :x, center=center, mask=mask)
208-
y = getvar(dataobject, :y, center=center, mask=mask)
209-
z = getvar(dataobject, :z, center=center, mask=mask)
211+
x = getvar(filtered_dataobject, :x, center=center, mask=use_mask_in_recursion)
212+
y = getvar(filtered_dataobject, :y, center=center, mask=use_mask_in_recursion)
213+
z = getvar(filtered_dataobject, :z, center=center, mask=use_mask_in_recursion)
210214
ax = select(masked_data, :ax)
211215
ay = select(masked_data, :ay)
212216
az = select(masked_data, :az)
@@ -221,8 +225,8 @@ function get_data(dataobject::GravDataType,
221225

222226
elseif i == :aϕ_sphere
223227
selected_unit = getunit(dataobject, :aϕ_sphere, vars, units)
224-
x = getvar(dataobject, :x, center=center, mask=mask)
225-
y = getvar(dataobject, :y, center=center, mask=mask)
228+
x = getvar(filtered_dataobject, :x, center=center, mask=use_mask_in_recursion)
229+
y = getvar(filtered_dataobject, :y, center=center, mask=use_mask_in_recursion)
226230
ax = select(masked_data, :ax)
227231
ay = select(masked_data, :ay)
228232

@@ -234,22 +238,22 @@ function get_data(dataobject::GravDataType,
234238
# Radial distances (for gravity analysis) - code units by default
235239
elseif i == :r_cylinder
236240
selected_unit = getunit(dataobject, :r_cylinder, vars, units)
237-
x = getvar(dataobject, :x, center=center, mask=mask)
238-
y = getvar(dataobject, :y, center=center, mask=mask)
241+
x = getvar(filtered_dataobject, :x, center=center, mask=use_mask_in_recursion)
242+
y = getvar(filtered_dataobject, :y, center=center, mask=use_mask_in_recursion)
239243
vars_dict[:r_cylinder] = @. sqrt(x^2 + y^2) * selected_unit
240244

241245
elseif i == :r_sphere
242246
selected_unit = getunit(dataobject, :r_sphere, vars, units)
243-
x = getvar(dataobject, :x, center=center, mask=mask)
244-
y = getvar(dataobject, :y, center=center, mask=mask)
245-
z = getvar(dataobject, :z, center=center, mask=mask)
247+
x = getvar(filtered_dataobject, :x, center=center, mask=use_mask_in_recursion)
248+
y = getvar(filtered_dataobject, :y, center=center, mask=use_mask_in_recursion)
249+
z = getvar(filtered_dataobject, :z, center=center, mask=use_mask_in_recursion)
246250
vars_dict[:r_sphere] = @. sqrt(x^2 + y^2 + z^2) * selected_unit
247251

248252
# Azimuthal angle - dimensionless/radians by default
249253
elseif i ==
250254
selected_unit = getunit(dataobject, , vars, units)
251-
x = getvar(dataobject, :x, center=center, mask=mask)
252-
y = getvar(dataobject, :y, center=center, mask=mask)
255+
x = getvar(filtered_dataobject, :x, center=center, mask=use_mask_in_recursion)
256+
y = getvar(filtered_dataobject, :y, center=center, mask=use_mask_in_recursion)
253257
vars_dict[] = @. atan(y, x) * selected_unit
254258

255259
# Fallback: if variable not found in gravity and hydro data is available, try hydro getvar

0 commit comments

Comments
 (0)