Skip to content

Commit e957775

Browse files
committed
expand to any multiple of 3 for pop size, select out bottom third and replenish with rest of the paired couples with approximate fitness uniform strategy (Hutter et al.)
1 parent d930538 commit e957775

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

vector_quantize_pytorch/evo_vq.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import torch
22
from torch import cat
3+
from torch.nn import Module
4+
5+
from einops import reduce
36

47
# helpers
58

@@ -9,23 +12,29 @@ def exists(v):
912
def default(v, d):
1013
return v if exists(v) else d
1114

15+
def divisible_by(num, den):
16+
return (num % den) == 0
17+
1218
# evolution - start with the most minimal, a population of 3
1319
# 1 is natural selected out, the other 2 performs crossover
1420

1521
def select_and_crossover(
1622
codes, # Float[3 ...]
1723
fitness, # Float[3]
1824
):
19-
assert codes.shape[0] == fitness.shape[0] == 3
25+
pop_size = codes.shape[0]
26+
assert pop_size == fitness.shape[0]
27+
assert divisible_by(pop_size, 3)
2028

2129
# selection
2230

23-
top2 = fitness.topk(2, dim = -1).indices
24-
codes = codes[top2]
31+
sorted_indices = fitness.sort().indices
32+
selected = sorted_indices[(pop_size // 3):] # bottom third wins darwin awards
33+
codes = codes[selected]
2534

2635
# crossover
2736

28-
child = codes.mean(dim = 0, keepdim = True)
37+
child = reduce(codes, '(two paired) ... -> paired ...', 'mean', two = 2)
2938
codes = cat((codes, child))
3039

3140
return codes

0 commit comments

Comments
 (0)