20
20
from bokeh .layouts import row , column
21
21
from bokeh .io import output_notebook
22
22
from bokeh .palettes import RdYlGn11
23
+ from bokeh .transform import linear_cmap
23
24
24
25
from matsimAI .flowsim_dataset import FlowSimDataset
25
26
@@ -189,20 +190,23 @@ def create_html_plot(dataset, link_flows, hour_count, title, output_html_path="s
189
190
190
191
# Sensor edges (colored)
191
192
s_x0s , s_y0s , s_x1s , s_y1s = [], [], [], []
193
+ mid_x , mid_y = [], []
192
194
for (u , v , data ) in sensor_edges :
193
195
x0 , y0 = pos [u ]
194
196
x1 , y1 = pos [v ]
195
197
s_x0s .append (x0 )
196
198
s_y0s .append (y0 )
197
199
s_x1s .append (x1 )
198
200
s_y1s .append (y1 )
199
-
201
+ mid_x .append ((x0 + x1 ) / 2 )
202
+ mid_y .append ((y0 + y1 ) / 2 )
200
203
weight_all_hours .append (data ["Absolute Difference" ])
201
204
pred_all_hours .append (data ["Predicted Flow" ])
202
205
target_all_hours .append (data ["Target Flow" ])
203
206
204
207
sensors_edge_source = ColumnDataSource (data = dict (
205
208
x0 = s_x0s , y0 = s_y0s , x1 = s_x1s , y1 = s_y1s ,
209
+ mid_x = mid_x , mid_y = mid_y ,
206
210
weight = [w [0 ] for w in weight_all_hours ],
207
211
predicted = [p [0 ] for p in pred_all_hours ],
208
212
target = [t [0 ] for t in target_all_hours ],
@@ -231,7 +235,8 @@ def create_html_plot(dataset, link_flows, hour_count, title, output_html_path="s
231
235
# Loop through each hour to create a color mapper per hour
232
236
for hour_idx in range (hour_count ):
233
237
# Color mapper for each hour
234
- color_mapper = LinearColorMapper (palette = RdYlGn11 , low = 0 , high = max_abs_diff_per_hour [hour_idx ])
238
+ # color_mapper = LinearColorMapper(palette=RdYlGn11, low=0, high=max_abs_diff_per_hour[hour_idx])
239
+ color_mapper = LinearColorMapper (palette = RdYlGn11 , low = 0 , high = max (max_abs_diff_per_hour ))
235
240
236
241
# Draw sensor edges with color
237
242
plot .segment ('x0' , 'y0' , 'x1' , 'y1' , source = sensors_edge_source ,
@@ -244,7 +249,7 @@ def create_html_plot(dataset, link_flows, hour_count, title, output_html_path="s
244
249
# Draw nodes
245
250
node_x = [pos [i ][0 ] for i in pos ]
246
251
node_y = [pos [i ][1 ] for i in pos ]
247
- plot .scatter (node_x , node_y , size = .001 , color = "black" , alpha = 0.2 )
252
+ plot .scatter (node_x , node_y , size = .01 , color = "black" , alpha = 0.1 )
248
253
249
254
# Hover tool
250
255
tooltips = [("Abs Diff" , "@weight" ), ("Predicted" , "@predicted" ), ("Target" , "@target" )]
@@ -255,21 +260,90 @@ def create_html_plot(dataset, link_flows, hour_count, title, output_html_path="s
255
260
color_bar = ColorBar (color_mapper = color_mapper , location = (0 , 0 ))
256
261
plot .add_layout (color_bar , 'right' )
257
262
263
+
264
+ # === Histogram Setup ===
265
+ bin_edges = np .linspace (0 , max (max_abs_diff_per_hour ), 21 )
266
+ bin_centers = [(bin_edges [i ] + bin_edges [i + 1 ])/ 2 for i in range (len (bin_edges )- 1 )]
267
+ hist_source = ColumnDataSource (data = dict (
268
+ top = [0 ]* 20 ,
269
+ bin_center = bin_centers
270
+ ))
271
+
272
+ hist_plot = figure (title = "Visible Edge Flow Difference Histogram" , width = 300 , height = 600 , y_range = (0 , max (max_abs_diff_per_hour )),
273
+ toolbar_location = None , tools = "" )
274
+ hist_plot .hbar (y = 'bin_center' , height = (bin_edges [1 ] - bin_edges [0 ]) * 0.8 , right = 'top' , left = 0 , source = hist_source ,
275
+ fill_color = linear_cmap ('bin_center' , RdYlGn11 , low = 0 , high = max (max_abs_diff_per_hour )), line_color = "white" )
276
+ hist_plot .xaxis .axis_label = "Count"
277
+ hist_plot .yaxis .axis_label = "Abs Diff"
278
+ hist_plot .ygrid .grid_line_color = None
279
+
258
280
# Dropdown for hour selection
259
281
hour_selector = Select (title = "Select Hour" , value = "0" , options = [str (i ) for i in range (hour_count )])
260
- callback = CustomJS (args = dict (source = sensors_edge_source , hour_selector = hour_selector ), code = """
282
+
283
+ # === JS Callback ===
284
+ callback = CustomJS (args = dict (
285
+ source = sensors_edge_source ,
286
+ hist_source = hist_source ,
287
+ x_range = plot .x_range ,
288
+ y_range = plot .y_range ,
289
+ hour_selector = hour_selector
290
+ ), code = """
261
291
const hour = parseInt(hour_selector.value);
262
- const data = source.data;
263
- for (let i = 0; i < data['weight_all_hours'].length; i++) {
264
- data['weight'][i] = data['weight_all_hours'][i][hour];
265
- data['predicted'][i] = data['predicted_all_hours'][i][hour];
266
- data['target'][i] = data['target_all_hours'][i][hour];
292
+ const x0 = x_range.start;
293
+ const x1 = x_range.end;
294
+ const y0 = y_range.start;
295
+ const y1 = y_range.end;
296
+
297
+ const mids_x = source.data['mid_x'];
298
+ const mids_y = source.data['mid_y'];
299
+ const weights_all = source.data['weight_all_hours'];
300
+ const preds_all = source.data['predicted_all_hours'];
301
+ const targets_all = source.data['target_all_hours'];
302
+
303
+ const weights = [];
304
+ for (let i = 0; i < mids_x.length; i++) {
305
+ if (mids_x[i] >= x0 && mids_x[i] <= x1 && mids_y[i] >= y0 && mids_y[i] <= y1) {
306
+ weights.push(weights_all[i][hour]);
307
+ }
308
+ source.data['weight'][i] = weights_all[i][hour];
309
+ source.data['predicted'][i] = preds_all[i][hour];
310
+ source.data['target'][i] = targets_all[i][hour];
311
+ }
312
+
313
+ const bin_count = 20;
314
+ const bin_min = 0;
315
+ const bin_max = Math.max(...weights_all.flat());
316
+ const bin_size = (bin_max - bin_min) / bin_count;
317
+
318
+ const hist = Array(bin_count).fill(0);
319
+ const bin_centers = [];
320
+ for (let i = 0; i < bin_count; i++) {
321
+ bin_centers.push(bin_min + bin_size * (i + 0.5));
322
+ }
323
+
324
+ for (let i = 0; i < weights.length; i++) {
325
+ const w = weights[i];
326
+ const bin = Math.floor((w - bin_min) / bin_size);
327
+ if (bin >= 0 && bin < bin_count) {
328
+ hist[bin]++;
329
+ }
267
330
}
331
+
332
+ hist_source.data['top'] = hist;
333
+ hist_source.data['bin_center'] = bin_centers;
334
+
268
335
source.change.emit();
336
+ hist_source.change.emit();
269
337
""" )
338
+
339
+ # Link all events to callback
340
+ plot .x_range .js_on_change ("start" , callback )
341
+ plot .x_range .js_on_change ("end" , callback )
342
+ plot .y_range .js_on_change ("start" , callback )
343
+ plot .y_range .js_on_change ("end" , callback )
270
344
hour_selector .js_on_change ('value' , callback )
271
345
272
- layout = column ( row (plot , hour_selector ))
346
+ layout = row (plot , hist_plot , column ( hour_selector ))
273
347
save (layout , filename = output_html_path , title = title )
274
348
275
349
def build_abs_diff_graph (dataset , link_flows , sensor_idxs , target_flows , hour , title , save_path ):
@@ -374,7 +448,9 @@ def main(args):
374
448
target_graph = dataset .target_graph
375
449
sensor_idxs = dataset .sensor_idxs
376
450
377
- flows = torch .load (results_path / "best_flows.pt" )
451
+ # flows = torch.load(results_path / "best_flows.pt")
452
+ device = torch .device ("mps" ) if torch .backends .mps .is_available () else torch .device ("cpu" )
453
+ flows = torch .load (results_path / "best_flows.pt" , map_location = device )
378
454
link_flows = flows ["LinkFlows" ].to ("cpu" )
379
455
target_flows = target_graph .edge_attr
380
456
0 commit comments