Skip to content

Commit 315c568

Browse files
committed
Add function to create interactive HTML plot for sensor edges
1 parent 89e7da4 commit 315c568

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

matsimAI/scripts/analysis.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,131 @@ def load_clusters(cluster_path, dataset):
127127
clusters[key] = vals
128128
return clusters
129129

130+
def create_html_plot(G_sensor, G_normal, dataset, df, hour_count):
131+
132+
from bokeh.plotting import figure, output_file, save
133+
from bokeh.plotting import figure, show
134+
from bokeh.models import ColumnDataSource, LinearColorMapper, ColorBar, HoverTool, CustomJS, Select
135+
from bokeh.layouts import row, column
136+
from bokeh.io import output_notebook
137+
# output_notebook()
138+
139+
# Positions
140+
pos = {i: (dataset.target_graph.pos[i][0].item(), dataset.target_graph.pos[i][1].item()) for i in range(len(dataset.target_graph.pos))}
141+
142+
# Build Bokeh Data Source
143+
edge_start = []
144+
edge_end = []
145+
x0s, y0s, x1s, y1s = [], [], [], []
146+
weight_all_hours = []
147+
other_attr_all_hours = {col: [] for col in df.columns if col != 'Link Id'}
148+
149+
print("Loop Sensor Edges:", G_sensor.edges(data=True))
150+
for (u, v, data) in tqdm(G_sensor.edges(data=True), desc="Processing Sensor Edges", total=len(G_sensor.edges)):
151+
edge_start.append(u)
152+
edge_end.append(v)
153+
x0, y0 = pos[u]
154+
x1, y1 = pos[v]
155+
x0s.append(x0)
156+
y0s.append(y0)
157+
x1s.append(x1)
158+
y1s.append(y1)
159+
160+
# For each attribute across hours
161+
weights = []
162+
attr_columns = {col: [] for col in df.columns if col != 'Link Id'}
163+
if not 'all_attrs' in data:
164+
continue
165+
for hour_attr in data['all_attrs']:
166+
if hour_attr: # if exists
167+
weights.append(hour_attr['Normalized Relative Error'])
168+
for col in attr_columns:
169+
attr_columns[col].append(hour_attr[col])
170+
else:
171+
weights.append(0)
172+
for col in attr_columns:
173+
attr_columns[col].append(0)
174+
175+
weight_all_hours.append(weights)
176+
for col in attr_columns:
177+
other_attr_all_hours[col].append(attr_columns[col])
178+
179+
# Create ColumnDataSource
180+
sensors_edge_source = ColumnDataSource(data=dict(
181+
x0=x0s, y0=y0s, x1=x1s, y1=y1s,
182+
weight=[w[0] for w in weight_all_hours], # Initially hour 0
183+
weight_all_hours=weight_all_hours,
184+
**{col: [other_attr_all_hours[col][i][0] for i in range(len(edge_start))] for col in other_attr_all_hours},
185+
**{f"{col}_all_hours": other_attr_all_hours[col] for col in other_attr_all_hours}
186+
))
187+
normal_edge_source = ColumnDataSource(data=dict(
188+
x0=[pos[u][0] for u, v in G_normal.edges()],
189+
y0=[pos[u][1] for u, v in G_normal.edges()],
190+
x1=[pos[v][0] for u, v in G_normal.edges()],
191+
y1=[pos[v][1] for u, v in G_normal.edges()],
192+
))
193+
194+
# === CREATE FIGURE ===
195+
196+
color_mapper = LinearColorMapper(palette="RdYlGn11", low=0, high=1)
197+
198+
# plot = figure(title="Normalized Relative Error", width=800, height=600, tools="pan,wheel_zoom,box_zoom,reset,save")
199+
plot = figure(title="Normalized Relative Error", width=800, height=600, tools="pan,wheel_zoom,box_zoom,reset,save")
200+
201+
# Draw edges for sensors
202+
plot.segment('x0', 'y0', 'x1', 'y1', source=sensors_edge_source,
203+
line_width=10, color={'field': 'weight', 'transform': color_mapper})
204+
# # Draw edges for normal
205+
plot.segment('x0', 'y0', 'x1', 'y1', source=sensors_edge_source,
206+
line_width=1, color="black")
207+
plot.segment('x0', 'y0', 'x1', 'y1', source=normal_edge_source,
208+
line_width=1, color="black")
209+
# Draw nodes
210+
node_x = [pos[i][0] for i in pos]
211+
node_y = [pos[i][1] for i in pos]
212+
plot.scatter(node_x, node_y, size=5, color="black", alpha=0.2)
213+
214+
# Add hover tool
215+
tooltips = [(col, f"@{{{col}}}") for col in df.columns if col != 'Link Id']
216+
tooltips.insert(0, ("Weight", "@weight"))
217+
218+
hover = HoverTool(tooltips=tooltips)
219+
plot.add_tools(hover)
220+
221+
# Colorbar
222+
color_bar = ColorBar(color_mapper=color_mapper, location=(0,0))
223+
plot.add_layout(color_bar, 'right')
224+
225+
# === CREATE DROPDOWN TO SELECT HOUR ===
226+
227+
fields = [col for col in df.columns if col != 'Link Id']
228+
229+
hour_selector = Select(title="Select Hour", value="1", options=[str(i+1) for i in range(hour_count)])
230+
231+
callback = CustomJS(args=dict(source=sensors_edge_source, hour_selector=hour_selector), code=f"""
232+
var data = source.data;
233+
var hour = parseInt(hour_selector.value-1);
234+
var n = data['weight_all_hours'].length;
235+
236+
for (var i = 0; i < n; i++) {{
237+
data['weight'][i] = data['weight_all_hours'][i][hour];
238+
{"".join([f"data['{field}'][i] = data['{field}_all_hours'][i][hour];" for field in fields])}
239+
}}
240+
source.change.emit();
241+
""")
242+
243+
hour_selector.js_on_change('value', callback)
244+
245+
# === LAYOUT AND SHOW ===
246+
247+
layout = column(row(plot, hour_selector))
248+
show(layout)
249+
250+
# Save the plot as an HTML file
251+
save(layout)
252+
253+
print("Plot saved as 'sensor_edges_graph.html'")
254+
130255
def build_abs_diff_graph(dataset, link_flows, sensor_idxs, target_flows, hour, title, save_path):
131256
hour_idx = hour
132257
hour_val = hour + 1

0 commit comments

Comments
 (0)