Skip to content

Commit 0e94d50

Browse files
committed
Added support for encoder_init_norm and refactored normalization.
1 parent 3bfc5ef commit 0e94d50

File tree

3 files changed

+96
-19
lines changed

3 files changed

+96
-19
lines changed

dictionary_learning/dictionary.py

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@ class NormalizableMixin(nn.Module):
2626
pass through unchanged.
2727
"""
2828

29-
def __init__(self, activation_mean: th.Tensor | None = None, activation_std: th.Tensor | None = None, activation_shape: tuple[int, ...] | None = None):
29+
def __init__(
30+
self,
31+
activation_mean: th.Tensor | None = None,
32+
activation_std: th.Tensor | None = None,
33+
activation_shape: tuple[int, ...] | None = None,
34+
*,
35+
keep_relative_variance: bool = True,
36+
target_rms: float = 1.0,
37+
):
3038
"""
3139
Initialize the normalization mixin.
3240
@@ -36,26 +44,44 @@ def __init__(self, activation_mean: th.Tensor | None = None, activation_std: th.
3644
activation_std: Optional std tensor for normalization. If None,
3745
normalization is a no-op.
3846
activation_shape: Shape of the activation tensor. Required if activation_mean and activation_std are None for proper initialization and registration of the buffers.
47+
keep_relative_variance: If True, performs global scaling so that the
48+
sum of variances is 1 while their relative magnitudes stay unchanged. If false we normalize neuron-wise.
49+
target_rms: Target RMS for input activation normalization.
3950
"""
4051
super().__init__()
52+
self.keep_relative_variance = keep_relative_variance
53+
self.register_buffer("target_rms", th.tensor(target_rms))
4154
if activation_mean is not None and activation_std is not None:
4255
# Type assertion to help linter understand these are tensors
43-
assert isinstance(activation_mean, th.Tensor), "Expected mean to be a tensor"
56+
assert isinstance(
57+
activation_mean, th.Tensor
58+
), "Expected mean to be a tensor"
4459
assert isinstance(activation_std, th.Tensor), "Expected std to be a tensor"
4560
assert not th.isnan(activation_mean).any(), "Expected mean to be non-NaN"
4661
assert not th.isnan(activation_std).any(), "Expected std to be non-NaN"
4762
self.register_buffer("activation_mean", activation_mean)
4863
self.register_buffer("activation_std", activation_std)
4964
else:
50-
assert activation_shape is not None, "activation_shape must be provided if activation_mean and activation_std are None"
65+
assert (
66+
activation_shape is not None
67+
), "activation_shape must be provided if activation_mean and activation_std are None"
5168
self.register_buffer("activation_mean", th.nan * th.ones(activation_shape))
5269
self.register_buffer("activation_std", th.nan * th.ones(activation_shape))
5370

71+
if self.keep_relative_variance and self.has_activation_normalizer:
72+
total_var = (self.activation_std**2).sum()
73+
activation_global_scale = self.target_rms / th.sqrt(total_var + 1e-8)
74+
self.register_buffer("activation_global_scale", activation_global_scale)
75+
else:
76+
self.register_buffer("activation_global_scale", th.tensor(1.0))
77+
5478
@property
5579
def has_activation_normalizer(self) -> bool:
5680
"""Check if activation normalization is enabled."""
57-
return (not th.isnan(self.activation_mean).any() and
58-
not th.isnan(self.activation_std).any())
81+
return (
82+
not th.isnan(self.activation_mean).any()
83+
and not th.isnan(self.activation_std).any()
84+
)
5985

6086
def normalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tensor:
6187
"""
@@ -74,7 +100,12 @@ def normalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tenso
74100
# Type assertions for linter
75101
assert isinstance(self.activation_mean, th.Tensor)
76102
assert isinstance(self.activation_std, th.Tensor)
77-
return (x - self.activation_mean) / (self.activation_std + 1e-8)
103+
x = x - self.activation_mean
104+
105+
if self.keep_relative_variance:
106+
return x * self.activation_global_scale
107+
else:
108+
return x / (self.activation_std + 1e-8)
78109
return x
79110

80111
def denormalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tensor:
@@ -94,7 +125,13 @@ def denormalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Ten
94125
# Type assertions for linter
95126
assert isinstance(self.activation_mean, th.Tensor)
96127
assert isinstance(self.activation_std, th.Tensor)
97-
return x * (self.activation_std + 1e-8) + self.activation_mean
128+
129+
if self.keep_relative_variance:
130+
x = x / (self.activation_global_scale + 1e-8)
131+
else:
132+
x = x * (self.activation_std + 1e-8)
133+
134+
return x + self.activation_mean
98135
return x
99136

100137

@@ -454,6 +491,8 @@ def __init__(
454491
k: int,
455492
activation_mean: th.Tensor | None = None,
456493
activation_std: th.Tensor | None = None,
494+
target_rms: float = 1.0,
495+
encoder_init_norm: float = 1.0,
457496
):
458497
"""
459498
Initialize the Batch Top-K SAE.
@@ -464,11 +503,17 @@ def __init__(
464503
k: Number of top features to keep active across the batch
465504
activation_mean: Optional mean tensor for input activation normalization. If None, no normalization is applied.
466505
activation_std: Optional std tensor for input activation normalization. If None, no normalization is applied.
506+
target_rms: Target variance for input activation normalization.
507+
encoder_init_norm: Norm for the encoder weights.
467508
"""
468509

469-
super().__init__(activation_mean=activation_mean, activation_std=activation_std, activation_shape=(activation_dim,))
470-
471-
510+
super().__init__(
511+
activation_mean=activation_mean,
512+
activation_std=activation_std,
513+
activation_shape=(activation_dim,),
514+
target_rms=target_rms,
515+
)
516+
472517
self.activation_dim = activation_dim
473518
self.dict_size = dict_size
474519

@@ -482,7 +527,7 @@ def __init__(
482527
)
483528

484529
self.encoder = nn.Linear(activation_dim, dict_size)
485-
self.encoder.weight.data = self.decoder.weight.T.clone()
530+
self.encoder.weight.data = self.decoder.weight.T.clone() * encoder_init_norm
486531
self.encoder.bias.data.zero_()
487532
self.b_dec = nn.Parameter(th.zeros(activation_dim))
488533

@@ -627,10 +672,10 @@ def from_pretrained(
627672
elif "k" in state_dict and k != state_dict["k"].item():
628673
raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']")
629674

630-
631-
632675
autoencoder = cls(
633-
activation_dim, dict_size, k,
676+
activation_dim,
677+
dict_size,
678+
k,
634679
)
635680
autoencoder.load_state_dict(state_dict)
636681
if device is not None:
@@ -645,6 +690,7 @@ def dtype(self):
645690
def device(self):
646691
return self.encoder.weight.device
647692

693+
648694
# TODO merge this with AutoEncoder
649695
class AutoEncoderNew(Dictionary, nn.Module):
650696
"""
@@ -994,6 +1040,7 @@ class CrossCoder(Dictionary, NormalizableMixin):
9941040
code_normalization_alpha_cc: Weight for CrossCoder component in MIXED normalization
9951041
activation_mean: Optional mean tensor for input/output activation normalization
9961042
activation_std: Optional std tensor for input/output activation normalization
1043+
target_rms: Optional target RMS for input/output activation normalization
9971044
"""
9981045

9991046
def __init__(
@@ -1012,6 +1059,7 @@ def __init__(
10121059
code_normalization_alpha_cc: float | None = 0.1,
10131060
activation_mean: th.Tensor | None = None,
10141061
activation_std: th.Tensor | None = None,
1062+
target_rms: float | None = None,
10151063
):
10161064
"""
10171065
Initialize a CrossCoder sparse autoencoder.
@@ -1031,11 +1079,16 @@ def __init__(
10311079
code_normalization_alpha_cc: Weight for CrossCoder component in MIXED normalization
10321080
activation_mean: Optional mean tensor for input/output activation normalization
10331081
activation_std: Optional std tensor for input/output activation normalization
1082+
target_rms: Optional target RMS for input/output activation normalization
10341083
"""
10351084
# First initialize the base classes that don't take normalization parameters
1036-
super().__init__(activation_mean=activation_mean, activation_std=activation_std, activation_shape=(num_layers, activation_dim))
1085+
super().__init__(
1086+
activation_mean=activation_mean,
1087+
activation_std=activation_std,
1088+
activation_shape=(num_layers, activation_dim),
1089+
target_rms=target_rms,
1090+
)
10371091

1038-
10391092
if num_decoder_layers is None:
10401093
num_decoder_layers = num_layers
10411094

@@ -1306,7 +1359,7 @@ def dtype(self):
13061359
@property
13071360
def device(self):
13081361
return self.encoder.weight.device
1309-
1362+
13101363
def resample_neurons(self, deads, activations):
13111364
"""
13121365
Resample dead neurons by reinitializing their weights.
@@ -1401,6 +1454,7 @@ def __init__(
14011454
norm_init_scale: Scale factor for weight initialization normalization
14021455
activation_mean: Optional mean tensor for input/output activation normalization
14031456
activation_std: Optional std tensor for input/output activation normalization
1457+
target_rms: Optional target RMS for input/output activation normalization
14041458
*args: Additional positional arguments passed to parent class
14051459
**kwargs: Additional keyword arguments passed to parent class
14061460
"""
@@ -1411,6 +1465,7 @@ def __init__(
14111465
norm_init_scale=norm_init_scale,
14121466
activation_mean=activation_mean,
14131467
activation_std=activation_std,
1468+
target_rms=target_rms,
14141469
*args,
14151470
**kwargs,
14161471
)
@@ -1687,7 +1742,6 @@ def from_pretrained(
16871742
), f"k in kwargs ({kwargs['k']}) does not match k in state_dict ({state_dict['k']})"
16881743
kwargs.pop("k")
16891744

1690-
16911745
crosscoder = cls(
16921746
activation_dim,
16931747
dict_size,

dictionary_learning/trainers/batch_top_k.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(
3232
wandb_name: str = "BatchTopKSAE",
3333
activation_mean: Optional[t.Tensor] = None,
3434
activation_std: Optional[t.Tensor] = None,
35+
target_rms: float = 1.0,
36+
encoder_init_norm: str = 1.0,
3537
):
3638
super().__init__(seed)
3739
assert layer is not None and lm_name is not None
@@ -50,7 +52,13 @@ def __init__(
5052
t.cuda.manual_seed_all(seed)
5153

5254
self.ae = dict_class(
53-
activation_dim, dict_size, k, activation_mean=activation_mean, activation_std=activation_std
55+
activation_dim,
56+
dict_size,
57+
k,
58+
activation_mean=activation_mean,
59+
activation_std=activation_std,
60+
target_rms=target_rms,
61+
encoder_init_norm=encoder_init_norm,
5462
)
5563

5664
if device is None:
@@ -78,6 +86,7 @@ def __init__(
7886
self.effective_l0 = -1
7987
self.dead_features = -1
8088
self.pre_norm_auxk_loss = -1
89+
self.encoder_init_norm = encoder_init_norm
8190

8291
self.optimizer = t.optim.Adam(
8392
self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)
@@ -199,6 +208,9 @@ def loss(
199208
"l2_loss": l2_loss.item(),
200209
"auxk_loss": auxk_loss.item(),
201210
"loss": loss.item(),
211+
"deads": ~did_fire,
212+
"threshold": self.ae.threshold.item(),
213+
"rms_norm": t.sqrt((x.pow(2).sum(-1)).mean()).item(),
202214
},
203215
)
204216

@@ -256,6 +268,7 @@ def config(self):
256268
"layer": self.layer,
257269
"lm_name": self.lm_name,
258270
"wandb_name": self.wandb_name,
271+
"encoder_init_norm": self.encoder_init_norm,
259272
}
260273

261274
@staticmethod

dictionary_learning/trainers/crosscoder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class CrossCoderTrainer(SAETrainer):
3838
use_mse_loss: Whether to use MSE loss instead of L2 loss for reconstruction (default: False)
3939
activation_mean: Optional activation mean (default: None)
4040
activation_std: Optional activation std (default: None)
41+
target_rms: Target RMS for input activation normalization.
4142
"""
4243

4344
def __init__(
@@ -62,6 +63,7 @@ def __init__(
6263
use_mse_loss=False,
6364
activation_mean: Optional[th.Tensor] = None,
6465
activation_std: Optional[th.Tensor] = None,
66+
target_rms: float = 1.0,
6567
):
6668
super().__init__(seed)
6769

@@ -71,6 +73,7 @@ def __init__(
7173
self.submodule_name = submodule_name
7274
self.compile = compile
7375
self.use_mse_loss = use_mse_loss
76+
self.target_rms = target_rms
7477
if seed is not None:
7578
th.manual_seed(seed)
7679
th.cuda.manual_seed_all(seed)
@@ -83,6 +86,7 @@ def __init__(
8386
num_layers=num_layers,
8487
activation_mean=activation_mean,
8588
activation_std=activation_std,
89+
target_rms=target_rms,
8690
**dict_class_kwargs,
8791
)
8892
else:
@@ -267,6 +271,7 @@ def config(self):
267271
"code_normalization": str(self.ae.code_normalization),
268272
"code_normalization_alpha_sae": self.ae.code_normalization_alpha_sae,
269273
"code_normalization_alpha_cc": self.ae.code_normalization_alpha_cc,
274+
"target_rms": self.target_rms,
270275
}
271276

272277

@@ -302,6 +307,7 @@ class BatchTopKCrossCoderTrainer(SAETrainer):
302307
dict_class_kwargs: Additional arguments for the dictionary class (default: {})
303308
activation_mean: Optional activation mean (default: None)
304309
activation_std: Optional activation std (default: None)
310+
target_rms: Target RMS for input activation normalization.
305311
"""
306312

307313
def __init__(
@@ -330,6 +336,7 @@ def __init__(
330336
dict_class_kwargs: dict = {},
331337
activation_mean: Optional[th.Tensor] = None,
332338
activation_std: Optional[th.Tensor] = None,
339+
target_rms: float = 1.0,
333340
):
334341
super().__init__(seed)
335342
assert layer is not None and lm_name is not None
@@ -348,6 +355,7 @@ def __init__(
348355

349356
self.threshold_beta = threshold_beta
350357
self.threshold_start_step = threshold_start_step
358+
self.target_rms = target_rms
351359

352360
if seed is not None:
353361
th.manual_seed(seed)
@@ -362,6 +370,7 @@ def __init__(
362370
self.k_initial,
363371
activation_mean=activation_mean,
364372
activation_std=activation_std,
373+
target_rms=target_rms,
365374
**dict_class_kwargs,
366375
)
367376
else:
@@ -689,6 +698,7 @@ def config(self):
689698
"wandb_name": self.wandb_name,
690699
"submodule_name": self.submodule_name,
691700
"dict_class_kwargs": {k: str(v) for k, v in self.dict_class_kwargs.items()},
701+
"target_rms": self.target_rms,
692702
}
693703

694704
@staticmethod

0 commit comments

Comments
 (0)