Skip to content

Commit ff03c67

Browse files
author
Isaac Peterson
committed
2 parents 2149b48 + 53607a2 commit ff03c67

File tree

1 file changed

+87
-11
lines changed

1 file changed

+87
-11
lines changed

matsimAI/scripts/analysis.py

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from bokeh.layouts import row, column
2121
from bokeh.io import output_notebook
2222
from bokeh.palettes import RdYlGn11
23+
from bokeh.transform import linear_cmap
2324

2425
from matsimAI.flowsim_dataset import FlowSimDataset
2526

@@ -189,20 +190,23 @@ def create_html_plot(dataset, link_flows, hour_count, title, output_html_path="s
189190

190191
# Sensor edges (colored)
191192
s_x0s, s_y0s, s_x1s, s_y1s = [], [], [], []
193+
mid_x, mid_y = [], []
192194
for (u, v, data) in sensor_edges:
193195
x0, y0 = pos[u]
194196
x1, y1 = pos[v]
195197
s_x0s.append(x0)
196198
s_y0s.append(y0)
197199
s_x1s.append(x1)
198200
s_y1s.append(y1)
199-
201+
mid_x.append((x0 + x1) / 2)
202+
mid_y.append((y0 + y1) / 2)
200203
weight_all_hours.append(data["Absolute Difference"])
201204
pred_all_hours.append(data["Predicted Flow"])
202205
target_all_hours.append(data["Target Flow"])
203206

204207
sensors_edge_source = ColumnDataSource(data=dict(
205208
x0=s_x0s, y0=s_y0s, x1=s_x1s, y1=s_y1s,
209+
mid_x=mid_x, mid_y=mid_y,
206210
weight=[w[0] for w in weight_all_hours],
207211
predicted=[p[0] for p in pred_all_hours],
208212
target=[t[0] for t in target_all_hours],
@@ -231,7 +235,8 @@ def create_html_plot(dataset, link_flows, hour_count, title, output_html_path="s
231235
# Loop through each hour to create a color mapper per hour
232236
for hour_idx in range(hour_count):
233237
# Color mapper for each hour
234-
color_mapper = LinearColorMapper(palette=RdYlGn11, low=0, high=max_abs_diff_per_hour[hour_idx])
238+
# color_mapper = LinearColorMapper(palette=RdYlGn11, low=0, high=max_abs_diff_per_hour[hour_idx])
239+
color_mapper = LinearColorMapper(palette=RdYlGn11, low=0, high=max(max_abs_diff_per_hour))
235240

236241
# Draw sensor edges with color
237242
plot.segment('x0', 'y0', 'x1', 'y1', source=sensors_edge_source,
@@ -244,7 +249,7 @@ def create_html_plot(dataset, link_flows, hour_count, title, output_html_path="s
244249
# Draw nodes
245250
node_x = [pos[i][0] for i in pos]
246251
node_y = [pos[i][1] for i in pos]
247-
plot.scatter(node_x, node_y, size=.001, color="black", alpha=0.2)
252+
plot.scatter(node_x, node_y, size=.01, color="black", alpha=0.1)
248253

249254
# Hover tool
250255
tooltips = [("Abs Diff", "@weight"), ("Predicted", "@predicted"), ("Target", "@target")]
@@ -255,21 +260,90 @@ def create_html_plot(dataset, link_flows, hour_count, title, output_html_path="s
255260
color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0))
256261
plot.add_layout(color_bar, 'right')
257262

263+
264+
# === Histogram Setup ===
265+
bin_edges = np.linspace(0, max(max_abs_diff_per_hour), 21)
266+
bin_centers = [(bin_edges[i] + bin_edges[i+1])/2 for i in range(len(bin_edges)-1)]
267+
hist_source = ColumnDataSource(data=dict(
268+
top=[0]*20,
269+
bin_center=bin_centers
270+
))
271+
272+
hist_plot = figure(title="Visible Edge Flow Difference Histogram", width=300, height=600, y_range=(0, max(max_abs_diff_per_hour)),
273+
toolbar_location=None, tools="")
274+
hist_plot.hbar(y='bin_center', height=(bin_edges[1] - bin_edges[0]) * 0.8, right='top', left=0, source=hist_source,
275+
fill_color=linear_cmap('bin_center', RdYlGn11, low=0, high=max(max_abs_diff_per_hour)), line_color="white")
276+
hist_plot.xaxis.axis_label = "Count"
277+
hist_plot.yaxis.axis_label = "Abs Diff"
278+
hist_plot.ygrid.grid_line_color = None
279+
258280
# Dropdown for hour selection
259281
hour_selector = Select(title="Select Hour", value="0", options=[str(i) for i in range(hour_count)])
260-
callback = CustomJS(args=dict(source=sensors_edge_source, hour_selector=hour_selector), code="""
282+
283+
# === JS Callback ===
284+
callback = CustomJS(args=dict(
285+
source=sensors_edge_source,
286+
hist_source=hist_source,
287+
x_range=plot.x_range,
288+
y_range=plot.y_range,
289+
hour_selector=hour_selector
290+
), code="""
261291
const hour = parseInt(hour_selector.value);
262-
const data = source.data;
263-
for (let i = 0; i < data['weight_all_hours'].length; i++) {
264-
data['weight'][i] = data['weight_all_hours'][i][hour];
265-
data['predicted'][i] = data['predicted_all_hours'][i][hour];
266-
data['target'][i] = data['target_all_hours'][i][hour];
292+
const x0 = x_range.start;
293+
const x1 = x_range.end;
294+
const y0 = y_range.start;
295+
const y1 = y_range.end;
296+
297+
const mids_x = source.data['mid_x'];
298+
const mids_y = source.data['mid_y'];
299+
const weights_all = source.data['weight_all_hours'];
300+
const preds_all = source.data['predicted_all_hours'];
301+
const targets_all = source.data['target_all_hours'];
302+
303+
const weights = [];
304+
for (let i = 0; i < mids_x.length; i++) {
305+
if (mids_x[i] >= x0 && mids_x[i] <= x1 && mids_y[i] >= y0 && mids_y[i] <= y1) {
306+
weights.push(weights_all[i][hour]);
307+
}
308+
source.data['weight'][i] = weights_all[i][hour];
309+
source.data['predicted'][i] = preds_all[i][hour];
310+
source.data['target'][i] = targets_all[i][hour];
311+
}
312+
313+
const bin_count = 20;
314+
const bin_min = 0;
315+
const bin_max = Math.max(...weights_all.flat());
316+
const bin_size = (bin_max - bin_min) / bin_count;
317+
318+
const hist = Array(bin_count).fill(0);
319+
const bin_centers = [];
320+
for (let i = 0; i < bin_count; i++) {
321+
bin_centers.push(bin_min + bin_size * (i + 0.5));
322+
}
323+
324+
for (let i = 0; i < weights.length; i++) {
325+
const w = weights[i];
326+
const bin = Math.floor((w - bin_min) / bin_size);
327+
if (bin >= 0 && bin < bin_count) {
328+
hist[bin]++;
329+
}
267330
}
331+
332+
hist_source.data['top'] = hist;
333+
hist_source.data['bin_center'] = bin_centers;
334+
268335
source.change.emit();
336+
hist_source.change.emit();
269337
""")
338+
339+
# Link all events to callback
340+
plot.x_range.js_on_change("start", callback)
341+
plot.x_range.js_on_change("end", callback)
342+
plot.y_range.js_on_change("start", callback)
343+
plot.y_range.js_on_change("end", callback)
270344
hour_selector.js_on_change('value', callback)
271345

272-
layout = column(row(plot, hour_selector))
346+
layout = row(plot, hist_plot, column(hour_selector))
273347
save(layout, filename=output_html_path, title=title)
274348

275349
def build_abs_diff_graph(dataset, link_flows, sensor_idxs, target_flows, hour, title, save_path):
@@ -374,7 +448,9 @@ def main(args):
374448
target_graph = dataset.target_graph
375449
sensor_idxs = dataset.sensor_idxs
376450

377-
flows = torch.load(results_path / "best_flows.pt")
451+
# flows = torch.load(results_path / "best_flows.pt")
452+
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
453+
flows = torch.load(results_path / "best_flows.pt", map_location=device)
378454
link_flows = flows["LinkFlows"].to("cpu")
379455
target_flows = target_graph.edge_attr
380456

0 commit comments

Comments
 (0)