Skip to content

Commit 3a554be

Browse files
committed
dev: enable plotting 3d obj space with pf
1 parent 9d592fc commit 3a554be

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/evox/vis_tools/plot.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import jax.numpy as jnp
22

3-
from evox import use_state
4-
53

64
def plot_dec_space(
75
population_history,
@@ -64,7 +62,6 @@ def plot_dec_space(
6462
sliders = [
6563
{
6664
"currentvalue": {"prefix": "Generation: "},
67-
"pad": {"t": 50},
6865
"pad": {"b": 1, "t": 10},
6966
"len": 0.8,
7067
"x": 0.2,
@@ -82,7 +79,6 @@ def plot_dec_space(
8279
"x": 1,
8380
"y": 1,
8481
"xanchor": "auto",
85-
"xanchor": "auto",
8682
},
8783
margin={"l": 0, "r": 0, "t": 0, "b": 0},
8884
sliders=sliders,
@@ -115,7 +111,6 @@ def plot_dec_space(
115111
"frame": {"duration": 0, "redraw": False},
116112
"mode": "immediate",
117113
"transition": {"duration": 0},
118-
"mode": "immediate",
119114
},
120115
],
121116
"label": "Pause",
@@ -171,7 +166,6 @@ def plot_obj_space_1d_no_animation(fitness_history, **kwargs):
171166
"x": 1,
172167
"y": 1,
173168
"xanchor": "auto",
174-
"xanchor": "auto",
175169
},
176170
margin={"l": 0, "r": 0, "t": 0, "b": 0},
177171
),
@@ -268,7 +262,6 @@ def plot_obj_space_1d_animation(fitness_history, **kwargs):
268262
"x": 1,
269263
"y": 1,
270264
"xanchor": "auto",
271-
"xanchor": "auto",
272265
},
273266
margin={"l": 0, "r": 0, "t": 0, "b": 0},
274267
sliders=sliders,
@@ -397,7 +390,6 @@ def plot_obj_space_2d(fitness_history, problem_pf=None, sort_points=False, **kwa
397390
"x": 1,
398391
"y": 1,
399392
"xanchor": "auto",
400-
"xanchor": "auto",
401393
},
402394
margin={"l": 0, "r": 0, "t": 0, "b": 0},
403395
sliders=sliders,
@@ -479,6 +471,17 @@ def plot_obj_space_3d(fitness_history, sort_points=False, problem_pf=None, **kwa
479471

480472
frames = []
481473
steps = []
474+
475+
if problem_pf is not None:
476+
pf_scatter = go.Scatter3d(
477+
x=problem_pf[:, 0],
478+
y=problem_pf[:, 1],
479+
z=problem_pf[:, 2],
480+
mode="markers",
481+
marker={"color": "#FFA15A", "size": 2},
482+
name="Pareto Front",
483+
)
484+
482485
for i, fit in enumerate(fitness_history):
483486
# it will make the animation look nicer
484487
if sort_points:
@@ -492,7 +495,10 @@ def plot_obj_space_3d(fitness_history, sort_points=False, problem_pf=None, **kwa
492495
mode="markers",
493496
marker={"color": "#636EFA", "size": 2},
494497
)
495-
frames.append(go.Frame(data=[scatter], name=str(i)))
498+
if problem_pf is not None:
499+
frames.append(go.Frame(data=[pf_scatter, scatter], name=str(i)))
500+
else:
501+
frames.append(go.Frame(data=[scatter], name=str(i)))
496502

497503
step = {
498504
"label": i,

0 commit comments

Comments
 (0)