Skip to content
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 77 additions & 34 deletions src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extract the highest energy particle in the event."""

from typing import Dict, Any, TYPE_CHECKING, Tuple
from typing import Dict, Any, TYPE_CHECKING, Tuple, Union, List


from .i3extractor import I3Extractor
Expand Down Expand Up @@ -183,13 +183,15 @@ def get_tracks(

MMCTrackList = frame[self.mmctracklist]
if self.daughters:
MMCTrackList = [
track
for track in MMCTrackList
if frame[self.mctree].get_primary(track.GetI3Particle())
in primaries
]
MMCTrackList = simclasses.I3MMCTrackList(MMCTrackList)
temp_MMCTrackList = []
for track in MMCTrackList:
for p in primaries:
if frame[self.mctree].is_in_subtree(
p.id, track.GetI3Particle().id
):
temp_MMCTrackList.append(track)
break
MMCTrackList = simclasses.I3MMCTrackList(temp_MMCTrackList)

MuonGun_tracks = np.array(
MuonGun.Track.harvest(frame[self.mctree], MMCTrackList)
Expand Down Expand Up @@ -222,14 +224,16 @@ def get_pos_dir_length(
lengths[np.isnan(lengths)] = 0
return pos, direc, lengths

def get_bundle_HEP(self, particles: np.array) -> "dataclasses.I3Particle":
def get_bundle_HEP(
self, particles: np.array
) -> Tuple["dataclasses.I3Particle", np.ndarray, bool]:
"""Get the energy averaged particle of a list of particles.

Args:
particles: List of I3Particles
"""
if len(particles) == 0:
return dataclasses.I3Particle()
return dataclasses.I3Particle(), np.array([]), True

energies, lengths = np.array(
[[p.energy, p.length] for p in particles]
Expand Down Expand Up @@ -333,9 +337,19 @@ def highest_energy_track(
EonEntrance = MGtrack.get_energy(
max(intersections.first, 0)
)

# a skimming track can be outside the hull
# therefore it can have 0 visible length
visible_length = intersections.second - max(
intersections.first, 0
)

# It can happen that both intersections are negative
# in this case the particle never reaches the detector
# and therefore should not be considered for the HEP
if visible_length < 0:
continue

e_mask = energies > EonEntrance
energies = energies[e_mask]
MMCTrackList = MMCTrackList[e_mask]
Expand Down Expand Up @@ -392,26 +406,15 @@ def highest_energy_starting(
primaries = [
self.check_primary_energy(frame, p) for p in primaries
]
particles = dataclasses.ListI3Particle()

particles = frame[self.mctree]
particles = self.get_descendants(frame, primaries)

e_p = np.array(
[
np.array([p.energy, p])
for p in particles
if (
(
(p.energy > min_e)
& (
frame[self.mctree].get_primary(p.id)
in primaries
)
)
& (not p.is_track)
)
]
).T
e_p = []
for part in particles:
if (part.energy > min_e) & (~part.is_track):
e_p.append(np.array([part.energy, part]))

e_p = np.array(e_p).T

else:

Expand All @@ -434,9 +437,9 @@ def highest_energy_starting(
visible_length,
-1,
)
else:
energies = e_p[0]
particles = e_p[1]

energies = e_p[0]
particles = e_p[1]

pos, direc, lengths = self.get_pos_dir_length(particles)
pos = pos + direc * lengths
Expand Down Expand Up @@ -493,10 +496,17 @@ def highest_energy_starting(
intersections = self.hull.surface.intersection(
track.pos, track.dir
)
visible_length = max(
visible_length,
intersections.second - max(intersections.first, 0),

visible_length = intersections.second - max(
intersections.first, 0
)

# It can happen that both intersections are negative
# in this case the particle never reaches the detector
# and therefore should not be considered for the HEP
if visible_length < 0:
continue

# Check if we have a single topologically "real" track
if not real_track:
if not dataclasses.I3MCTree.parent(
Expand Down Expand Up @@ -724,6 +734,10 @@ def highest_energy_bundle(
GN_containment_types.throughgoing_bundle.value
)

assert (
visible_length >= 0
), f"Visible length is negative for particle {frame['I3EventHeader']}"

closest_pos = np.sum(closest_pos, axis=0) / EonEntrance

bundle.pos = dataclasses.I3Position(
Expand Down Expand Up @@ -770,3 +784,32 @@ def get_visible_produced_particles(
else:
visible_particles.append(daughter)
return visible_particles

def get_descendants(
self,
frame: "icetray.I3Frame",
particle: Union[
"dataclasses.I3Particle", List["dataclasses.I3Particle"]
],
) -> "dataclasses.ListI3Particle":
"""Get the descendants of a particle and the particle as a list.

Args:
frame: I3Frame object
particle: I3Particle object
"""
if isinstance(particle, list):
ret = []
for p in particle:
ret.extend(self.get_descendants(frame, p))
return ret
else:
daughters = frame[self.mctree].get_daughters(particle)
if len(daughters) == 0:
return [particle]
else:
ret = []
ret.append(particle)
for p in daughters:
ret.extend(self.get_descendants(frame, p))
return ret