-
Notifications
You must be signed in to change notification settings - Fork 104
Add cnn model #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sevmag
wants to merge
27
commits into
graphnet-team:main
Choose a base branch
from
sevmag:add_cnn_model
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add cnn model #813
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
44df8db
Adding Images
sevmag b0f7eff
fix init
sevmag f026e73
fix logic for detector in IC86 Image
sevmag 908cb54
Fixing bugs TheoCNN
sevmag 3b8ad84
more fixes for cnn
sevmag aaeaa9e
Fixing batching for images
sevmag 7ed55a8
Adjusting imports
sevmag d9760c8
Fixing gitignore & mapping_table
816a6b3
fixing image num_nodes
sevmag 9226db2
adding_counts to summary features
sevmag dcb9933
change mapping to faster version
sevmag 8956614
Merge branch 'graphnet-team:main' into add_cnn_model
sevmag 5a900bb
Faster Mapping & unit tests
sevmag b45dad7
Rename classes & more unit tests
sevmag 00525af
Adding LCSC model
sevmag dc39d64
Changing annotations and docstrings
sevmag eaa7bfa
Merge branch 'graphnet-team:main' into add_cnn_model
sevmag 20f7b4f
Adding example script
sevmag ee135ac
adding cnn example
sevmag adfc1c1
Merge branch 'graphnet-team:main' into add_cnn_model
sevmag ba4f012
Adjust docstring
sevmag cc42748
Fixing comments in example
sevmag 7104df8
Add more to docstring in LCSC
sevmag 7054144
adjust docstrings theos cnn
sevmag bbc9e53
docstring clean ups
sevmag a693a49
add info to docstring
sevmag 8e94de5
add shape property
sevmag File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,343 @@ | ||
"""Example of training a CNN Model.""" | ||
|
||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
from pytorch_lightning.loggers import WandbLogger | ||
import torch | ||
from torch.optim.adam import Adam | ||
|
||
from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR | ||
from graphnet.data.constants import TRUTH | ||
from graphnet.models import StandardModel | ||
from graphnet.models.cnn import LCSC | ||
from graphnet.models.data_representation import PercentileClusters | ||
from graphnet.models.task.reconstruction import EnergyReconstruction | ||
from graphnet.training.callbacks import PiecewiseLinearLR | ||
from graphnet.training.loss_functions import LogCoshLoss | ||
from graphnet.utilities.argparse import ArgumentParser | ||
from graphnet.utilities.logging import Logger | ||
from graphnet.data.dataset import SQLiteDataset | ||
from graphnet.data.dataset import ParquetDataset | ||
from graphnet.models.detector import ORCA150 | ||
from torch_geometric.data import Batch | ||
from graphnet.models.data_representation.images import ExamplePrometheusImage | ||
|
||
# Constants | ||
features = ["sensor_id", "sensor_string_id", "t"] | ||
truth = TRUTH.PROMETHEUS | ||
|
||
|
||
def main( | ||
path: str, | ||
pulsemap: str, | ||
target: str, | ||
truth_table: str, | ||
gpus: Optional[List[int]], | ||
max_epochs: int, | ||
early_stopping_patience: int, | ||
batch_size: int, | ||
num_workers: int, | ||
wandb: bool = False, | ||
) -> None: | ||
"""Run example.""" | ||
# Construct Logger | ||
logger = Logger() | ||
|
||
# Initialise Weights & Biases (W&B) run | ||
if wandb: | ||
# Make sure W&B output directory exists | ||
wandb_dir = "./wandb/" | ||
os.makedirs(wandb_dir, exist_ok=True) | ||
wandb_logger = WandbLogger( | ||
project="example-script", | ||
entity="graphnet-team", | ||
save_dir=wandb_dir, | ||
log_model=True, | ||
) | ||
|
||
logger.info(f"features: {features}") | ||
logger.info(f"truth: {truth}") | ||
|
||
# Configuration | ||
config: Dict[str, Any] = { | ||
"path": path, | ||
"pulsemap": pulsemap, | ||
"batch_size": batch_size, | ||
"num_workers": num_workers, | ||
"target": target, | ||
"early_stopping_patience": early_stopping_patience, | ||
"fit": { | ||
"gpus": gpus, | ||
"max_epochs": max_epochs, | ||
}, | ||
"dataset_reference": ( | ||
SQLiteDataset if path.endswith(".db") else ParquetDataset | ||
), | ||
} | ||
|
||
archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_cnn_model") | ||
run_name = "lcsc_{}_example".format(config["target"]) | ||
if wandb: | ||
# Log configuration to W&B | ||
wandb_logger.experiment.config.update(config) | ||
|
||
# First we need to define how the image is constructed. | ||
# This is done using an ImageDefinition. | ||
|
||
# An ImageDefinition combines two components: | ||
|
||
# 1. A pixel definition, which defines how the pixel data is | ||
# represented. Since an image has always fixed dimensions this | ||
# pixel definition is also responsible to represent the data in | ||
# a way such that this fixed dimensions can be achieved. | ||
# Normally, this could mean that light pulses that arrive at | ||
# the same optical module must be aggregated to a | ||
# fixed-dimensional vector. | ||
# A pixel definition works exactly the same as the | ||
# a node definition in the graph scenerio. | ||
|
||
# 2. A pixel mapping, which defines where each pixel is located | ||
# in the final image. This is highly detector specific, as it | ||
# depends on the geometry of the detector. | ||
|
||
# An ImageDefinition can be used to create multiple images of | ||
# a single event. In the example of IceCube, you can e.g | ||
# create three images, one for the so called main array, | ||
# one for the upper deep core and one for the lower deep | ||
# core. Essentially, these are just different areas in | ||
# the detector. | ||
|
||
# Here we use the PercentileClusters pixel definition, which | ||
# aggregates the light pulses that arrive at the same optical | ||
# module with percentiles. | ||
print(features) | ||
pixel_definition = PercentileClusters( | ||
cluster_on=["sensor_id", "sensor_string_id"], | ||
percentiles=[10, 50, 90], | ||
add_counts=True, | ||
input_feature_names=features, | ||
) | ||
|
||
# The final image definition used here is the ExamplePrometheusImage, | ||
# which is a detector specific pixel mapping. | ||
# It maps optical modules into the image | ||
# using the sensor_string_id and sensor_id | ||
# (number of the optical module). | ||
# The detector class standardizes the input features, | ||
# so that the features are in a ML friendly range. | ||
# For the mapping of the optical modules to the image it is | ||
# essential to not change the value of the sensor_id and | ||
# sensor_string_id. Therefore we need to make sure that | ||
# these features are not standardized, which is done by the | ||
# `replace_with_identity` argument of the detector. | ||
image_definition = ExamplePrometheusImage( | ||
detector=ORCA150( | ||
replace_with_identity=[ | ||
"sensor_id", | ||
"sensor_string_id", | ||
], | ||
), | ||
node_definition=pixel_definition, | ||
input_feature_names=features, | ||
string_label="sensor_string_id", | ||
dom_number_label="sensor_id", | ||
) | ||
|
||
# Use SQLiteDataset to load in data | ||
# The input here depends on the dataset being used, | ||
# in this case the Prometheus dataset. | ||
dataset = SQLiteDataset( | ||
path=config["path"], | ||
pulsemaps=config["pulsemap"], | ||
truth_table=truth_table, | ||
features=features, | ||
truth=truth, | ||
data_representation=image_definition, | ||
) | ||
|
||
# Create the training and validation dataloaders. | ||
training_dataloader = torch.utils.data.DataLoader( | ||
dataset=dataset, | ||
batch_size=config["batch_size"], | ||
num_workers=config["num_workers"], | ||
collate_fn=Batch.from_data_list, | ||
) | ||
|
||
validation_dataloader = torch.utils.data.DataLoader( | ||
dataset=dataset, | ||
batch_size=config["batch_size"], | ||
num_workers=config["num_workers"], | ||
collate_fn=Batch.from_data_list, | ||
) | ||
|
||
# Building model | ||
|
||
# Define architecture of the backbone, in this example | ||
# the LCSC architecture from Alexander Harnisch is used. | ||
backbone = LCSC( | ||
num_input_features=image_definition.nb_outputs, | ||
out_put_dim=2, | ||
input_norm=True, | ||
num_conv_layers=5, | ||
conv_filters=[5, 10, 20, 40, 60], | ||
kernel_size=3, | ||
image_size=(8, 9, 22), # dimensions of the example image | ||
pooling_type=[ | ||
"Avg", | ||
None, | ||
"Avg", | ||
None, | ||
"Avg", | ||
], | ||
pooling_kernel_size=[ | ||
[1, 1, 2], | ||
None, | ||
[2, 2, 2], | ||
None, | ||
[2, 2, 2], | ||
], | ||
pooling_stride=[ | ||
[1, 1, 2], | ||
None, | ||
[2, 2, 2], | ||
None, | ||
[2, 2, 2], | ||
], | ||
num_fc_neurons=50, | ||
norm_list=True, | ||
norm_type="Batch", | ||
) | ||
# Define the task. | ||
# Here an energy reconstruction, with a LogCoshLoss function. | ||
# The target and prediction are transformed using the log10 function. | ||
# When infering the prediction is transformed back to the | ||
# original scale using 10^x. | ||
task = EnergyReconstruction( | ||
hidden_size=backbone.nb_outputs, | ||
target_labels=config["target"], | ||
loss_function=LogCoshLoss(), | ||
transform_prediction_and_target=lambda x: torch.log10(x), | ||
transform_inference=lambda x: torch.pow(10, x), | ||
) | ||
# Define the full model, which includes the backbone, task(s), | ||
# along with typical machine learning options such as | ||
# learning rate optimizers and schedulers. | ||
model = StandardModel( | ||
data_representation=image_definition, | ||
backbone=backbone, | ||
tasks=[task], | ||
optimizer_class=Adam, | ||
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03}, | ||
scheduler_class=PiecewiseLinearLR, | ||
scheduler_kwargs={ | ||
"milestones": [ | ||
0, | ||
len(training_dataloader) / 2, | ||
len(training_dataloader) * config["fit"]["max_epochs"], | ||
], | ||
"factors": [1e-2, 1, 1e-02], | ||
}, | ||
scheduler_config={ | ||
"interval": "step", | ||
}, | ||
) | ||
|
||
# Training model | ||
model.fit( | ||
training_dataloader, | ||
validation_dataloader, | ||
early_stopping_patience=config["early_stopping_patience"], | ||
logger=wandb_logger if wandb else None, | ||
**config["fit"], | ||
) | ||
|
||
# Get predictions | ||
additional_attributes = model.target_labels | ||
assert isinstance(additional_attributes, list) # mypy | ||
|
||
results = model.predict_as_dataframe( | ||
validation_dataloader, | ||
additional_attributes=additional_attributes + ["event_no"], | ||
gpus=config["fit"]["gpus"], | ||
) | ||
|
||
# Save predictions and model to file | ||
db_name = path.split("/")[-1].split(".")[0] | ||
path = os.path.join(archive, db_name, run_name) | ||
logger.info(f"Writing results to {path}") | ||
os.makedirs(path, exist_ok=True) | ||
|
||
# Save results as .csv | ||
results.to_csv(f"{path}/cnn_results.csv") | ||
|
||
# Save model config and state dict - Version safe save method. | ||
# This method of saving models is the safest way. | ||
model.save_state_dict(f"{path}/cnn_state_dict.pth") | ||
model.save_config(f"{path}/cnn_model_config.yml") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
# Parse command-line arguments | ||
parser = ArgumentParser( | ||
description=""" | ||
Train GNN model without the use of config files. | ||
""" | ||
) | ||
|
||
parser.add_argument( | ||
"--path", | ||
help="Path to dataset file (default: %(default)s)", | ||
default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", | ||
) | ||
|
||
parser.add_argument( | ||
"--pulsemap", | ||
help="Name of pulsemap to use (default: %(default)s)", | ||
default="total", | ||
) | ||
|
||
parser.add_argument( | ||
"--target", | ||
help=( | ||
"Name of feature to use as regression target (default: " | ||
"%(default)s)" | ||
), | ||
default="total_energy", | ||
) | ||
|
||
parser.add_argument( | ||
"--truth-table", | ||
help="Name of truth table to be used (default: %(default)s)", | ||
default="mc_truth", | ||
) | ||
|
||
parser.with_standard_arguments( | ||
"gpus", | ||
("max-epochs", 1), | ||
"early-stopping-patience", | ||
("batch-size", 16), | ||
("num-workers", 2), | ||
) | ||
|
||
parser.add_argument( | ||
"--wandb", | ||
action="store_true", | ||
help="If True, Weights & Biases are used to track the experiment.", | ||
) | ||
|
||
args, unknown = parser.parse_known_args() | ||
|
||
main( | ||
args.path, | ||
args.pulsemap, | ||
args.target, | ||
args.truth_table, | ||
args.gpus, | ||
args.max_epochs, | ||
args.early_stopping_patience, | ||
args.batch_size, | ||
args.num_workers, | ||
args.wandb, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""CNN-specific modules, for performing the main learnable operations.""" | ||
|
||
from .cnn import CNN | ||
from .theos_muonE_upgoing import TheosMuonEUpgoing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
from .lcsc import LCSC |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be neat to add a property to the ImageDefinition that contains the resulting image dimension. E.g.
ImageDefinition.shape