1
1
"""Example of reading events from Dataset class."""
2
2
3
- from timer import timer
4
-
5
- import awkward
6
3
import sqlite3
7
4
import time
8
- import torch .multiprocessing
9
- import torch .utils .data
5
+ import torch
10
6
from torch_geometric .data .batch import Batch
11
7
from tqdm import tqdm
8
+ import pandas as pd
9
+ import os
12
10
13
11
from graphnet .constants import TEST_PARQUET_DATA , TEST_SQLITE_DATA
14
12
from graphnet .data .constants import FEATURES , TRUTH
@@ -47,8 +45,10 @@ def main(backend: str) -> None:
47
45
num_workers = 30
48
46
wait_time = 0.00 # sec.
49
47
50
- # Define graph representation
51
- graph_definition = KNNGraph (detector = IceCubeDeepCore ())
48
+ # Define data representation
49
+ data_representation = KNNGraph (
50
+ detector = IceCubeDeepCore (), input_feature_names = features
51
+ )
52
52
53
53
for table in [pulsemap , truth_table ]:
54
54
# Get column names from backend
@@ -57,9 +57,11 @@ def main(backend: str) -> None:
57
57
cursor = conn .execute (f"SELECT * FROM { table } LIMIT 1" )
58
58
names = list (map (lambda x : x [0 ], cursor .description ))
59
59
else :
60
- ak = awkward .from_parquet (path , lazy = True )
61
- names = ak [table ].fields
62
- del ak
60
+ df = pd .DataFrame (os .path .join (path , f"{ table } *.parquet" ))
61
+ names = df .columns .tolist ()
62
+ # ak = awkward.from_parquet(path, lazy=True)
63
+ # names = ak[table].fields
64
+ # del ak
63
65
64
66
# Print
65
67
logger .info (f"Available columns in { table } " )
@@ -73,7 +75,7 @@ def main(backend: str) -> None:
73
75
features = features ,
74
76
truth = truth ,
75
77
truth_table = truth_table ,
76
- graph_definition = graph_definition ,
78
+ data_representation = data_representation ,
77
79
)
78
80
assert isinstance (dataset , Dataset )
79
81
@@ -91,13 +93,11 @@ def main(backend: str) -> None:
91
93
shuffle = True ,
92
94
num_workers = num_workers ,
93
95
collate_fn = Batch .from_data_list ,
94
- # persistent_workers=True,
95
96
prefetch_factor = 2 ,
96
97
)
97
98
98
- with timer ("torch dataloader" ):
99
- for batch in tqdm (dataloader , unit = " batches" , colour = "green" ):
100
- time .sleep (wait_time )
99
+ for batch in tqdm (dataloader , unit = " batches" , colour = "green" ):
100
+ time .sleep (wait_time )
101
101
102
102
logger .info (str (batch ))
103
103
logger .info (batch .size ())
0 commit comments