1
1
import jax .numpy as jnp
2
2
3
- from evox import use_state
4
-
5
3
6
4
def plot_dec_space (
7
5
population_history ,
@@ -64,7 +62,6 @@ def plot_dec_space(
64
62
sliders = [
65
63
{
66
64
"currentvalue" : {"prefix" : "Generation: " },
67
- "pad" : {"t" : 50 },
68
65
"pad" : {"b" : 1 , "t" : 10 },
69
66
"len" : 0.8 ,
70
67
"x" : 0.2 ,
@@ -82,7 +79,6 @@ def plot_dec_space(
82
79
"x" : 1 ,
83
80
"y" : 1 ,
84
81
"xanchor" : "auto" ,
85
- "xanchor" : "auto" ,
86
82
},
87
83
margin = {"l" : 0 , "r" : 0 , "t" : 0 , "b" : 0 },
88
84
sliders = sliders ,
@@ -115,7 +111,6 @@ def plot_dec_space(
115
111
"frame" : {"duration" : 0 , "redraw" : False },
116
112
"mode" : "immediate" ,
117
113
"transition" : {"duration" : 0 },
118
- "mode" : "immediate" ,
119
114
},
120
115
],
121
116
"label" : "Pause" ,
@@ -171,7 +166,6 @@ def plot_obj_space_1d_no_animation(fitness_history, **kwargs):
171
166
"x" : 1 ,
172
167
"y" : 1 ,
173
168
"xanchor" : "auto" ,
174
- "xanchor" : "auto" ,
175
169
},
176
170
margin = {"l" : 0 , "r" : 0 , "t" : 0 , "b" : 0 },
177
171
),
@@ -268,7 +262,6 @@ def plot_obj_space_1d_animation(fitness_history, **kwargs):
268
262
"x" : 1 ,
269
263
"y" : 1 ,
270
264
"xanchor" : "auto" ,
271
- "xanchor" : "auto" ,
272
265
},
273
266
margin = {"l" : 0 , "r" : 0 , "t" : 0 , "b" : 0 },
274
267
sliders = sliders ,
@@ -397,7 +390,6 @@ def plot_obj_space_2d(fitness_history, problem_pf=None, sort_points=False, **kwa
397
390
"x" : 1 ,
398
391
"y" : 1 ,
399
392
"xanchor" : "auto" ,
400
- "xanchor" : "auto" ,
401
393
},
402
394
margin = {"l" : 0 , "r" : 0 , "t" : 0 , "b" : 0 },
403
395
sliders = sliders ,
@@ -479,6 +471,17 @@ def plot_obj_space_3d(fitness_history, sort_points=False, problem_pf=None, **kwa
479
471
480
472
frames = []
481
473
steps = []
474
+
475
+ if problem_pf is not None :
476
+ pf_scatter = go .Scatter3d (
477
+ x = problem_pf [:, 0 ],
478
+ y = problem_pf [:, 1 ],
479
+ z = problem_pf [:, 2 ],
480
+ mode = "markers" ,
481
+ marker = {"color" : "#FFA15A" , "size" : 2 },
482
+ name = "Pareto Front" ,
483
+ )
484
+
482
485
for i , fit in enumerate (fitness_history ):
483
486
# it will make the animation look nicer
484
487
if sort_points :
@@ -492,7 +495,10 @@ def plot_obj_space_3d(fitness_history, sort_points=False, problem_pf=None, **kwa
492
495
mode = "markers" ,
493
496
marker = {"color" : "#636EFA" , "size" : 2 },
494
497
)
495
- frames .append (go .Frame (data = [scatter ], name = str (i )))
498
+ if problem_pf is not None :
499
+ frames .append (go .Frame (data = [pf_scatter , scatter ], name = str (i )))
500
+ else :
501
+ frames .append (go .Frame (data = [scatter ], name = str (i )))
496
502
497
503
step = {
498
504
"label" : i ,
0 commit comments