Skip to content

Commit ed74864

Browse files
committed
fix: some bugs
1 parent 2c430dc commit ed74864

File tree

6 files changed

+1639
-19
lines changed

6 files changed

+1639
-19
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
zip_file_paths: # support loading multiple networks
2+
- "G:/My Drive/Dataset/huy_v3/simgen_Anytown_20241118_1026.zip"
3+
- "G:/My Drive/Dataset/huy_v3/simgen_epanet2_20241004_1246.zip"
4+
node_attrs:
5+
- pressure
6+
-
7+
- reservoir_base_head
8+
- junction_elevation
9+
- tank_elevation
10+
edge_attrs: []
11+
label_attrs: []
12+
edge_label_attrs: []
13+
num_records: 10_000 #this number will be divided into 60%train-20%val-20%test
14+
selected_snapshots: null
15+
verbose: false # turn on for more debug info
16+
split_type: scene # two way to split data - scenario or temporal axis
17+
split_set: all # take subset only. 4 options: train/val/test and all
18+
skip_nodes_list: [] # put the name of node here to skip. By default, we ADDED skip nodes w.r.t. the generation config.
19+
skip_types_list: [] # faster way to ignore component type (e.g., reservour, tank).
20+
unstackable_pad_value: -1.0
21+
bypass_skip_names_in_config: false
22+
do_lazy: false
23+
overwatch: false
24+
batch_axis_choice: snapshot
25+
do_cache: false # set True if your RAM can handle the whole array.
26+
subset_shuffle: true # if True, we shuffle the subset and STORE the shuffle ids. Otherwise, we do sampling with a dedicated step size w.r.t. num_records.
27+
split_per_network: true # Assume we choose num_records 10_000, so #samples of training subset is 6_000. If you have 2 networks, we will sample 3_000 per each.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
wdn_names: # support loading multiple networks
2+
- Anytown_7GB_1Y
3+
node_attrs:
4+
- pressure
5+
-
6+
- reservoir_base_head
7+
- junction_elevation
8+
- tank_elevation
9+
edge_attrs: []
10+
label_attrs: []
11+
edge_label_attrs: []
12+
num_records: 100 #this number will be divided into 60%train-20%val-20%test
13+
selected_snapshots: null
14+
verbose: false # turn on for more debug info
15+
split_type: scene # two way to split data - scenario or temporal axis
16+
split_set: all # take subset only. 4 options: train/val/test and all
17+
skip_nodes_list: [] # put the name of node here to skip. By default, we ADDED skip nodes w.r.t. the generation config.
18+
skip_types_list: [] # faster way to ignore component type (e.g., reservour, tank).
19+
unstackable_pad_value: -1.0
20+
bypass_skip_names_in_config: false
21+
do_lazy: false
22+
overwatch: false
23+
batch_axis_choice: snapshot
24+
do_cache: false # set True if your RAM can handle the whole array.
25+
subset_shuffle: false # if True, we shuffle the subset and STORE the shuffle ids. Otherwise, we do sampling with a dedicated step size w.r.t. num_records.
26+
split_per_network: true # Assume we choose num_records 10_000, so #samples of training subset is 6_000. If you have 2 networks, we will sample 3_000 per each.

ditec_wdn_dataset/core/datasets_large.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
# Created on Thu May 16 2024
33
# Copyright (c) 2024 Huy Truong
44
# ------------------------------
5-
# Purpose: (Dec 12 2024) BACK UP OF DATASETS_LARGE.PY
5+
# Purpose: The data interface of DiTEC-WDN.
6+
# If you use Hugging Face (.parquet), data interface is optional since HF provides a general interface.
7+
# In the case that you still want to try OUR interface, please try GiDaV7.
8+
# (Jun 05 2025) several bugs are fixed in this version.
9+
# (Dec 12 2024) BACK UP OF DATASETS_LARGE.PY
610
# ------------------------------
711

812
from collections import OrderedDict, defaultdict
@@ -442,15 +446,15 @@ def compute_indices(
442446
for network_index, root in enumerate(self._roots):
443447
if self.batch_axis_choice == "scene":
444448
# arr WILL have shape <merged>(#scenes, #nodes_or_#links, #statics + time_dims * #dynamics)
445-
num_samples = root.compute_first_size()
449+
num_samples = root.compute_first_size() if self.num_records is None else min(self.num_records, root.compute_first_size())
446450
relative_scene_ids = np.arange(num_samples)
447451
tuples = (relative_scene_ids, None)
448452
elif self.batch_axis_choice == "temporal":
449453
num_samples = root.time_dim
450454
relative_time_ids = np.arange(num_samples)
451455
tuples = (None, relative_time_ids)
452456
elif self.batch_axis_choice == "snapshot":
453-
num_scenes = root.compute_first_size()
457+
num_scenes = root.compute_first_size() if self.num_records is None else min(self.num_records, root.compute_first_size())
454458
time_dim = root.time_dim
455459
relative_scene_ids = np.arange(num_scenes).repeat(time_dim) # .reshape([-1, 1])
456460
relative_time_ids = np.tile(np.arange(time_dim), reps=num_scenes) # .reshape([-1, 1])
@@ -1034,7 +1038,9 @@ def __getitem__(
10341038

10351039
def __getitems__(self, idx: Union[int, np.integer, IndexType]) -> list[BaseData]:
10361040
# return self.get(idx) # type:ignore
1037-
batch: list[BaseData] = self.get(idx[0]) # type:ignore
1041+
# batch: list[BaseData] = self.get(idx[0]) # type:ignore
1042+
batch: list[BaseData] = self.get(idx) # type:ignore
1043+
10381044
batch = batch if self.transform is None else [self.transform(dat) for dat in batch]
10391045
return batch
10401046

@@ -1070,7 +1076,7 @@ def gather_statistic(
10701076
"edge_label": getattr(self._roots[0], "sorted_edge_label_attrs"),
10711077
}
10721078

1073-
time_dim = self._roots[0].attrs["duration"] // self._roots[0].attrs["time_step"]
1079+
time_dim = self._roots[0].time_dim # self._roots[0].attrs["duration"] // self._roots[0].attrs["time_step"]
10741080
param_attrs = which_array_attrs_map[which_array]
10751081
assert param_attrs is not None and len(param_attrs) > 0, f"ERROR! No found paramattrs from which_array=({which_array}): ({param_attrs})"
10761082
channel_splitters = [
@@ -1103,14 +1109,6 @@ def gather_statistic(
11031109
)
11041110

11051111
for i in range(len(cat_arrays)):
1106-
# if do_group_norm:
1107-
# # reshape arr from (scenes, #nodes_or_edges, t+1+t+...) -> (scenes * #nodes_or_edges, t+1+t+...)
1108-
# arr = cat_arrays[i]
1109-
# cat_arrays[i] = arr.reshape([-1, arr.shape[-1]], limit=self.chunk_limit)
1110-
# else:
1111-
# # everything goes flatten?
1112-
1113-
# cat_arrays.append(arr.reshape([-1, arr.shape[-1]], limit=self.chunk_limit))
11141112
arr = cat_arrays[i]
11151113
cat_arrays[i] = arr.reshape([-1, arr.shape[-1]], limit=self.chunk_limit)
11161114

@@ -1123,18 +1121,38 @@ def gather_statistic(
11231121
for i in range(len(channel_splitters)):
11241122
num_channels = channel_splitters[i]
11251123
t = flatten_array[:, current_idx : current_idx + num_channels]
1126-
t = t.flatten()
1124+
1125+
# t = t.flatten()
11271126
current_idx += num_channels
11281127

11291128
t_std_val, t_mean_val = t.std(axis=norm_dim), t.mean(axis=norm_dim)
11301129
# torch.std_mean(t, dim=norm_dim)
11311130
t_min_val, t_max_val = t.min(axis=norm_dim), t.max(axis=norm_dim)
11321131
# torch.min(t, dim=norm_dim).values, torch.max(t, dim=norm_dim).values
11331132

1134-
std_vals.append(t_std_val.reshape([-1]).repeat(num_channels))
1135-
mean_vals.append(t_mean_val.reshape([-1]).repeat(num_channels))
1136-
min_vals.append(t_min_val.reshape([-1]).repeat(num_channels))
1137-
max_vals.append(t_max_val.reshape([-1]).repeat(num_channels))
1133+
# std_vals.append(t_std_val.reshape([-1]).repeat(num_channels))
1134+
# mean_vals.append(t_mean_val.reshape([-1]).repeat(num_channels))
1135+
# min_vals.append(t_min_val.reshape([-1]).repeat(num_channels))
1136+
# max_vals.append(t_max_val.reshape([-1]).repeat(num_channels))
1137+
1138+
if norm_dim is not None:
1139+
t_std_val = np.expand_dims(t_std_val, axis=norm_dim)
1140+
1141+
t_mean_val = np.expand_dims(t_mean_val, axis=norm_dim)
1142+
1143+
t_min_val = np.expand_dims(t_min_val, axis=norm_dim)
1144+
1145+
t_max_val = np.expand_dims(t_max_val, axis=norm_dim)
1146+
else:
1147+
t_std_val = t_std_val.reshape([-1]).repeat(num_channels)
1148+
t_mean_val = t_mean_val.reshape([-1]).repeat(num_channels)
1149+
t_min_val = t_min_val.reshape([-1]).repeat(num_channels)
1150+
t_max_val = t_max_val.reshape([-1]).repeat(num_channels)
1151+
1152+
std_vals.append(t_std_val)
1153+
mean_vals.append(t_mean_val)
1154+
min_vals.append(t_min_val)
1155+
max_vals.append(t_max_val)
11381156

11391157
std_val = dac.concatenate(std_vals, axis=channel_dim)
11401158
mean_val = dac.concatenate(mean_vals, axis=channel_dim)

0 commit comments

Comments
 (0)