Skip to content

Commit 9595c13

Browse files
committed
Remove scripts
1 parent 8991080 commit 9595c13

File tree

2 files changed

+0
-97
lines changed

2 files changed

+0
-97
lines changed

common.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import jax
21
import jax.numpy as jnp
3-
import matplotlib.pyplot as plt
42
import optax
53

64

@@ -51,62 +49,3 @@ def schedule_fn(step):
5149
return lr
5250

5351
return schedule_fn
54-
55-
56-
if __name__ == "__main__":
57-
### Testing the lr_scheduler
58-
num_iters = 50_000
59-
warmup_steps = 500
60-
cooldown_steps = 1_000
61-
62-
decay_steps = num_iters - warmup_steps - cooldown_steps
63-
64-
base_lr = 1e-3
65-
clip_value = 1.0
66-
decay_rate = 0.1
67-
68-
lr_scheduler = warmup_exponential_decay_cooldown_scheduler(
69-
warmup_steps, base_lr, decay_steps, decay_rate, cooldown_steps, min_lr=0.0
70-
)
71-
72-
### Small training loop
73-
params = jax.random.normal(jax.random.PRNGKey(0), (3, 3))
74-
75-
optimizer = optax.chain(
76-
optax.clip_by_global_norm(clip_value),
77-
optax.adamw(lr_scheduler, b2=0.98, weight_decay=0.01),
78-
)
79-
80-
opt_state = optimizer.init(params)
81-
82-
@jax.jit
83-
def update(params, opt_state):
84-
loss, grads = jax.value_and_grad(lambda params: jnp.sum(params**2))(params)
85-
updates, opt_state = optimizer.update(grads, opt_state, params)
86-
params = optax.apply_updates(params, updates)
87-
return params, opt_state, loss
88-
89-
num_steps = warmup_steps + decay_steps + cooldown_steps
90-
91-
steps = []
92-
lrs = []
93-
for step in range(num_steps + 1):
94-
batch = jnp.ones((3, 3)) # Dummy batch data
95-
params, opt_state, loss = update(params, opt_state)
96-
current_lr = lr_scheduler(step)
97-
if step % 1000 == 0:
98-
print(f"Step {step} | Loss: {loss} | Learning Rate: {current_lr}")
99-
steps.append(step)
100-
lrs.append(current_lr)
101-
102-
### Plot the LR
103-
plt.figure(figsize=(10, 6))
104-
plt.plot(steps, lrs, label="Learning rate")
105-
106-
plt.xlabel("Steps")
107-
plt.ylabel("Learning Rate")
108-
plt.title("Warmup, Exponential Decay, and Cooldown Learning Rate Schedule")
109-
plt.legend()
110-
plt.grid(True)
111-
plt.savefig("lr_exponential.png")
112-
plt.show()

dataset.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import collections
22
import itertools
3-
import os
4-
from glob import glob
53

64
import jax
75
import tensorflow as tf
8-
from einops import rearrange
9-
from matplotlib import pyplot as plt
106

11-
from transformations.native_aspect_ratio_resize import NativeAspectRatioResize
127

138
AUTOTUNE = tf.data.AUTOTUNE
149

@@ -137,34 +132,3 @@ def enqueue(n):
137132
while queue:
138133
yield queue.popleft()
139134
enqueue(1)
140-
141-
142-
if __name__ == "__main__":
143-
tf.random.set_seed(0)
144-
batch_size = 1
145-
image_dir = "./tfrecords"
146-
train_files = glob(os.path.join(image_dir, "*.tfrec"))
147-
train_dataset = load_dataset(train_files, 14, [NativeAspectRatioResize(224, 14)])
148-
train_ds = prefetch(
149-
train_dataset.shuffle(10 * batch_size, seed=1)
150-
.batch(batch_size)
151-
.prefetch(tf.data.AUTOTUNE)
152-
.repeat()
153-
.as_numpy_iterator()
154-
)
155-
156-
for batch in train_ds:
157-
patches, patch_indices, label, attention_matrix, loss_mask = batch
158-
h, w = 1 + patch_indices[0].max(axis=0)
159-
print(h, w)
160-
image = rearrange(
161-
patches[0][: h * w],
162-
"(h w) (p1 p2 c) -> (h p1) (w p2) c",
163-
h=h,
164-
w=w,
165-
p1=14,
166-
p2=14,
167-
c=3,
168-
)
169-
plt.imshow(image)
170-
plt.show()

0 commit comments

Comments
 (0)