Skip to content

Commit 4bee06e

Browse files
authored
Merge pull request #795 from Aske-Rosted/converter_extractor_changes
Converter extractor changes
2 parents 2f81275 + 216cd02 commit 4bee06e

15 files changed

+82
-35
lines changed

src/graphnet/data/dataconverter.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,16 +215,17 @@ def _assign_event_no(
215215
data[k][extractor_name],
216216
index=[0] if n_rows == 1 else None,
217217
)
218-
if extractor_name in dataframe_dict.keys():
219-
dataframe_dict[extractor_name].append(df)
220-
else:
221-
dataframe_dict[extractor_name] = [df]
218+
if not df.empty:
219+
if extractor_name in dataframe_dict.keys():
220+
dataframe_dict[extractor_name].append(df)
221+
else:
222+
dataframe_dict[extractor_name] = [df]
222223

223224
# Merge each list of dataframes if wanted by writer
224225
if self._save_method.expects_merged_dataframes:
225226
for key in dataframe_dict.keys():
226227
dataframe_dict[key] = pd.concat(
227-
dataframe_dict[key], axis=0
228+
[df for df in dataframe_dict[key] if not df.empty], axis=0
228229
).reset_index(drop=True)
229230
return dataframe_dict
230231

@@ -275,10 +276,11 @@ def get_map_function(
275276
"""Identify map function to use (pure python or multiprocess)."""
276277
# Choose relevant map-function given the requested number of workers.
277278
n_workers = min(self._num_workers, nb_files)
279+
self._num_workers = n_workers
278280
if n_workers > 1:
279281
self.info(
280282
f"Starting pool of {n_workers} workers to process"
281-
" {nb_files} {unit}"
283+
f"{nb_files} {unit}"
282284
)
283285

284286
manager = Manager()
@@ -321,7 +323,10 @@ def _update_shared_variables(
321323

322324
@final
323325
def merge_files(
324-
self, files: Optional[Union[List[str], str]] = None, **kwargs: Any
326+
self,
327+
files: Optional[Union[List[str], str]] = None,
328+
output_dir: Optional[str] = None,
329+
**kwargs: Any,
325330
) -> None:
326331
"""Merge converted files.
327332
@@ -330,6 +335,9 @@ def merge_files(
330335
331336
Args:
332337
files: Intermediate files to be merged.
338+
output_dir: Directory to save the merged files in.
339+
**kwargs: Additional keyword arguments to be passed to the
340+
`GraphNeTWriter.merge_files` method.
333341
"""
334342
if (files is None) & (len(self._output_files) > 0):
335343
# If no input files are given, but output files from conversion
@@ -349,9 +357,10 @@ def merge_files(
349357
"and you must therefore specify argument `files`."
350358
)
351359
assert files is not None
352-
360+
if output_dir is None:
361+
output_dir = self._output_dir
353362
# Merge files
354-
merge_path = os.path.join(self._output_dir, "merged")
363+
merge_path = os.path.join(output_dir, "merged")
355364
self.info(f"Merging files to {merge_path}")
356365
self._save_method.merge_files(
357366
files=files_to_merge, output_dir=merge_path, **kwargs

src/graphnet/data/extractors/combine_extractors.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Module for combining multiple extractors into a single extractor."""
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Optional
44

55
from graphnet.utilities.imports import has_icecube_package
66
from graphnet.data.extractors.icecube.i3extractor import I3Extractor
@@ -31,6 +31,11 @@ def __init__(self, extractors: List[I3Extractor], extractor_name: str):
3131
super().__init__(extractor_name=extractor_name)
3232
self._extractors = extractors
3333

34+
def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None:
35+
"""Set the GCD file for all extractors."""
36+
for extractor in self._extractors:
37+
extractor.set_gcd(i3_file, gcd_file)
38+
3439
def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
3540
"""Extract data from frame using all extractors.
3641

src/graphnet/data/extractors/extractor.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Base I3Extractor class(es)."""
22

3-
from typing import Any, Union
3+
from typing import Any, Union, Callable
44
from abc import ABC, abstractmethod
55
import pandas as pd
66

@@ -23,21 +23,42 @@ class Extractor(ABC, Logger):
2323
An extractor is used in conjunction with a specific `FileReader`.
2424
"""
2525

26-
def __init__(self, extractor_name: str):
26+
def __init__(self, extractor_name: str, exclude: list = [None]):
2727
"""Construct Extractor.
2828
2929
Args:
3030
extractor_name: Name of the `Extractor` instance.
3131
Used to keep track of the provenance of different
3232
data, and to name tables to which this data is
3333
saved. E.g. "mc_truth".
34+
exclude: List of keys to exclude from the extracted data.
3435
"""
3536
# Member variable(s)
3637
self._extractor_name: str = extractor_name
38+
self._exclude = exclude
3739

3840
# Base class constructor
3941
super().__init__(name=__name__, class_name=self.__class__.__name__)
4042

43+
def exclude(func: Callable) -> Callable:
44+
"""Exclude specified keys from the extracted data."""
45+
46+
def wrapper(
47+
self: "Extractor", *args: Any
48+
) -> Union[dict, pd.DataFrame]:
49+
result = func(self, *args)
50+
if isinstance(result, dict):
51+
for key in self._exclude:
52+
if key in result:
53+
del result[key]
54+
elif isinstance(result, pd.DataFrame):
55+
for key in self._exclude:
56+
if key in result.columns:
57+
result = result.drop(columns=[key])
58+
return result
59+
60+
return wrapper
61+
4162
@abstractmethod
4263
def __call__(self, data: Any) -> Union[dict, pd.DataFrame]:
4364
"""Extract information from data."""
@@ -47,3 +68,8 @@ def __call__(self, data: Any) -> Union[dict, pd.DataFrame]:
4768
def name(self) -> str:
4869
"""Get the name of the `Extractor` instance."""
4970
return self._extractor_name
71+
72+
def __init_subclass__(cls) -> None:
73+
"""Initialize subclass and apply the exclude decorator to __call__."""
74+
super().__init_subclass__()
75+
cls.__call__ = cls.exclude(cls.__call__) # type: ignore

src/graphnet/data/extractors/icecube/i3extractor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ class I3Extractor(Extractor):
2020
method.
2121
"""
2222

23-
def __init__(self, extractor_name: str):
23+
def __init__(self, extractor_name: str, exclude: list = [None]):
2424
"""Construct I3Extractor.
2525
2626
Args:
2727
extractor_name: Name of the `I3Extractor` instance. Used to keep
2828
track of the provenance of different data, and to name tables
2929
to which this data is saved.
30+
exclude: List of features to exclude from the extractor.
3031
"""
3132
# Member variable(s)
3233
self._i3_file: str = ""
@@ -35,7 +36,7 @@ def __init__(self, extractor_name: str):
3536
self._calibration: Optional["icetray.I3Frame.Calibration"] = None
3637

3738
# Base class constructor
38-
super().__init__(extractor_name=extractor_name)
39+
super().__init__(extractor_name=extractor_name, exclude=exclude)
3940

4041
def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None:
4142
"""Extract GFrame and CFrame from i3/gcd-file pair.

src/graphnet/data/extractors/icecube/i3featureextractor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@
1414
class I3FeatureExtractor(I3Extractor):
1515
"""Base class for extracting specific, reconstructed features."""
1616

17-
def __init__(self, pulsemap: str):
17+
def __init__(self, pulsemap: str, exclude: list = [None]):
1818
"""Construct I3FeatureExtractor.
1919
2020
Args:
2121
pulsemap: Name of the pulse (series) map for which to extract
2222
reconstructed features.
23+
exclude: List of keys to exclude from the extracted data.
2324
"""
2425
# Member variable(s)
2526
self._pulsemap = pulsemap
2627

2728
# Base class constructor
28-
super().__init__(pulsemap)
29+
super().__init__(pulsemap, exclude=exclude)
2930

3031

3132
class I3FeatureExtractorIceCube86(I3FeatureExtractor):

src/graphnet/data/extractors/icecube/i3genericextractor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
keys: Optional[Union[str, List[str]]] = None,
4747
exclude_keys: Optional[Union[str, List[str]]] = None,
4848
extractor_name: str = GENERIC_EXTRACTOR_NAME,
49+
exclude: list = [None],
4950
):
5051
"""Construct I3GenericExtractor.
5152
@@ -73,7 +74,7 @@ def __init__(
7374
self._exclude_keys: Optional[List[str]] = exclude_keys
7475

7576
# Base class constructor
76-
super().__init__(extractor_name)
77+
super().__init__(extractor_name, exclude=exclude)
7778

7879
def _get_keys(self, frame: "icetray.I3Frame") -> List[str]:
7980
"""Get the list of keys to be queried from `frame`.

src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
class I3GalacticPlaneHybridRecoExtractor(I3Extractor):
1212
"""Class for extracting galatictic plane hybrid reconstruction."""
1313

14-
def __init__(self, name: str = "dnn_hybrid"):
14+
def __init__(self, name: str = "dnn_hybrid", exclude: list = [None]):
1515
"""Construct I3GalacticPlaneHybridRecoExtractor."""
1616
# Base class constructor
17-
super().__init__(name)
17+
super().__init__(name, exclude=exclude)
1818

1919
def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]:
2020
"""Extract TUMs DNN reconcstructions and associated variables."""

src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ def __init__(
1515
self,
1616
name: str = "northeren_tracks_muon_labels",
1717
padding_value: int = -1,
18+
exclude: list = [None],
1819
):
1920
"""Construct I3NTMuonLabelExtractor."""
2021
# Base class constructor
21-
super().__init__(name)
22+
super().__init__(name, exclude=exclude)
2223
self._padding_value = padding_value
2324

2425
def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:

src/graphnet/data/extractors/icecube/i3particleextractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ class I3ParticleExtractor(I3Extractor):
1515
with GraphNeT.
1616
"""
1717

18-
def __init__(self, extractor_name: str):
18+
def __init__(self, extractor_name: str, exclude: list = [None]):
1919
"""Construct I3ParticleExtractor."""
2020
# Base class constructor
21-
super().__init__(extractor_name=extractor_name)
21+
super().__init__(extractor_name=extractor_name, exclude=exclude)
2222

2323
def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
2424
"""Extract I3Particle properties from I3Particle in frame."""

src/graphnet/data/extractors/icecube/i3pisaextractor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
class I3PISAExtractor(I3Extractor):
1212
"""Class for extracting quantities required by PISA."""
1313

14-
def __init__(self, name: str = "pisa_dependencies"):
14+
def __init__(
15+
self, name: str = "pisa_dependencies", exclude: list = [None]
16+
):
1517
"""Construct `I3PISAExtractor`."""
1618
# Base class constructor
17-
super().__init__(name)
19+
super().__init__(name, exclude=exclude)
1820

1921
def __call__(
2022
self, frame: "icetray.I3Frame", padding_value: float = -1.0

0 commit comments

Comments
 (0)