Skip to content

Commit e8db087

Browse files
committed
Refactor imports and enhance activation normalization handling
This commit includes the following changes: - Reformatted the import statements in `__init__.py` for improved readability. - Increased the sleep duration in `ActivationCache` from 1 to 10 seconds to allow more time for save processes to complete. - Updated the `BatchTopKSAE`, `CrossCoder`, and `BatchTopKCrossCoder` classes to load the `activation_normalizer` from the state dictionary, ensuring that normalization is applied correctly during model initialization. - Refined the normalization checks in the `CrossCoderEncoder` and `CrossCoderDecoder` classes to ensure that normalization only occurs if an `activation_normalizer` is present. - Made minor formatting adjustments in the `training.py` file for better code clarity. These changes aim to enhance the clarity and maintainability of the code while ensuring proper handling of activation normalization across various components.
1 parent 17c0a92 commit e8db087

File tree

5 files changed

+76
-21
lines changed

5 files changed

+76
-21
lines changed

dictionary_learning/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,9 @@
1-
from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder, CrossCoder, BatchTopKSAE, BatchTopKCrossCoder
1+
from .dictionary import (
2+
AutoEncoder,
3+
GatedAutoEncoder,
4+
JumpReluAutoEncoder,
5+
CrossCoder,
6+
BatchTopKSAE,
7+
BatchTopKCrossCoder,
8+
)
29
from .buffer import ActivationBuffer

dictionary_learning/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def cleanup_multiprocessing():
349349
print(
350350
f"Waiting for {ActivationCache.__active_processes.value} save processes to finish"
351351
)
352-
time.sleep(1)
352+
time.sleep(10)
353353
ActivationCache.__pool.close()
354354
ActivationCache.__pool = None
355355
ActivationCache.__manager.shutdown()

dictionary_learning/dictionary.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def __init__(self, activation_normalizer: ActivationNormalizer | None = None):
3636
"""
3737
super().__init__()
3838
self.activation_normalizer = activation_normalizer
39-
if self.activation_normalizer is not None:
40-
self.activation_normalizer.to(self.device)
39+
4140

4241
def normalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tensor:
4342
"""
@@ -594,7 +593,16 @@ def from_pretrained(
594593
elif "k" in state_dict and k != state_dict["k"].item():
595594
raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']")
596595

597-
autoencoder = cls(activation_dim, dict_size, k)
596+
# Load activation normalizer if present in kwargs
597+
activation_normalizer_mean = state_dict.get("activation_normalizer.mean", None)
598+
activation_normalizer_std = state_dict.get("activation_normalizer.std", None)
599+
if activation_normalizer_mean is not None and activation_normalizer_std is not None:
600+
activation_normalizer = ActivationNormalizer(
601+
mean=activation_normalizer_mean, std=activation_normalizer_std
602+
)
603+
else:
604+
activation_normalizer = None
605+
autoencoder = cls(activation_dim, dict_size, k, activation_normalizer=activation_normalizer)
598606
autoencoder.load_state_dict(state_dict)
599607
if device is not None:
600608
autoencoder.to(device)
@@ -729,8 +737,6 @@ def __init__(
729737
self.weight = nn.Parameter(weight)
730738
self.bias = nn.Parameter(th.zeros(dict_size))
731739
self.activation_normalizer = activation_normalizer
732-
if self.activation_normalizer is not None:
733-
self.activation_normalizer.to(self.device)
734740

735741
def forward(
736742
self,
@@ -763,7 +769,7 @@ def forward(
763769
- summed_features: shape (batch_size, dict_size)
764770
- per_layer_features: shape (batch_size, num_layers, dict_size)
765771
"""
766-
if normalize_activations:
772+
if normalize_activations and self.activation_normalizer is not None:
767773
x = self.activation_normalizer.normalize(x, inplace=inplace_normalize)
768774
x = x[:, self.encoder_layers]
769775
if select_features is not None:
@@ -836,8 +842,7 @@ def __init__(
836842
weight = weight / weight.norm(dim=2, keepdim=True) * norm_init_scale
837843
self.weight = nn.Parameter(weight)
838844
self.activation_normalizer = activation_normalizer
839-
if self.activation_normalizer is not None:
840-
self.activation_normalizer.to(self.device)
845+
841846

842847
def forward(
843848
self,
@@ -873,7 +878,7 @@ def forward(
873878
x = th.einsum("blf, lfd -> bld", f, w)
874879
if add_bias:
875880
x += self.bias
876-
if denormalize_activations:
881+
if denormalize_activations and self.activation_normalizer is not None:
877882
x = self.activation_normalizer.denormalize(x, inplace=True)
878883
return x
879884

@@ -1260,6 +1265,15 @@ def from_pretrained(
12601265
code_normalization.value, dtype=th.int
12611266
)
12621267
num_layers, activation_dim, dict_size = state_dict["encoder.weight"].shape
1268+
# Load activation normalizer if present in kwargs
1269+
activation_normalizer_mean = state_dict.get("activation_normalizer.mean", None)
1270+
activation_normalizer_std = state_dict.get("activation_normalizer.std", None)
1271+
if activation_normalizer_mean is not None and activation_normalizer_std is not None:
1272+
activation_normalizer = ActivationNormalizer(
1273+
mean=activation_normalizer_mean, std=activation_normalizer_std
1274+
)
1275+
else:
1276+
activation_normalizer = None
12631277

12641278
crosscoder = cls(
12651279
activation_dim,
@@ -1268,6 +1282,7 @@ def from_pretrained(
12681282
code_normalization=CodeNormalization._value2member_map_[
12691283
state_dict["code_normalization_id"].item()
12701284
],
1285+
activation_normalizer=activation_normalizer,
12711286
)
12721287
crosscoder.load_state_dict(state_dict)
12731288

@@ -1650,6 +1665,18 @@ def from_pretrained(
16501665
state_dict["k"] == kwargs["k"]
16511666
), f"k in kwargs ({kwargs['k']}) does not match k in state_dict ({state_dict['k']})"
16521667
kwargs.pop("k")
1668+
1669+
# Load activation normalizer if present in kwargs
1670+
activation_normalizer_mean = state_dict.get("activation_normalizer.mean", None)
1671+
activation_normalizer_std = state_dict.get("activation_normalizer.std", None)
1672+
if activation_normalizer_mean is not None and activation_normalizer_std is not None:
1673+
activation_normalizer = ActivationNormalizer(
1674+
mean=activation_normalizer_mean, std=activation_normalizer_std
1675+
)
1676+
1677+
else:
1678+
activation_normalizer = None
1679+
16531680
kwargs.update()
16541681

16551682
crosscoder = cls(
@@ -1658,6 +1685,7 @@ def from_pretrained(
16581685
num_layers,
16591686
k=state_dict["k"],
16601687
code_normalization=code_normalization,
1688+
activation_normalizer=activation_normalizer,
16611689
**kwargs,
16621690
)
16631691
if "code_normalization_id" not in state_dict:

dictionary_learning/trainers/batch_top_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def update(self, step, x):
213213
x = x.to(self.device)
214214
x = self.ae.normalize_activations(
215215
x,
216-
inplace_normalize=True, # Normalize inplace to avoid copying the activations during training
216+
inplace=True, # Normalize inplace to avoid copying the activations during training
217217
)
218218
loss = self.loss(x, step=step, normalize_activations=False)
219219
loss.backward()

dictionary_learning/training.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def get_stats(
6868
out["frac_variance_explained"] = frac_variance_explained.item()
6969
return out
7070

71+
7172
def get_model(trainer):
7273
if hasattr(trainer, "ae"):
7374
model = trainer.ae
@@ -77,6 +78,7 @@ def get_model(trainer):
7778
model = model._orig_mod
7879
return model
7980

81+
8082
def log_stats(
8183
trainer,
8284
step: int,
@@ -106,7 +108,10 @@ def log_stats(
106108
for name, value in trainer_log.items():
107109
log[f"{stage}/{name}"] = value
108110

109-
wandb.log(log, step=step, epoch=epoch_idx_per_step[step] if epoch_idx_per_step is not None else None)
111+
if epoch_idx_per_step is not None:
112+
log["epoch"] = epoch_idx_per_step[step]
113+
wandb.log(log, step=step)
114+
110115

111116
@th.no_grad()
112117
def run_validation(
@@ -177,7 +182,9 @@ def run_validation(
177182
).mean()
178183
if step is not None:
179184
log["step"] = step
180-
wandb.log(log, step=step, epoch=epoch_idx_per_step[step] if epoch_idx_per_step is not None else None)
185+
if epoch_idx_per_step is not None:
186+
log["epoch"] = epoch_idx_per_step[step]
187+
wandb.log(log, step=step)
181188

182189
return log
183190

@@ -194,6 +201,7 @@ def trainSAE(
194201
use_wandb=False,
195202
wandb_entity="",
196203
wandb_project="",
204+
wandb_group="",
197205
steps=None,
198206
save_steps=None,
199207
save_dir=None,
@@ -212,13 +220,14 @@ def trainSAE(
212220
):
213221
"""
214222
Train SAE using the given trainer
215-
223+
216224
Args:
217225
data: Training data iterator/dataloader
218226
trainer_config: Configuration dictionary for the trainer
219227
use_wandb: Whether to use Weights & Biases logging (default: False)
220228
wandb_entity: W&B entity name (default: "")
221229
wandb_project: W&B project name (default: "")
230+
wandb_group: W&B group name (default: "")
222231
steps: Maximum number of training steps (default: None)
223232
save_steps: Frequency of model checkpointing (default: None)
224233
save_dir: Directory to save checkpoints and config (default: None)
@@ -234,10 +243,10 @@ def trainSAE(
234243
dtype: Training data type (default: torch.float32)
235244
run_wandb_finish: Whether to call wandb.finish() at end of training (default: True)
236245
epoch_idx_per_step: Optional mapping of training steps to epoch indices (default: None). Mainly used for logging when the dataset is pre-shuffled and contains multiple epochs.
237-
246+
238247
Returns:
239248
Trained model
240-
249+
241250
Raises:
242251
AssertionError: If validation_data is None but validate_every_n_steps is specified
243252
"""
@@ -256,11 +265,12 @@ def trainSAE(
256265
config=wandb_config,
257266
name=wandb_config["wandb_name"],
258267
mode="disabled" if not use_wandb else "online",
268+
group=wandb_group,
259269
)
260270

261271
trainer.model.to(dtype)
262272

263-
# make save dir, export config
273+
# make save dir, export config
264274
if save_dir is not None:
265275
os.makedirs(save_dir, exist_ok=True)
266276
# save config
@@ -317,7 +327,13 @@ def trainSAE(
317327
and (start_of_training_eval or step > 0)
318328
):
319329
print(f"Validating at step {step}")
320-
logs = run_validation(trainer, validation_data, step=step, dtype=dtype, epoch_idx_per_step=epoch_idx_per_step)
330+
logs = run_validation(
331+
trainer,
332+
validation_data,
333+
step=step,
334+
dtype=dtype,
335+
epoch_idx_per_step=epoch_idx_per_step,
336+
)
321337
try:
322338
os.makedirs(save_dir, exist_ok=True)
323339
th.save(logs, os.path.join(save_dir, f"eval_logs_{step}.pt"))
@@ -328,7 +344,11 @@ def trainSAE(
328344
end_of_step_logging_fn(trainer, step)
329345
try:
330346
last_eval_logs = run_validation(
331-
trainer, validation_data, step=step, dtype=dtype, epoch_idx_per_step=epoch_idx_per_step
347+
trainer,
348+
validation_data,
349+
step=step,
350+
dtype=dtype,
351+
epoch_idx_per_step=epoch_idx_per_step,
332352
)
333353
if save_last_eval:
334354
os.makedirs(save_dir, exist_ok=True)
@@ -343,4 +363,4 @@ def trainSAE(
343363
if use_wandb and run_wandb_finish:
344364
wandb.finish()
345365

346-
return get_model(trainer)
366+
return get_model(trainer)

0 commit comments

Comments
 (0)