Skip to content

Commit 92b150a

Browse files
authored
Merge pull request #770 from Aske-Rosted/clustering_utilities
Clustering utilities
2 parents c826a2d + 75b3260 commit 92b150a

File tree

2 files changed

+326
-11
lines changed

2 files changed

+326
-11
lines changed

src/graphnet/models/graphs/nodes/nodes.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from graphnet.utilities.decorators import final
1010
from graphnet.models import Model
1111
from graphnet.models.graphs.utils import (
12-
cluster_summarize_with_percentiles,
12+
cluster_and_pad,
1313
identify_indices,
1414
lex_sort,
1515
ice_transparency,
@@ -169,17 +169,14 @@ def _define_output_feature_names(
169169
cluster_idx,
170170
summ_idx,
171171
new_feature_names,
172-
) = self._get_indices_and_feature_names(
173-
input_feature_names, self._add_counts
174-
)
172+
) = self._get_indices_and_feature_names(input_feature_names)
175173
self._cluster_indices = cluster_idx
176174
self._summarization_indices = summ_idx
177175
return new_feature_names
178176

179177
def _get_indices_and_feature_names(
180178
self,
181179
feature_names: List[str],
182-
add_counts: bool,
183180
) -> Tuple[List[int], List[int], List[str]]:
184181
cluster_idx, summ_idx, summ_names = identify_indices(
185182
feature_names, self._cluster_on
@@ -188,7 +185,7 @@ def _get_indices_and_feature_names(
188185
for feature in summ_names:
189186
for pct in self._percentiles:
190187
new_feature_names.append(f"{feature}_pct{pct}")
191-
if add_counts:
188+
if self._add_counts:
192189
# add "counts" as the last feature
193190
new_feature_names.append("counts")
194191
return cluster_idx, summ_idx, new_feature_names
@@ -198,13 +195,16 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
198195
x = x.numpy()
199196
# Construct clusters with percentile-summarized features
200197
if hasattr(self, "_summarization_indices"):
201-
array = cluster_summarize_with_percentiles(
202-
x=x,
198+
cluster_class = cluster_and_pad(
199+
x=x, cluster_columns=self._cluster_indices
200+
)
201+
cluster_class.add_percentile_summary(
203202
summarization_indices=self._summarization_indices,
204-
cluster_indices=self._cluster_indices,
205203
percentiles=self._percentiles,
206-
add_counts=self._add_counts,
207204
)
205+
if self._add_counts:
206+
cluster_class.add_counts()
207+
array = cluster_class.clustered_x
208208
else:
209209
self.error(
210210
f"""{self.__class__.__name__} was not instatiated with

src/graphnet/models/graphs/utils.py

Lines changed: 316 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Utility functions for construction of graphs."""
22

3-
from typing import List, Tuple, Optional
3+
from typing import List, Tuple, Optional, Union
44
import os
55
import numpy as np
66
import pandas as pd
@@ -113,6 +113,8 @@ def identify_indices(
113113
return cluster_indices, summarization_indices, features_for_summarization
114114

115115

116+
# TODO Remove this function as it is superseded by
117+
# cluster_and_pad wich has the same functionality
116118
def cluster_summarize_with_percentiles(
117119
x: np.ndarray,
118120
summarization_indices: List[int],
@@ -172,6 +174,319 @@ def cluster_summarize_with_percentiles(
172174
return array
173175

174176

177+
class cluster_and_pad:
178+
"""Cluster and pad the data for further summarization.
179+
180+
Clusters the inptut data according to the specified columns
181+
and computes aggregate statistics on the clusters.
182+
The clustering will happen only ones creating a cluster matrix
183+
which will hold all the aggregated statistics and a padded matrix which
184+
will hold the padded data for quick calculation of aggregate statistics.
185+
186+
Example:
187+
cluster_and_pad(x = single_event_as_array,
188+
cluster_columns = [0,1,2])
189+
# Creates a cluster matrix and a padded matrix,
190+
# the cluster matrix will contain the unique values of the cluster columns,
191+
# no additional aggregate statistics are added yet.
192+
193+
cluster_class.add_percentile_summary(summarization_indices = [3,4,5],
194+
percentiles = [10,50,90])
195+
# Adds the 10th, 50th and 90th percentile of columns 3,4
196+
# and 5 in the input data to the cluster matrix.
197+
198+
cluster_class.add_std(column = 4)
199+
# Adds the standard deviation of column 4 in the input data
200+
# to the cluster matrix.
201+
x = cluster_class.clustered_x
202+
# Gets the clustered matrix with all the aggregate statistics.
203+
"""
204+
205+
def __init__(
206+
self,
207+
x: np.ndarray,
208+
cluster_columns: List[int],
209+
input_names: Optional[List[str]] = None,
210+
) -> None:
211+
"""Initialize the class with the data and cluster columns.
212+
213+
Args:
214+
x: Array to be clustered
215+
cluster_columns: List of column indices on which the clusters
216+
are constructed.
217+
input_names: Names of the columns in the input data for automatic
218+
generation of names.
219+
Adds:
220+
clustered_x: Added to the class
221+
_counts: Added to the class
222+
_padded_x: Added to the class
223+
"""
224+
x = lex_sort(x=x, cluster_columns=cluster_columns)
225+
226+
unique_sensors, self._counts = np.unique(
227+
x[:, cluster_columns], axis=0, return_counts=True
228+
)
229+
230+
contingency_table = np.concatenate(
231+
[unique_sensors, self._counts.reshape(-1, 1)], axis=1
232+
)
233+
234+
contingency_table = lex_sort(
235+
x=contingency_table, cluster_columns=cluster_columns
236+
)
237+
238+
self.clustered_x = contingency_table[:, 0 : unique_sensors.shape[1]]
239+
self._counts = (
240+
contingency_table[:, self.clustered_x.shape[1] :]
241+
.flatten()
242+
.astype(int)
243+
)
244+
245+
self._padded_x = np.empty(
246+
(len(self._counts), max(self._counts), x.shape[1])
247+
)
248+
self._padded_x.fill(np.nan)
249+
250+
for i in range(len(self._counts)):
251+
self._padded_x[i, : self._counts[i]] = x[: self._counts[i]]
252+
x = x[self._counts[i] :]
253+
254+
self._input_names = input_names
255+
if self._input_names is not None:
256+
assert (
257+
len(self._input_names) == x.shape[1]
258+
), "The input names must have the same length as the input data"
259+
260+
self._cluster_names = np.array(input_names)[cluster_columns]
261+
262+
def _add_column(
263+
self, column: np.ndarray, location: Optional[int] = None
264+
) -> None:
265+
"""Add a column to the clustered tensor.
266+
267+
Args:
268+
column: Column to be added to the tensor
269+
location: Location to insert the column in the clustered tensor.
270+
Altered:
271+
clustered_x: The column is added at the end of the tenor or
272+
inserted at the specified location
273+
"""
274+
if location is None:
275+
self.clustered_x = np.column_stack([self.clustered_x, column])
276+
else:
277+
self.clustered_x = np.insert(
278+
self.clustered_x, location, column, axis=1
279+
)
280+
281+
def _add_column_names(
282+
self, names: List[str], location: Optional[int] = None
283+
) -> None:
284+
"""Add names to the columns of the clustered tensor.
285+
286+
Args:
287+
names: Names to be added to the columns of the tensor
288+
location: Location to insert the names in the clustered tensor
289+
Altered:
290+
_cluster_names: The names are added at the end of the tensor
291+
or inserted at the specified location
292+
"""
293+
if location is None:
294+
self._cluster_names = np.append(self._cluster_names, names)
295+
else:
296+
self._cluster_names = np.insert(
297+
self._cluster_names, location, names
298+
)
299+
300+
def _calculate_charge_sum(self, charge_index: int) -> np.ndarray:
301+
"""Calculate the sum of the charge."""
302+
assert not hasattr(
303+
self, "_charge_sum"
304+
), "Charge sum has already been calculated, \
305+
re-calculation is not allowed"
306+
self._charge_sum = self._padded_x[:, :, charge_index].sum(axis=1)
307+
308+
def _calculate_charge_weights(self, charge_index: int) -> np.ndarray:
309+
"""Calculate the weights of the charge."""
310+
assert not hasattr(
311+
self, "_charge_weights"
312+
), "Charge weights have already been calculated, \
313+
re-calculation is not allowed"
314+
assert hasattr(
315+
self, "_charge_sum"
316+
), "Charge sum has not been calculated, \
317+
please run calculate_charge_sum"
318+
self._charge_weights = (
319+
self._padded_x[:, :, charge_index]
320+
/ self._charge_sum[:, np.newaxis]
321+
)
322+
323+
def add_charge_threshold_summary(
324+
self,
325+
summarization_indices: List[int],
326+
percentiles: List[int],
327+
charge_index: int,
328+
location: Optional[int] = None,
329+
) -> np.ndarray:
330+
"""Summarize features through percentiles on charge of sensor.
331+
332+
Args:
333+
summarization_indices: List of column indices that defines features
334+
that will be summarized with percentiles.
335+
percentiles: percentiles used to summarize `x`. E.g. [10,50,90].
336+
charge_index: index of the charge column in the padded tensor
337+
location: Location to insert the summarization indices in the
338+
clustered tensor defaults to adding at the end
339+
Adds:
340+
_charge_sum: Added to the class
341+
_charge_weights: Added to the class
342+
Altered:
343+
_padded_x: Charge is altered to be the cumulative sum
344+
of the charge divided by the total charge
345+
clustered_x: The summarization indices are added at the end
346+
of the tensor or inserted at the specified location.
347+
_cluster_names: The names are added at the end of the tensor
348+
or inserted at the specified location
349+
"""
350+
# convert the charge to the cumulative sum of the charge divided
351+
# by the total charge
352+
self._calculate_charge_sum(charge_index)
353+
self._calculate_charge_weights(charge_index)
354+
355+
self._padded_x[:, :, charge_index] = (
356+
self._padded_x[:, :, charge_index]
357+
/ self._charge_sum[:, np.newaxis]
358+
)
359+
360+
# Summarize the charge at different percentiles
361+
selections = np.argmax(
362+
self._padded_x[:, :, charge_index][:, :, np.newaxis]
363+
>= (np.array(percentiles) / 100),
364+
axis=1,
365+
)
366+
367+
selections += (np.arange(len(self._counts)) * self._padded_x.shape[1])[
368+
:, np.newaxis
369+
]
370+
371+
selections = self._padded_x[:, :, summarization_indices].reshape(
372+
-1, len(summarization_indices)
373+
)[selections]
374+
selections = selections.transpose(0, 2, 1).reshape(
375+
len(self.clustered_x), -1
376+
)
377+
self._add_column(selections, location)
378+
379+
# update the cluster names
380+
if self._input_names is not None:
381+
new_names = [
382+
self._input_names[i] + "_charge_threshold_" + str(p)
383+
for i in summarization_indices
384+
for p in percentiles
385+
]
386+
self._add_column_names(new_names, location)
387+
388+
def add_percentile_summary(
389+
self,
390+
summarization_indices: List[int],
391+
percentiles: List[int],
392+
method: str = "linear",
393+
location: Optional[int] = None,
394+
) -> np.ndarray:
395+
"""Summarize the features of the sensors using percentiles.
396+
397+
Args:
398+
summarization_indices: List of column indices that defines features
399+
that will be summarized with percentiles.
400+
percentiles: percentiles used to summarize `x`. E.g. [10,50,90].
401+
method: Method to summarize the features. E.g. "linear"
402+
location: Location to insert the summarization indices in the
403+
clustered tensor defaults to adding at the end
404+
Altered:
405+
clustered_x: The summarization indices are added at the end of
406+
the tensor or inserted at the specified location
407+
_cluster_names: The names are added at the end of the tensor
408+
or inserted at the specified location
409+
"""
410+
percentiles_x = np.nanpercentile(
411+
self._padded_x[:, :, summarization_indices],
412+
percentiles,
413+
axis=1,
414+
method=method,
415+
)
416+
417+
percentiles_x = percentiles_x.transpose(1, 2, 0).reshape(
418+
len(self.clustered_x), -1
419+
)
420+
self._add_column(percentiles_x, location)
421+
422+
# update the cluster names
423+
if self._input_names is not None:
424+
new_names = [
425+
self._input_names[i] + "_percentile_" + str(p)
426+
for i in summarization_indices
427+
for p in percentiles
428+
]
429+
self._add_column_names(new_names, location)
430+
431+
def add_counts(self, location: Optional[int] = None) -> np.ndarray:
432+
"""Add the counts of the sensor to the summarization features."""
433+
self._add_column(np.log10(self._counts), location)
434+
if self._input_names is not None:
435+
new_name = ["counts"]
436+
self._add_column_names(new_name, location)
437+
438+
def add_sum_charge(
439+
self, charge_index: int, location: Optional[int] = None
440+
) -> np.ndarray:
441+
"""Add the sum of the charge to the summarization features."""
442+
if not hasattr(self, "_charge_sum"):
443+
self._calculate_charge_sum(charge_index)
444+
self._add_column(self._charge_sum, location)
445+
# update the cluster names
446+
if self._input_names is not None:
447+
new_name = [self._input_names[charge_index] + "_sum"]
448+
self._add_column_names(new_name, location)
449+
450+
def add_std(
451+
self,
452+
columns: List[int],
453+
location: Optional[int] = None,
454+
weights: Union[np.ndarray, int] = 1,
455+
) -> np.ndarray:
456+
"""Add the standard deviation of the column.
457+
458+
Args:
459+
columns: Index of the columns from which to calculate the standard
460+
deviation.
461+
location: Location to insert the standard deviation in the
462+
clustered tensor defaults to adding at the end
463+
weights: Optional weights to be applied to the standard deviation
464+
"""
465+
self._add_column(
466+
np.nanstd(self._padded_x[:, :, columns] * weights, axis=1),
467+
location,
468+
)
469+
if self._input_names is not None:
470+
new_names = [self._input_names[i] + "_std" for i in columns]
471+
self._add_column_names(new_names, location)
472+
473+
def add_mean(
474+
self,
475+
columns: List[int],
476+
location: Optional[int] = None,
477+
weights: Union[np.ndarray, int] = 1,
478+
) -> np.ndarray:
479+
"""Add the mean of the column."""
480+
self._add_column(
481+
np.nanmean(self._padded_x[:, :, columns] * weights, axis=1),
482+
location,
483+
)
484+
# update the cluster names
485+
if self._input_names is not None:
486+
new_names = [self._input_names[i] + "_mean" for i in columns]
487+
self._add_column_names(new_names, location)
488+
489+
175490
def ice_transparency(
176491
z_offset: Optional[float] = None, z_scaling: Optional[float] = None
177492
) -> Tuple[interp1d, interp1d]:

0 commit comments

Comments
 (0)