From 1cbca10f9cd597deff96d92eeb816c9815f4247a Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Wed, 9 Apr 2025 14:14:52 +0900 Subject: [PATCH 1/5] sqlite_string_selection --- src/graphnet/data/dataset/dataset.py | 18 +++++++--- .../utilities/string_selection_resolver.py | 36 +++++++++++++------ 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/src/graphnet/data/dataset/dataset.py b/src/graphnet/data/dataset/dataset.py index db274d906..838ea40b7 100644 --- a/src/graphnet/data/dataset/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -233,6 +233,7 @@ def __init__( loss_weight_default_value: Optional[float] = None, seed: Optional[int] = None, labels: Optional[Dict[str, Any]] = None, + use_super_selection: bool = False, ): """Construct Dataset. @@ -280,6 +281,10 @@ def __init__( events ~ event_no % 5 > 0"`). graph_definition: Method that defines the graph representation. labels: Dictionary of labels to be added to the dataset. + use_super_selection: If True, the string selection is handled by + the query function of the dataset class, rather than + pd.DataFrame.query. Defaults to False and should + only be used with sqlite. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -306,6 +311,7 @@ def __init__( self._graph_definition = deepcopy(graph_definition) self._labels = labels self._string_column = graph_definition._detector.string_index_name + self._use_super_selection = use_super_selection if node_truth is not None: assert isinstance(node_truth_table, str) @@ -356,6 +362,7 @@ def __init__( self, index_column=index_column, seed=seed, + use_super_selection=self._use_super_selection, ) if self._labels is not None: @@ -618,10 +625,13 @@ def _create_graph( """ # Convert truth to dict if len(truth.shape) == 1: - truth = truth.reshape(1, -1) - truth_dict = { - key: truth[:, index] for index, key in enumerate(self._truth) - } + truth_dict = { + key: truth[0][index] for index, key in enumerate(self._truth) + } + else: + truth_dict = { + key: truth[:, index] for index, key in enumerate(self._truth) + } # Define custom labels labels_dict = self._get_labels(truth_dict) diff --git a/src/graphnet/data/utilities/string_selection_resolver.py b/src/graphnet/data/utilities/string_selection_resolver.py index 8a1c61513..c19311bef 100644 --- a/src/graphnet/data/utilities/string_selection_resolver.py +++ b/src/graphnet/data/utilities/string_selection_resolver.py @@ -53,14 +53,17 @@ def __init__( index_column: str, seed: Optional[int] = None, use_cache: bool = True, + use_super_selection: bool = False, ): """Construct `StringSelectionResolver`.""" self._dataset = dataset self._index_column = index_column self._seed = seed self._use_cache = use_cache - + self._use_super_selection = use_super_selection # Base class constructor + if self._use_super_selection: + self._use_cache = False super().__init__(name=__name__, class_name=self.__class__.__name__) # Public method(s) @@ -214,19 +217,32 @@ def _query_selection_from_dataset(self, selection: str) -> pd.DataFrame: df_values = self._load_values_cache(values_cache_path) else: - df_values = pd.DataFrame( - data=self._dataset.query_table( - self._dataset.truth_table, - list(variables), - ), - columns=list(variables), - ) + if self._use_super_selection: + df_values = pd.DataFrame( + data=self._dataset.query_table( + self._dataset.truth_table, + list(variables), + selection=selection, + ).tolist(), + columns=list(variables), + ) + + else: + df_values = pd.DataFrame( + data=self._dataset.query_table( + self._dataset.truth_table, + list(variables), + ).tolist(), + columns=list(variables), + ) # (Opt.) Cache indices. if self._use_cache and not os.path.exists(values_cache_path): self._save_values_cache(df_values, values_cache_path) - - df_selection = df_values.query(selection) + if not self._use_super_selection: + df_selection = df_values.query(selection) + else: + df_selection = df_values return df_selection def _get_random_state(self, selection: str) -> Optional[int]: From db7e2368bd0c01bd2955f7b39a014b6313b2c5bb Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Wed, 9 Apr 2025 15:35:48 +0900 Subject: [PATCH 2/5] add_option_to_dataset_config.py --- src/graphnet/utilities/config/dataset_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index 79e03c184..7afa75651 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -57,6 +57,7 @@ class DatasetConfig(BaseConfig): seed: Optional[int] = None graph_definition: Any = None labels: Optional[Dict[str, Any]] = None + use_super_selection: bool = False def __init__(self, **data: Any) -> None: """Construct `DataConfig`. From cf89c549469c1fdcda8ac15ee3819484dc95d993 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Wed, 9 Apr 2025 16:25:21 +0900 Subject: [PATCH 3/5] ensure cmake install --- .github/actions/install/action.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/install/action.yml b/.github/actions/install/action.yml index 19e23be01..309e373a8 100644 --- a/.github/actions/install/action.yml +++ b/.github/actions/install/action.yml @@ -33,6 +33,7 @@ runs: run: | pip install --upgrade pip>=20 pip install wheel setuptools==59.5.0 + pip install cmake shell: bash - name: Install package run: | From b9eba42b36ade4cb500e4c00b826a8e720d338de Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Wed, 9 Apr 2025 16:28:10 +0900 Subject: [PATCH 4/5] fix formatting error --- .github/actions/install/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/install/action.yml b/.github/actions/install/action.yml index 309e373a8..0750c03e9 100644 --- a/.github/actions/install/action.yml +++ b/.github/actions/install/action.yml @@ -33,7 +33,7 @@ runs: run: | pip install --upgrade pip>=20 pip install wheel setuptools==59.5.0 - pip install cmake + pip install cmake shell: bash - name: Install package run: | From 5858a592bceb16e7173d9e3a68485ccbe3ffa4cf Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Wed, 9 Apr 2025 16:32:04 +0900 Subject: [PATCH 5/5] revert --- .github/actions/install/action.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/actions/install/action.yml b/.github/actions/install/action.yml index 0750c03e9..19e23be01 100644 --- a/.github/actions/install/action.yml +++ b/.github/actions/install/action.yml @@ -33,7 +33,6 @@ runs: run: | pip install --upgrade pip>=20 pip install wheel setuptools==59.5.0 - pip install cmake shell: bash - name: Install package run: |