Skip to content

Commit f565d2e

Browse files
Merge pull request #242 from EMI-Group/dev-index-fix
Several fixes about indexing and other bugs
2 parents a98b2ba + 66bfd45 commit f565d2e

File tree

14 files changed

+112
-29
lines changed

14 files changed

+112
-29
lines changed

src/evox/algorithms/so/es_variants/adam_step.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,20 @@ def adam_single_tensor(
66
grad: torch.Tensor,
77
exp_avg: torch.Tensor,
88
exp_avg_sq: torch.Tensor,
9-
beta1: torch.Tensor = torch.tensor(0.9),
10-
beta2: torch.Tensor = torch.tensor(0.999),
11-
lr: torch.Tensor = torch.tensor(1e-3),
12-
weight_decay: torch.Tensor = torch.tensor(0),
13-
eps: torch.Tensor = torch.tensor(1e-8),
9+
beta1: float | torch.Tensor = 0.9,
10+
beta2: float | torch.Tensor = 0.999,
11+
lr: float | torch.Tensor = 1e-3,
12+
weight_decay: float | torch.Tensor = 0,
13+
eps: float | torch.Tensor = 1e-8,
14+
decouple_weight_decay: bool = False,
1415
):
1516
# weight decay
16-
if weight_decay != 0:
17-
weight_decay = weight_decay.to(param.device)
17+
# if weight_decay != 0:
18+
if decouple_weight_decay:
19+
param = param * (1 - weight_decay * lr)
20+
else:
1821
grad = grad + weight_decay * param
1922
# Decay the first and second moment running average coefficient
20-
beta1 = beta1.to(param.device)
21-
beta2 = beta2.to(param.device)
22-
lr = lr.to(param.device)
23-
eps = eps.to(param.device)
2423
exp_avg = torch.lerp(exp_avg, grad, 1 - beta1)
2524
exp_avg_sq = exp_avg_sq * beta2 + grad * grad.conj() * (1 - beta2)
2625
denom = exp_avg_sq.sqrt() + eps

src/evox/algorithms/so/pso_variants/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def min_by(
1818
values = torch.cat(values, dim=0)
1919
keys = torch.cat(keys, dim=0)
2020
min_index = torch.argmin(keys)
21-
return values[min_index[None]][0], keys[min_index[None]][0]
21+
return values[min_index], keys[min_index]
2222

2323

2424
def random_select_from_mask(mask: torch.Tensor, count: int, dim: int = -1) -> torch.Tensor:

src/evox/core/module.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
from functools import wraps
13-
from typing import Callable, Dict, Optional, TypeVar, Union
13+
from typing import Any, Callable, Dict, Optional, Sequence, TypeVar, Union
1414

1515
import torch
1616
import torch.nn as nn
@@ -45,7 +45,9 @@ def Parameter(
4545
)
4646

4747

48-
def Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> torch.Tensor:
48+
def Mutable(
49+
value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None
50+
) -> torch.Tensor:
4951
"""Wraps a value as a mutable tensor.
5052
This is often used to label a value in an algorithm as a mutable tensor that may changes during iteration(s).
5153
@@ -84,6 +86,24 @@ def eval(self):
8486
assert False, "`ModuleBase.eval()` shall never be invoked to prevent ambiguity."
8587

8688

89+
def _transform_scalar_index(ori_index: Sequence[Any | torch.Tensor] | Any | torch.Tensor):
90+
if isinstance(ori_index, Sequence):
91+
index = tuple(ori_index)
92+
else:
93+
index = (ori_index,)
94+
any_scalar_tensor = False
95+
new_index = []
96+
for idx in index:
97+
if isinstance(idx, torch.Tensor) and idx.ndim == 0:
98+
new_index.append(idx[None])
99+
any_scalar_tensor = True
100+
else:
101+
new_index.append(idx)
102+
if not isinstance(ori_index, Sequence):
103+
new_index = new_index[0]
104+
return new_index, any_scalar_tensor
105+
106+
87107
# We still need a fix for the vmap
88108
# related issue: https://github.com/pytorch/pytorch/issues/124423
89109
class TransformGetSetItemToIndex(TorchFunctionMode):
@@ -95,15 +115,19 @@ class TransformGetSetItemToIndex(TorchFunctionMode):
95115
# That is, we convert A[idx] to A[idx[None]][0], A[idx] += 1 to A[idx[None]] += 1.
96116
# This is a temporary solution until the issue is fixed in PyTorch.
97117
def __torch_function__(self, func, types, args, kwargs=None):
118+
# A[idx]
98119
if func == torch.Tensor.__getitem__:
99120
x, index = args
100-
if isinstance(index, torch.Tensor) and index.ndim == 0:
101-
return func(x, index[None], **(kwargs or {}))[0]
102-
# return torch.index_select(x, 0, index)
121+
new_index, any_scalar = _transform_scalar_index(index)
122+
x = func(x, new_index, **(kwargs or {}))
123+
if any_scalar:
124+
x = x.squeeze(0)
125+
return x
126+
# A[idx] = value
103127
elif func == torch.Tensor.__setitem__:
104128
x, index, value = args
105-
if isinstance(index, torch.Tensor) and index.ndim == 0:
106-
return func(x, index[None], value, **(kwargs or {}))
129+
new_index, _ = _transform_scalar_index(index)
130+
return func(x, new_index, value, **(kwargs or {}))
107131

108132
return func(*args, **(kwargs or {}))
109133

src/evox/problems/hpo_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def __init__(
187187
188188
:param iterations: The number of iterations to be executed in the optimization process.
189189
:param num_instances: The number of instances to be executed in parallel in the optimization process, i.e., the population size of the outer algorithm.
190-
:param workflow: The workflow to be used in the optimization process. Must be wrapped by `core.jit_class`.
190+
:param workflow: The workflow to be used in the optimization process.
191191
:param num_repeats: The number of times to repeat the evaluation process for each instance. Defaults to 1.
192192
:param copy_init_state: Whether to copy the initial state of the workflow for each evaluation. Defaults to `True`. If your workflow contains operations that IN-PLACE modify the tensor(s) in initial state, this should be set to `True`. Otherwise, you can set it to `False` to save memory.
193193
"""

unit_test/algorithms/test_moea.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest import TestCase, skip
1+
from unittest import TestCase, skipIf
22

33
import torch
44

@@ -74,7 +74,10 @@ def test_moead(self):
7474
self.run_algorithm(algo)
7575
self.run_compiled_algorithm(algo)
7676

77-
@skip("Torch 2.7 bug when running on non-AVX512 CPU: https://github.com/pytorch/pytorch/issues/152172")
77+
@skipIf(
78+
torch.__version__.startswith("2.7."),
79+
"Torch 2.7 bug when running on non-AVX512 CPU: https://github.com/pytorch/pytorch/issues/152172",
80+
)
7881
def test_hype(self):
7982
algo = HypE(pop_size=self.pop_size, n_objs=3, lb=self.lb, ub=self.ub)
8083
self.run_algorithm(algo)

unit_test/core/test_index_fix.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
3+
import torch
4+
5+
from evox.core import compile, vmap
6+
7+
8+
class TestIndexFix(unittest.TestCase):
9+
def setUp(self):
10+
torch.manual_seed(42)
11+
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
12+
13+
def test_get(self):
14+
def _get_vmap(x: torch.Tensor, index: torch.Tensor, range: torch.Tensor):
15+
return x[range, index, index]
16+
17+
mapped = compile(vmap(_get_vmap, in_dims=(0, 0, None)))
18+
x = torch.rand(3, 2, 5, 5)
19+
indices = torch.randint(5, (3,))
20+
print(x)
21+
print(indices)
22+
x = mapped(x, indices, torch.arange(2))
23+
print(x)
24+
25+
def test_set(self):
26+
def _set_vmap(x: torch.Tensor, index: torch.Tensor, range: torch.Tensor, value: torch.Tensor):
27+
x[range, index, index] = value
28+
return x
29+
30+
mapped = compile(vmap(_set_vmap, in_dims=(0, 0, None, 0)), fullgraph=True)
31+
x = torch.rand(3, 2, 5, 5)
32+
indices = torch.randint(5, (3,))
33+
values = torch.rand(3, 2)
34+
print(x)
35+
print(indices)
36+
print(values)
37+
x = mapped(x, indices, torch.arange(2), values)
38+
print(x)

unit_test/operators/test_rvea_selection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
class TestRefVecGuided(unittest.TestCase):
99
def setUp(self):
10+
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
1011
self.n, self.m, self.nv = 12, 4, 5
1112
self.x = torch.randn(self.n, 10)
1213
self.f = torch.randn(self.n, self.m)

unit_test/operators/test_sbx_and_pm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
class TestOperators(TestCase):
1010
def setUp(self):
11+
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
1112
self.n_individuals = 9
1213
self.n_genes = 10
1314
self.x = torch.randn(self.n_individuals, self.n_genes)

unit_test/problems/test_basic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
class TestBasic(unittest.TestCase):
99
def setUp(self, dimensions: list = [10], pop_size: int = 7):
10+
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
1011
self.dimensions = dimensions
1112
self.pop_size = pop_size
1213
self.problems = [
@@ -24,4 +25,6 @@ def test_evaluate(self):
2425
problem = problem(shift=torch.rand(dimension), affine=torch.rand(dimension, dimension))
2526
population = torch.randn(self.pop_size, dimension)
2627
fitness = problem.evaluate(population)
27-
print(f"The fitness of {problem.__class__.__name__} function with {dimension} dimension is {fitness}")
28+
print(
29+
f"The fitness of {problem.__class__.__name__} function with {dimension} dimension is {fitness}"
30+
)

unit_test/problems/test_cec2022.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def setUp(self):
1212
self.dimensionality = [2, 10, 20]
1313
self.pop_size = 100
1414
torch.manual_seed(42)
15+
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
1516

1617
def test_evaluate(self):
1718
for i in range(1, 13):

0 commit comments

Comments
 (0)