|
1 |
| -import jax |
2 | 1 | import jax.numpy as jnp
|
3 |
| -import matplotlib.pyplot as plt |
4 | 2 | import optax
|
5 | 3 |
|
6 | 4 |
|
@@ -51,62 +49,3 @@ def schedule_fn(step):
|
51 | 49 | return lr
|
52 | 50 |
|
53 | 51 | return schedule_fn
|
54 |
| - |
55 |
| - |
56 |
| -if __name__ == "__main__": |
57 |
| - ### Testing the lr_scheduler |
58 |
| - num_iters = 50_000 |
59 |
| - warmup_steps = 500 |
60 |
| - cooldown_steps = 1_000 |
61 |
| - |
62 |
| - decay_steps = num_iters - warmup_steps - cooldown_steps |
63 |
| - |
64 |
| - base_lr = 1e-3 |
65 |
| - clip_value = 1.0 |
66 |
| - decay_rate = 0.1 |
67 |
| - |
68 |
| - lr_scheduler = warmup_exponential_decay_cooldown_scheduler( |
69 |
| - warmup_steps, base_lr, decay_steps, decay_rate, cooldown_steps, min_lr=0.0 |
70 |
| - ) |
71 |
| - |
72 |
| - ### Small training loop |
73 |
| - params = jax.random.normal(jax.random.PRNGKey(0), (3, 3)) |
74 |
| - |
75 |
| - optimizer = optax.chain( |
76 |
| - optax.clip_by_global_norm(clip_value), |
77 |
| - optax.adamw(lr_scheduler, b2=0.98, weight_decay=0.01), |
78 |
| - ) |
79 |
| - |
80 |
| - opt_state = optimizer.init(params) |
81 |
| - |
82 |
| - @jax.jit |
83 |
| - def update(params, opt_state): |
84 |
| - loss, grads = jax.value_and_grad(lambda params: jnp.sum(params**2))(params) |
85 |
| - updates, opt_state = optimizer.update(grads, opt_state, params) |
86 |
| - params = optax.apply_updates(params, updates) |
87 |
| - return params, opt_state, loss |
88 |
| - |
89 |
| - num_steps = warmup_steps + decay_steps + cooldown_steps |
90 |
| - |
91 |
| - steps = [] |
92 |
| - lrs = [] |
93 |
| - for step in range(num_steps + 1): |
94 |
| - batch = jnp.ones((3, 3)) # Dummy batch data |
95 |
| - params, opt_state, loss = update(params, opt_state) |
96 |
| - current_lr = lr_scheduler(step) |
97 |
| - if step % 1000 == 0: |
98 |
| - print(f"Step {step} | Loss: {loss} | Learning Rate: {current_lr}") |
99 |
| - steps.append(step) |
100 |
| - lrs.append(current_lr) |
101 |
| - |
102 |
| - ### Plot the LR |
103 |
| - plt.figure(figsize=(10, 6)) |
104 |
| - plt.plot(steps, lrs, label="Learning rate") |
105 |
| - |
106 |
| - plt.xlabel("Steps") |
107 |
| - plt.ylabel("Learning Rate") |
108 |
| - plt.title("Warmup, Exponential Decay, and Cooldown Learning Rate Schedule") |
109 |
| - plt.legend() |
110 |
| - plt.grid(True) |
111 |
| - plt.savefig("lr_exponential.png") |
112 |
| - plt.show() |
0 commit comments