diff --git a/src/evox/algorithms/so/es_variants/adam_step.py b/src/evox/algorithms/so/es_variants/adam_step.py index cd6bb381..2a3d1bc5 100644 --- a/src/evox/algorithms/so/es_variants/adam_step.py +++ b/src/evox/algorithms/so/es_variants/adam_step.py @@ -6,21 +6,20 @@ def adam_single_tensor( grad: torch.Tensor, exp_avg: torch.Tensor, exp_avg_sq: torch.Tensor, - beta1: torch.Tensor = torch.tensor(0.9), - beta2: torch.Tensor = torch.tensor(0.999), - lr: torch.Tensor = torch.tensor(1e-3), - weight_decay: torch.Tensor = torch.tensor(0), - eps: torch.Tensor = torch.tensor(1e-8), + beta1: float | torch.Tensor = 0.9, + beta2: float | torch.Tensor = 0.999, + lr: float | torch.Tensor = 1e-3, + weight_decay: float | torch.Tensor = 0, + eps: float | torch.Tensor = 1e-8, + decouple_weight_decay: bool = False, ): # weight decay - if weight_decay != 0: - weight_decay = weight_decay.to(param.device) + # if weight_decay != 0: + if decouple_weight_decay: + param = param * (1 - weight_decay * lr) + else: grad = grad + weight_decay * param # Decay the first and second moment running average coefficient - beta1 = beta1.to(param.device) - beta2 = beta2.to(param.device) - lr = lr.to(param.device) - eps = eps.to(param.device) exp_avg = torch.lerp(exp_avg, grad, 1 - beta1) exp_avg_sq = exp_avg_sq * beta2 + grad * grad.conj() * (1 - beta2) denom = exp_avg_sq.sqrt() + eps diff --git a/src/evox/algorithms/so/pso_variants/utils.py b/src/evox/algorithms/so/pso_variants/utils.py index e74059c6..1d69a8d9 100644 --- a/src/evox/algorithms/so/pso_variants/utils.py +++ b/src/evox/algorithms/so/pso_variants/utils.py @@ -18,7 +18,7 @@ def min_by( values = torch.cat(values, dim=0) keys = torch.cat(keys, dim=0) min_index = torch.argmin(keys) - return values[min_index[None]][0], keys[min_index[None]][0] + return values[min_index], keys[min_index] def random_select_from_mask(mask: torch.Tensor, count: int, dim: int = -1) -> torch.Tensor: diff --git a/src/evox/core/module.py b/src/evox/core/module.py index 362a8973..b543a950 100644 --- a/src/evox/core/module.py +++ b/src/evox/core/module.py @@ -10,7 +10,7 @@ from functools import wraps -from typing import Callable, Dict, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Optional, Sequence, TypeVar, Union import torch import torch.nn as nn @@ -45,7 +45,9 @@ def Parameter( ) -def Mutable(value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> torch.Tensor: +def Mutable( + value: torch.Tensor, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None +) -> torch.Tensor: """Wraps a value as a mutable tensor. This is often used to label a value in an algorithm as a mutable tensor that may changes during iteration(s). @@ -84,6 +86,24 @@ def eval(self): assert False, "`ModuleBase.eval()` shall never be invoked to prevent ambiguity." +def _transform_scalar_index(ori_index: Sequence[Any | torch.Tensor] | Any | torch.Tensor): + if isinstance(ori_index, Sequence): + index = tuple(ori_index) + else: + index = (ori_index,) + any_scalar_tensor = False + new_index = [] + for idx in index: + if isinstance(idx, torch.Tensor) and idx.ndim == 0: + new_index.append(idx[None]) + any_scalar_tensor = True + else: + new_index.append(idx) + if not isinstance(ori_index, Sequence): + new_index = new_index[0] + return new_index, any_scalar_tensor + + # We still need a fix for the vmap # related issue: https://github.com/pytorch/pytorch/issues/124423 class TransformGetSetItemToIndex(TorchFunctionMode): @@ -95,15 +115,19 @@ class TransformGetSetItemToIndex(TorchFunctionMode): # That is, we convert A[idx] to A[idx[None]][0], A[idx] += 1 to A[idx[None]] += 1. # This is a temporary solution until the issue is fixed in PyTorch. def __torch_function__(self, func, types, args, kwargs=None): + # A[idx] if func == torch.Tensor.__getitem__: x, index = args - if isinstance(index, torch.Tensor) and index.ndim == 0: - return func(x, index[None], **(kwargs or {}))[0] - # return torch.index_select(x, 0, index) + new_index, any_scalar = _transform_scalar_index(index) + x = func(x, new_index, **(kwargs or {})) + if any_scalar: + x = x.squeeze(0) + return x + # A[idx] = value elif func == torch.Tensor.__setitem__: x, index, value = args - if isinstance(index, torch.Tensor) and index.ndim == 0: - return func(x, index[None], value, **(kwargs or {})) + new_index, _ = _transform_scalar_index(index) + return func(x, new_index, value, **(kwargs or {})) return func(*args, **(kwargs or {})) diff --git a/src/evox/problems/hpo_wrapper.py b/src/evox/problems/hpo_wrapper.py index 861261d2..c3279839 100644 --- a/src/evox/problems/hpo_wrapper.py +++ b/src/evox/problems/hpo_wrapper.py @@ -187,7 +187,7 @@ def __init__( :param iterations: The number of iterations to be executed in the optimization process. :param num_instances: The number of instances to be executed in parallel in the optimization process, i.e., the population size of the outer algorithm. - :param workflow: The workflow to be used in the optimization process. Must be wrapped by `core.jit_class`. + :param workflow: The workflow to be used in the optimization process. :param num_repeats: The number of times to repeat the evaluation process for each instance. Defaults to 1. :param copy_init_state: Whether to copy the initial state of the workflow for each evaluation. Defaults to `True`. If your workflow contains operations that IN-PLACE modify the tensor(s) in initial state, this should be set to `True`. Otherwise, you can set it to `False` to save memory. """ diff --git a/unit_test/algorithms/test_moea.py b/unit_test/algorithms/test_moea.py index a3091023..fcb4296e 100644 --- a/unit_test/algorithms/test_moea.py +++ b/unit_test/algorithms/test_moea.py @@ -1,4 +1,4 @@ -from unittest import TestCase, skip +from unittest import TestCase, skipIf import torch @@ -74,7 +74,10 @@ def test_moead(self): self.run_algorithm(algo) self.run_compiled_algorithm(algo) - @skip("Torch 2.7 bug when running on non-AVX512 CPU: https://github.com/pytorch/pytorch/issues/152172") + @skipIf( + torch.__version__.startswith("2.7."), + "Torch 2.7 bug when running on non-AVX512 CPU: https://github.com/pytorch/pytorch/issues/152172", + ) def test_hype(self): algo = HypE(pop_size=self.pop_size, n_objs=3, lb=self.lb, ub=self.ub) self.run_algorithm(algo) diff --git a/unit_test/core/test_index_fix.py b/unit_test/core/test_index_fix.py new file mode 100644 index 00000000..b928a91a --- /dev/null +++ b/unit_test/core/test_index_fix.py @@ -0,0 +1,38 @@ +import unittest + +import torch + +from evox.core import compile, vmap + + +class TestIndexFix(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu") + + def test_get(self): + def _get_vmap(x: torch.Tensor, index: torch.Tensor, range: torch.Tensor): + return x[range, index, index] + + mapped = compile(vmap(_get_vmap, in_dims=(0, 0, None))) + x = torch.rand(3, 2, 5, 5) + indices = torch.randint(5, (3,)) + print(x) + print(indices) + x = mapped(x, indices, torch.arange(2)) + print(x) + + def test_set(self): + def _set_vmap(x: torch.Tensor, index: torch.Tensor, range: torch.Tensor, value: torch.Tensor): + x[range, index, index] = value + return x + + mapped = compile(vmap(_set_vmap, in_dims=(0, 0, None, 0)), fullgraph=True) + x = torch.rand(3, 2, 5, 5) + indices = torch.randint(5, (3,)) + values = torch.rand(3, 2) + print(x) + print(indices) + print(values) + x = mapped(x, indices, torch.arange(2), values) + print(x) diff --git a/unit_test/operators/test_rvea_selection.py b/unit_test/operators/test_rvea_selection.py index e50e0504..f537eb7f 100644 --- a/unit_test/operators/test_rvea_selection.py +++ b/unit_test/operators/test_rvea_selection.py @@ -7,6 +7,7 @@ class TestRefVecGuided(unittest.TestCase): def setUp(self): + torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu") self.n, self.m, self.nv = 12, 4, 5 self.x = torch.randn(self.n, 10) self.f = torch.randn(self.n, self.m) diff --git a/unit_test/operators/test_sbx_and_pm.py b/unit_test/operators/test_sbx_and_pm.py index 1e40c395..0129f032 100644 --- a/unit_test/operators/test_sbx_and_pm.py +++ b/unit_test/operators/test_sbx_and_pm.py @@ -8,6 +8,7 @@ class TestOperators(TestCase): def setUp(self): + torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu") self.n_individuals = 9 self.n_genes = 10 self.x = torch.randn(self.n_individuals, self.n_genes) diff --git a/unit_test/problems/test_basic.py b/unit_test/problems/test_basic.py index 00c80f4d..85e7eaea 100644 --- a/unit_test/problems/test_basic.py +++ b/unit_test/problems/test_basic.py @@ -7,6 +7,7 @@ class TestBasic(unittest.TestCase): def setUp(self, dimensions: list = [10], pop_size: int = 7): + torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu") self.dimensions = dimensions self.pop_size = pop_size self.problems = [ @@ -24,4 +25,6 @@ def test_evaluate(self): problem = problem(shift=torch.rand(dimension), affine=torch.rand(dimension, dimension)) population = torch.randn(self.pop_size, dimension) fitness = problem.evaluate(population) - print(f"The fitness of {problem.__class__.__name__} function with {dimension} dimension is {fitness}") + print( + f"The fitness of {problem.__class__.__name__} function with {dimension} dimension is {fitness}" + ) diff --git a/unit_test/problems/test_cec2022.py b/unit_test/problems/test_cec2022.py index aca9335c..4949254d 100644 --- a/unit_test/problems/test_cec2022.py +++ b/unit_test/problems/test_cec2022.py @@ -12,6 +12,7 @@ def setUp(self): self.dimensionality = [2, 10, 20] self.pop_size = 100 torch.manual_seed(42) + torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu") def test_evaluate(self): for i in range(1, 13): diff --git a/unit_test/problems/test_dtlz.py b/unit_test/problems/test_dtlz.py index c5c5f325..0c0738c1 100644 --- a/unit_test/problems/test_dtlz.py +++ b/unit_test/problems/test_dtlz.py @@ -7,6 +7,7 @@ class TestDTLZ(TestCase): def setUp(self): + torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu") d = 12 m = 3 self.pro = [ diff --git a/unit_test/problems/test_hpo_wrapper.py b/unit_test/problems/test_hpo_wrapper.py index debd09b7..6f309e71 100644 --- a/unit_test/problems/test_hpo_wrapper.py +++ b/unit_test/problems/test_hpo_wrapper.py @@ -20,7 +20,9 @@ def evaluate(self, x: torch.Tensor): class BasicAlgorithm(Algorithm): def __init__(self, pop_size: int, lb: torch.Tensor, ub: torch.Tensor, device: torch.device | None = None): super().__init__() - assert lb.ndim == 1 and ub.ndim == 1, f"Lower and upper bounds shall have ndim of 1, got {lb.ndim} and {ub.ndim}" + assert lb.ndim == 1 and ub.ndim == 1, ( + f"Lower and upper bounds shall have ndim of 1, got {lb.ndim} and {ub.ndim}" + ) assert lb.shape == ub.shape, f"Lower and upper bounds shall have same shape, got {lb.ndim} and {ub.ndim}" device = torch.get_default_device() if device is None else device self.pop_size = pop_size @@ -46,19 +48,25 @@ def setUp(self): self.prob = BasicProblem() self.monitor = HPOFitnessMonitor() self.workflow = StdWorkflow(self.algo, self.prob, monitor=self.monitor) - self.hpo_prob = HPOProblemWrapper(iterations=9, num_instances=7, workflow=self.workflow, copy_init_state=True) + self.hpo_prob = HPOProblemWrapper( + iterations=9, num_instances=7, workflow=self.workflow, copy_init_state=True + ) self.algo_mo = BasicAlgorithm(10, -10 * torch.ones(2), 10 * torch.ones(2)) self.prob_mo = DTLZ1(2, 2) self.monitor_mo = HPOFitnessMonitor(multi_obj_metric=lambda f: igd(f, self.prob_mo.pf())) self.workflow_mo = StdWorkflow(self.algo_mo, self.prob_mo, monitor=self.monitor_mo) - self.hpo_prob_mo = HPOProblemWrapper(iterations=9, num_instances=7, workflow=self.workflow_mo, copy_init_state=True) + self.hpo_prob_mo = HPOProblemWrapper( + iterations=9, num_instances=7, workflow=self.workflow_mo, copy_init_state=True + ) self.algo_mo2 = BasicAlgorithm(10, -10 * torch.ones(2), 10 * torch.ones(2)) self.prob_mo2 = DTLZ1(2, 2) self.monitor_mo2 = HPOFitnessMonitor(multi_obj_metric=lambda f: igd(f, self.prob_mo2.pf())) self.workflow_mo2 = StdWorkflow(self.algo_mo2, self.prob_mo2, monitor=self.monitor_mo2) - self.hpo_prob_mo2 = HPOProblemWrapper(iterations=9, num_instances=7, workflow=self.workflow_mo2, copy_init_state=True) + self.hpo_prob_mo2 = HPOProblemWrapper( + iterations=9, num_instances=7, workflow=self.workflow_mo2, copy_init_state=True + ) def test_get_init_params(self): params = self.hpo_prob.get_init_params() diff --git a/unit_test/problems/test_supervised_learning.py b/unit_test/problems/test_supervised_learning.py index b7174848..094d4494 100644 --- a/unit_test/problems/test_supervised_learning.py +++ b/unit_test/problems/test_supervised_learning.py @@ -89,7 +89,9 @@ def setUp(self): for inputs, labels in self.train_loader ] ) - self.pre_test_loader = tuple([(inputs.to(self.device), labels.to(self.device)) for inputs, labels in self.test_loader]) + self.pre_test_loader = tuple( + [(inputs.to(self.device), labels.to(self.device)) for inputs, labels in self.test_loader] + ) self.model = SampleCNN().to(self.device) self.adapter = ParamsAndVector(dummy_model=self.model) diff --git a/unit_test/workflows/test_std_workflow.py b/unit_test/workflows/test_std_workflow.py index 3831ba4c..d5c4a35b 100644 --- a/unit_test/workflows/test_std_workflow.py +++ b/unit_test/workflows/test_std_workflow.py @@ -24,7 +24,9 @@ def evaluate(self, pop: torch.Tensor): class BasicAlgorithm(Algorithm): def __init__(self, pop_size: int, lb: torch.Tensor, ub: torch.Tensor): super().__init__() - assert lb.ndim == 1 and ub.ndim == 1, f"Lower and upper bounds shall have ndim of 1, got {lb.ndim} and {ub.ndim}" + assert lb.ndim == 1 and ub.ndim == 1, ( + f"Lower and upper bounds shall have ndim of 1, got {lb.ndim} and {ub.ndim}" + ) assert lb.shape == ub.shape, f"Lower and upper bounds shall have same shape, got {lb.ndim} and {ub.ndim}" self.pop_size = pop_size self.lb = lb