Skip to content

Add cnn model #813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open

Add cnn model #813

wants to merge 27 commits into from

Conversation

sevmag
Copy link
Collaborator

@sevmag sevmag commented Jul 18, 2025

This is the big PR for the goal of adding CNN support to GraphNeT, enabling direct comparisons (see #771).

The CNN support consists of:

  • ImageDefinition to represent data as an image
  • CNN architectures to train
  • Unit tests
  • Example script for CNN training

An ImageDefinition consists of 2 parts:

  1. A NodeDefinition that preprocesses the raw data and makes sure that the pulses are aggregated at the optical modules (e.g. ClusterSummaryFeatures, or PercentileClusters )
  2. A PixelMapping, which is responsible for creating the images and mapping the nodes into the right location in the image

There are 2 CNN architectures implemented:

  1. LCSC from Alexander Harnisch
  2. TheosMuonEUpgoing, which is the Energy reconstruction architecture from Theo Glauch, used in IceCube

Timing of the ImageDefinition in Comparison to Other Datareps

At a low number of pulses, the bottleneck of the ImageDefinition is the initialisation of zero tensors

Timed Modules

input_feature_names = ['string', 'dom_number', 'dom_time', 'charge']
node_def = PercentileClusters(
    input_feature_names=input_feature_names,
    cluster_on = ['string', 'dom_number'],
    percentiles=np.linspace(0.2, 1.0, 5),
)
data_rep = {
    'edgeless': EdgelessGraph(
        node_definition=node_def,
        detector=IceCube86(
            replace_with_identity=input_feature_names,
        ),
        input_feature_names=input_feature_names,
    ),
    'knn_graph_8NN': KNNGraph(
        node_definition=node_def,
        detector=IceCube86(
            replace_with_identity=input_feature_names,
        ),
        input_feature_names=input_feature_names,
        nb_nearest_neighbours=8,
    ),
    'knn_graph_16NN': KNNGraph(
        node_definition=node_def,
        detector=IceCube86(
            replace_with_identity=input_feature_names,
        ),
        input_feature_names=input_feature_names,
        nb_nearest_neighbours=16,
    ),
    'knn_graph_64NN': KNNGraph(
        node_definition=node_def,
        detector=IceCube86(
            replace_with_identity=input_feature_names,
        ),
        input_feature_names=input_feature_names,
        nb_nearest_neighbours=64,
    ),
    'ic86_dnn': IC86DNNImage(
        node_definition=node_def,
        input_feature_names=input_feature_names,
        include_lower_dc=True,
        include_upper_dc=True,
    ),
}

5000-200000 Mock Pulses (log scale)

Screenshot 2025-07-18 at 17 40 42

1-5000 Mock Pulses

Screenshot 2025-07-18 at 17 33 58

1-500 Mock Pulses

Screenshot 2025-07-18 at 17 32 09

@sevmag sevmag requested a review from RasmusOrsoe July 18, 2025 15:45
num_conv_layers=5,
conv_filters=[5, 10, 20, 40, 60],
kernel_size=3,
image_size=(8, 9, 22), # dimensions of the example image
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be neat to add a property to the ImageDefinition that contains the resulting image dimension. E.g. ImageDefinition.shape

"""CNN-specific modules, for performing the main learnable operations."""

from .cnn import CNN
from .theos_muonE_upgoing import TheosMuonEUpgoing
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.theos_muonE_upgoing breaks with snake-case convention. Do we need "theos" in there? It's very jargony. Credit can be given in the associated docstring instead of the module name

"""Initialize the Lightning CNN signal classifier (LCSC).

Args:
num_input_features (int): Number of input features.
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great with the detailed argument descriptions, but they break the existing conventions used in the library. The types should be not repeated within the docstring itself, as our documentation automatically adds them to the compiled documentation when compiled based on type hinting in code.

I.e.
num_input_features (int): Number of input features.

should be

num_input_features: Number of input features.

You can see the docstring for DynEdge here and the resulting documentation here

"""
super().__init__(nb_inputs=num_input_features, nb_outputs=out_put_dim)

# Check input parameters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's quite a bit of parsing in the init here. Looks like you're doing two things: checking incompatible arguments (raising errors) and parsing the acceptable arguments for subsequent use in the layer building. You could instead move this logic into one or more private methods that are used in the init function - this will improve the readability greatly. For example:

def __init__(param1: type, param2: type):
    """ Docstring """

    # Check and parse input parameters
    filters, kernel_sizes, padding, .. = self._parse_conv_arguments(param1 = param1, param2=param2)
    pooling_size, pooling_stride, .. = self._parse_pooling_arguments(param1 = param1, param2=param2)

    # Set Convolution Layers
    self._set_conv_layers(filters = filters, kernel_sizes = kernel_sizes, 
                          ...., 
                          pooling_sizes = pooling_sizes) 

    # Set Linear layers
    self.flatten = torch.nn.Flatten()
    self.fc1 = torch.nn.Linear(latent_dim, num_fc_neurons)
    self.fc2 = torch.nn.Linear(num_fc_neurons, out_put_dim)


def forward(self, data: Data) -> torch.Tensor:
"""Forward pass of the LCSC."""
assert len(data.x) == 1, "Only Main Array image is supported for LCSC"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion checks that a single image is produced by the image representation as opposed to multiple, not that a specific image representation is used, e.g. "main array".

https://github.com/AlexHarn)

Intended to be used with the IceCube 86 image containing
only the Main Array image.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correctly understood that this method will work with any single-image representations, but that the method and default parameters were tested and selected based on IceCube simulation and a particular representation that utilizes the main array only? If so, I think adjusting this sentence would be wise.

@@ -502,6 +502,8 @@ class ClusterSummaryFeatures(NodeDefinition):
For more details on some of the features see
Theo Glauchs thesis (chapter 5.3):
https://mediatum.ub.tum.de/node?id=1584755

NOTE: number of pulses per cluster is not mentioned/used in the thesis
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this supposed to be understood? Do you mean that introducing this within the method is your own creation?

@@ -0,0 +1,411 @@
"""CNN used for muon energy reconstruction in IceCube.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/graphnet/models/cnn/theos_muonE_upgoing.py breaks with the snake-case convention. Do we strictly need "theos" in the module name? Proper credits can be given in the module docstring.



class Conv3dBN(LightningModule):
"""The Conv3dBN module from Theos CNN model."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theos -> Theo Glauch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a bit more detail to inform the reader of what this module is. E.g.

"""Implementation of the Conv3dBN image convolution module from Theo Glauch."""



class InceptionBlock4(LightningModule):
"""The inception_block4 module from Theos CNN model."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments above apply here too.



class InceptionResnet(LightningModule):
"""The inception_resnet module from Theos CNN model."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments from above apply here, too.

return x + self._scale * tmp


class TheosMuonEUpgoing(CNN):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the official name of the method, and to my knowledge, nothing within the method restricts it to upgoing events only. I would strongly suggest finding a more accessible name for the method. I believe it's more commonly known as the "DNN" within IceCube, no? You can use the docstring to provide further details on its origin. I.e. proper credits to Theo and his use of the method.

class TheosMuonEUpgoing(CNN):
"""The TheosMuonEUpgoing module."""

def __init__(self, nb_inputs: int = 15, nb_outputs: int = 16) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a good reason why the hyperparameters of the method are hardcoded?

If not, let's please make them arguments, as that will greatly increase the reusability of the method.

Args:
dtype: data type used for node features. e.g. ´torch.float´
string_label: Name of the feature corresponding
to the DOM string number. Values Integers betweem 1 - 86
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

betweem -> between

self._include_upper_dc = include_upper_dc

# read mapping from parquet file
df = pd.read_parquet(IC86_CNN_MAPPING)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be a way for us to compile the mapping at instantiation without relying on a file?

self._mapping = df
super().__init__(pixel_feature_names=pixel_feature_names)

def _set_indeces(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_indices

self._sensor_number_label = sensor_number_label
self._pixel_feature_names = pixel_feature_names

self._set_indeces(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_indices

self._dom_number_label = dom_number_label
self._pixel_feature_names = pixel_feature_names

self._set_indeces(pixel_feature_names, dom_number_label, string_label)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_indices

self._mapping = df
super().__init__(pixel_feature_names=pixel_feature_names)

def _set_indeces(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_indices

row[3], # mat_ax1
] = batch_row_features[i]

# unqueeze to add dimension for batching
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unqueeze - unsqueeze

)

# data.x is expected to be a tensor with shape (N, F)
# where N is the number of nodes and F is the number of features.
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rows represent pixels, right?

match_indices = self._mapping.loc[
zip(*string_dom_number.t().tolist())
][
["string", "dom_number", "mat_ax0", "mat_ax1", "mat_ax2"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like your method relies on a very specific set of column names in this file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants