diff --git a/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py b/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py index e2f1a7af6..d491d938f 100644 --- a/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py +++ b/src/graphnet/data/extractors/icecube/i3highesteparticleextractor.py @@ -1,6 +1,6 @@ """Extract the highest energy particle in the event.""" -from typing import Dict, Any, TYPE_CHECKING, Tuple +from typing import Dict, Any, TYPE_CHECKING, Tuple, Union, List from .i3extractor import I3Extractor @@ -183,13 +183,15 @@ def get_tracks( MMCTrackList = frame[self.mmctracklist] if self.daughters: - MMCTrackList = [ - track - for track in MMCTrackList - if frame[self.mctree].get_primary(track.GetI3Particle()) - in primaries - ] - MMCTrackList = simclasses.I3MMCTrackList(MMCTrackList) + temp_MMCTrackList = [] + for track in MMCTrackList: + for p in primaries: + if frame[self.mctree].is_in_subtree( + p.id, track.GetI3Particle().id + ): + temp_MMCTrackList.append(track) + break + MMCTrackList = simclasses.I3MMCTrackList(temp_MMCTrackList) MuonGun_tracks = np.array( MuonGun.Track.harvest(frame[self.mctree], MMCTrackList) @@ -222,14 +224,16 @@ def get_pos_dir_length( lengths[np.isnan(lengths)] = 0 return pos, direc, lengths - def get_bundle_HEP(self, particles: np.array) -> "dataclasses.I3Particle": + def get_bundle_HEP( + self, particles: np.array + ) -> Tuple["dataclasses.I3Particle", np.ndarray, bool]: """Get the energy averaged particle of a list of particles. Args: particles: List of I3Particles """ if len(particles) == 0: - return dataclasses.I3Particle() + return dataclasses.I3Particle(), np.array([]), True energies, lengths = np.array( [[p.energy, p.length] for p in particles] @@ -319,7 +323,28 @@ def highest_energy_track( if not np.isnan(intersections.first) & ( intersections.first < length ): - if MGtrack.get_energy(intersections.first) > EonEntrance: + try: + tmp_EonEntrance = MGtrack.get_energy( + max(intersections.first, 0) + ) + # Catch MuonGun errors + except RuntimeError as e: + if ( + "sum of losses is smaller than " + "energy at last checkpoint" in str(e) + ): + hdr = frame["I3EventHeader"] + e.add_note( + f"Error in MuonGun track in event {hdr}" + ) + self.warning( + f"Detected corrupt track in {hdr}: {e}" + ) + continue + else: + raise # re-raise unexpected errors + + if tmp_EonEntrance > EonEntrance: particle = track_particle closest_pos = np.array( @@ -330,12 +355,18 @@ def highest_energy_track( ] ) - EonEntrance = MGtrack.get_energy( - max(intersections.first, 0) - ) + EonEntrance = tmp_EonEntrance + visible_length = intersections.second - max( intersections.first, 0 ) + + # It can happen that both intersections are negative + # in this case the particle never reaches the detector + # and therefore should not be considered for the HEP + if visible_length < 0: + continue + e_mask = energies > EonEntrance energies = energies[e_mask] MMCTrackList = MMCTrackList[e_mask] @@ -392,26 +423,15 @@ def highest_energy_starting( primaries = [ self.check_primary_energy(frame, p) for p in primaries ] - particles = dataclasses.ListI3Particle() - particles = frame[self.mctree] + particles = self.get_descendants(frame, primaries) - e_p = np.array( - [ - np.array([p.energy, p]) - for p in particles - if ( - ( - (p.energy > min_e) - & ( - frame[self.mctree].get_primary(p.id) - in primaries - ) - ) - & (not p.is_track) - ) - ] - ).T + e_p = [] + for part in particles: + if (part.energy > min_e) & (~part.is_track): + e_p.append(np.array([part.energy, part])) + + e_p = np.array(e_p).T else: @@ -434,9 +454,9 @@ def highest_energy_starting( visible_length, -1, ) - else: - energies = e_p[0] - particles = e_p[1] + + energies = e_p[0] + particles = e_p[1] pos, direc, lengths = self.get_pos_dir_length(particles) pos = pos + direc * lengths @@ -493,10 +513,17 @@ def highest_energy_starting( intersections = self.hull.surface.intersection( track.pos, track.dir ) - visible_length = max( - visible_length, - intersections.second - max(intersections.first, 0), + + visible_length = intersections.second - max( + intersections.first, 0 ) + + # It can happen that both intersections are negative + # in this case the particle never reaches the detector + # and therefore should not be considered for the HEP + if visible_length < 0: + continue + # Check if we have a single topologically "real" track if not real_track: if not dataclasses.I3MCTree.parent( @@ -697,7 +724,32 @@ def highest_energy_bundle( MGtrack.pos, MGtrack.dir ) - track_energy = MGtrack.get_energy(intersections.first) + # Particles that passed the sphere check but + # do not actually intersect the hull + if intersections.second < 0: + continue + + try: + track_energy = MGtrack.get_energy(intersections.first) + # Catch MuonGun errors + except RuntimeError as e: + if ( + "sum of losses is smaller than " + "energy at last checkpoint" in str(e) + ): + hdr = frame["I3EventHeader"] + e.add_note(f"Error in MuonGun track in event {hdr}") + self.warning(f"Detected corrupt track in {hdr}: {e}") + return ( + dataclasses.I3Particle(), + 0.0, + -1.0, + -1, + GN_containment_types.no_intersect.value, + ) + else: + raise # re-raise unexpected errors + EonEntrance += track_energy closest_pos.append( @@ -724,6 +776,23 @@ def highest_energy_bundle( GN_containment_types.throughgoing_bundle.value ) + # If no intersection.second is every positive + # the visible_length can still be negative here + # this means that all particles that passed the + # sphere check do not actually make it to the real hull + if visible_length < 0: + return ( + dataclasses.I3Particle(), + 0.0, + -1.0, + -1, + GN_containment_types.no_intersect.value, + ) + + assert ( + visible_length >= 0 + ), f"Visible length is negative for particle {frame['I3EventHeader']}" + closest_pos = np.sum(closest_pos, axis=0) / EonEntrance bundle.pos = dataclasses.I3Position( @@ -770,3 +839,32 @@ def get_visible_produced_particles( else: visible_particles.append(daughter) return visible_particles + + def get_descendants( + self, + frame: "icetray.I3Frame", + particle: Union[ + "dataclasses.I3Particle", List["dataclasses.I3Particle"] + ], + ) -> "dataclasses.ListI3Particle": + """Get the descendants of a particle and the particle as a list. + + Args: + frame: I3Frame object + particle: I3Particle object + """ + if isinstance(particle, list): + ret = [] + for p in particle: + ret.extend(self.get_descendants(frame, p)) + return ret + else: + daughters = frame[self.mctree].get_daughters(particle) + if len(daughters) == 0: + return [particle] + else: + ret = [] + ret.append(particle) + for p in daughters: + ret.extend(self.get_descendants(frame, p)) + return ret