Skip to content

Commit e3b8993

Browse files
committed
format
1 parent 8d9beb9 commit e3b8993

File tree

3 files changed

+101
-54
lines changed

3 files changed

+101
-54
lines changed

dictionary_learning/dictionary.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,13 @@ def from_pretrained(
348348

349349

350350
class BatchTopKSAE(Dictionary, nn.Module):
351-
def __init__(self, activation_dim: int, dict_size: int, k: int, activation_normalizer: ActivationNormalizer | None = None):
351+
def __init__(
352+
self,
353+
activation_dim: int,
354+
dict_size: int,
355+
k: int,
356+
activation_normalizer: ActivationNormalizer | None = None,
357+
):
352358
super().__init__()
353359
self.activation_dim = activation_dim
354360
self.dict_size = dict_size
@@ -371,7 +377,11 @@ def __init__(self, activation_dim: int, dict_size: int, k: int, activation_norma
371377
self.b_dec = nn.Parameter(th.zeros(activation_dim))
372378

373379
def encode(
374-
self, x: th.Tensor, return_active: bool = False, use_threshold: bool = True, normalize_activations: bool = True
380+
self,
381+
x: th.Tensor,
382+
return_active: bool = False,
383+
use_threshold: bool = True,
384+
normalize_activations: bool = True,
375385
):
376386
if normalize_activations:
377387
x = self.normalize_activations(x)
@@ -405,7 +415,12 @@ def normalize_activations(self, x: th.Tensor) -> th.Tensor:
405415
return self.activation_normalizer(x)
406416
return x
407417

408-
def forward(self, x: th.Tensor, output_features: bool = False, normalize_activations: bool = True):
418+
def forward(
419+
self,
420+
x: th.Tensor,
421+
output_features: bool = False,
422+
normalize_activations: bool = True,
423+
):
409424
encoded_acts_BF = self.encode(x, normalize_activations=normalize_activations)
410425
x_hat_BD = self.decode(encoded_acts_BF)
411426

@@ -983,7 +998,6 @@ def from_pretrained(
983998
)
984999
num_layers, activation_dim, dict_size = state_dict["encoder.weight"].shape
9851000

986-
9871001
crosscoder = cls(
9881002
activation_dim,
9891003
dict_size,
@@ -1093,7 +1107,7 @@ def encode(
10931107
10941108
Returns:
10951109
If return_active is False: encoded features tensor
1096-
If return_active is True: tuple of (features, scaled_features, active_mask,
1110+
If return_active is True: tuple of (features, scaled_features, active_mask,
10971111
post_relu_features, post_relu_scaled_features)
10981112
"""
10991113
if normalize_activations:
@@ -1151,7 +1165,7 @@ def encode_decoupled(
11511165
11521166
Returns:
11531167
If return_active is False: encoded features tensor of shape (batch_size, num_layers, dict_size)
1154-
If return_active is True: tuple of (features, scaled_features, active_mask,
1168+
If return_active is True: tuple of (features, scaled_features, active_mask,
11551169
post_relu_features, post_relu_scaled_features)
11561170
11571171
Raises:
@@ -1228,7 +1242,12 @@ def encode_decoupled(
12281242
return f
12291243

12301244
def get_activations(
1231-
self, x: th.Tensor, use_threshold: bool = True, select_features=None, normalize_activations: bool = True, **kwargs
1245+
self,
1246+
x: th.Tensor,
1247+
use_threshold: bool = True,
1248+
select_features=None,
1249+
normalize_activations: bool = True,
1250+
**kwargs,
12321251
):
12331252
"""
12341253
Get scaled feature activations for the input.
@@ -1314,7 +1333,7 @@ def from_pretrained(
13141333
), f"k in kwargs ({kwargs['k']}) does not match k in state_dict ({state_dict['k']})"
13151334
kwargs.pop("k")
13161335
kwargs.update()
1317-
1336+
13181337
crosscoder = cls(
13191338
activation_dim,
13201339
dict_size,

tests/test_cache.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dictionary_learning.cache import ActivationCache
99
from transformers import AutoModelForCausalLM, AutoTokenizer
1010

11+
1112
@pytest.fixture
1213
def temp_dir():
1314
"""Create a temporary directory for test files."""
@@ -149,38 +150,40 @@ def test_activation_cache_with_normalizer(temp_dir):
149150
"""Test ActivationCache collection and normalizer against direct model activations."""
150151
# Set flag to handle meta tensors properly
151152
th.fx.experimental._config.meta_nonzero_assume_all_nonzero = True
152-
153+
153154
# Skip test if CUDA not available to avoid device mapping issues
154155
if not th.cuda.is_available():
155156
pytest.skip("CUDA not available, skipping test to avoid device mapping issues")
156-
157+
157158
# Test strings
158159
test_strings = [
159160
"The quick brown fox jumps over the lazy dog.",
160161
"Machine learning is a subset of artificial intelligence.",
161162
"Python is a popular programming language for data science.",
162163
"Neural networks are inspired by biological brain structures.",
163-
"Deep learning has revolutionized computer vision and natural language processing."
164+
"Deep learning has revolutionized computer vision and natural language processing.",
164165
]
165-
166+
166167
# Use the list directly - it already implements __len__ and __getitem__
167168
dataset = test_strings
168-
169+
169170
# Load GPT-2 model - use auto device mapping but force concrete tensors
170171
tokenizer = AutoTokenizer.from_pretrained("gpt2")
171-
model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", torch_dtype=th.float32)
172+
model = AutoModelForCausalLM.from_pretrained(
173+
"gpt2", device_map="auto", torch_dtype=th.float32
174+
)
172175
model = LanguageModel(model, torch_dtype=th.float32, tokenizer=tokenizer)
173176
model.tokenizer.pad_token = model.tokenizer.eos_token
174177

175178
# Get a transformer block to extract activations from
176179
target_layer = model.transformer.h[6] # Middle layer of GPT-2
177180
submodule_name = "transformer_h_6"
178-
181+
179182
# Parameters for activation collection
180183
batch_size = 2
181184
context_len = 64
182185
d_model = 768 # GPT-2 hidden size
183-
186+
184187
# Collect activations using ActivationCache
185188
ActivationCache.collect(
186189
data=dataset,
@@ -197,17 +200,16 @@ def test_activation_cache_with_normalizer(temp_dir):
197200
store_tokens=True,
198201
)
199202

200-
201203
# Load the cached activations
202204
cache = ActivationCache(temp_dir, submodule_name + "_out")
203-
205+
204206
# Collect activations directly from model for comparison
205207
direct_activations = []
206208
direct_tokens = []
207-
209+
208210
for i in range(0, len(test_strings), batch_size):
209-
batch_texts = test_strings[i:i+batch_size]
210-
211+
batch_texts = test_strings[i : i + batch_size]
212+
211213
# Tokenize
212214
tokens = model.tokenizer(
213215
batch_texts,
@@ -217,60 +219,84 @@ def test_activation_cache_with_normalizer(temp_dir):
217219
padding=True,
218220
add_special_tokens=True,
219221
)
220-
222+
221223
# Get activations directly
222224
with model.trace(tokens):
223225
layer_output = target_layer.output[0].save()
224-
226+
225227
# Extract valid tokens (non-padding)
226228
attention_mask = tokens["attention_mask"]
227-
valid_activations = layer_output.reshape(-1, d_model)[attention_mask.reshape(-1).bool()]
228-
valid_tokens = tokens["input_ids"].reshape(-1)[attention_mask.reshape(-1).bool()]
229-
229+
valid_activations = layer_output.reshape(-1, d_model)[
230+
attention_mask.reshape(-1).bool()
231+
]
232+
valid_tokens = tokens["input_ids"].reshape(-1)[
233+
attention_mask.reshape(-1).bool()
234+
]
235+
230236
direct_activations.append(valid_activations.cpu())
231237
direct_tokens.append(valid_tokens.cpu())
232-
238+
233239
# Concatenate direct activations
234240
direct_activations = th.cat(direct_activations, dim=0)
235241
direct_tokens = th.cat(direct_tokens, dim=0)
236-
242+
237243
# Test that we have the same number of activations
238-
assert len(cache) == direct_activations.shape[0], f"Cache length {len(cache)} != direct activations length {direct_activations.shape[0]}"
239-
244+
assert (
245+
len(cache) == direct_activations.shape[0]
246+
), f"Cache length {len(cache)} != direct activations length {direct_activations.shape[0]}"
247+
240248
# Test that tokens match
241-
assert th.equal(cache.tokens, direct_tokens), "Cached tokens don't match direct tokens"
242-
249+
assert th.equal(
250+
cache.tokens, direct_tokens
251+
), "Cached tokens don't match direct tokens"
252+
243253
# Test that activations match (within tolerance for numerical precision)
244254
cached_activations = th.stack([cache[i] for i in range(len(cache))], dim=0)
245-
assert th.allclose(cached_activations, direct_activations, atol=1e-5, rtol=1e-5), "Cached activations don't match direct activations"
246-
255+
assert th.allclose(
256+
cached_activations, direct_activations, atol=1e-5, rtol=1e-5
257+
), "Cached activations don't match direct activations"
258+
247259
# Test mean and std computation
248260
computed_mean = direct_activations.mean(dim=0)
249261
computed_std = direct_activations.std(dim=0, unbiased=True)
250-
251-
assert th.allclose(cache.mean, computed_mean, atol=1e-5, rtol=1e-5), "Cached mean doesn't match computed mean"
252-
assert th.allclose(cache.std, computed_std, atol=1e-5, rtol=1e-5), "Cached std doesn't match computed std"
253-
262+
263+
assert th.allclose(
264+
cache.mean, computed_mean, atol=1e-5, rtol=1e-5
265+
), "Cached mean doesn't match computed mean"
266+
assert th.allclose(
267+
cache.std, computed_std, atol=1e-5, rtol=1e-5
268+
), "Cached std doesn't match computed std"
269+
254270
# Test normalizer functionality
255271
normalizer = cache.normalizer
256-
272+
257273
# Test normalization of a sample activation
258274
sample_activation = cached_activations[0]
259275
normalized = normalizer(sample_activation)
260-
276+
261277
# Verify normalization: (x - mean) / std (with small epsilon for numerical stability)
262278
expected_normalized = (sample_activation - cache.mean) / (cache.std + 1e-8)
263-
assert th.allclose(normalized, expected_normalized, atol=1e-6), "Normalizer doesn't work correctly"
264-
279+
assert th.allclose(
280+
normalized, expected_normalized, atol=1e-6
281+
), "Normalizer doesn't work correctly"
282+
265283
# Test batch normalization
266284
batch_normalized = normalizer(cached_activations[:5])
267-
expected_batch_normalized = (cached_activations[:5] - cache.mean) / (cache.std + 1e-8)
268-
assert th.allclose(batch_normalized, expected_batch_normalized, atol=1e-6), "Batch normalization doesn't work correctly"
269-
285+
expected_batch_normalized = (cached_activations[:5] - cache.mean) / (
286+
cache.std + 1e-8
287+
)
288+
assert th.allclose(
289+
batch_normalized, expected_batch_normalized, atol=1e-6
290+
), "Batch normalization doesn't work correctly"
291+
270292
# Test that normalization preserves shape
271-
assert normalized.shape == sample_activation.shape, "Normalization changed tensor shape"
272-
assert batch_normalized.shape == cached_activations[:5].shape, "Batch normalization changed tensor shape"
273-
293+
assert (
294+
normalized.shape == sample_activation.shape
295+
), "Normalization changed tensor shape"
296+
assert (
297+
batch_normalized.shape == cached_activations[:5].shape
298+
), "Batch normalization changed tensor shape"
299+
274300
print(f"✓ Successfully tested ActivationCache with {len(cache)} activations")
275301
print(f"✓ Mean shape: {cache.mean.shape}, Std shape: {cache.std.shape}")
276302
print(f"✓ Normalizer tests passed")

tests/test_running_stat_welford.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
10-
@pytest.mark.parametrize("D", [5, 127]) # feature dimensionalities
10+
@pytest.mark.parametrize("D", [5, 127]) # feature dimensionalities
1111
def test_streaming_matches_reference(dtype, D):
1212
"""
1313
Stream random data through RunningStatWelford in random-sized batches
@@ -30,11 +30,11 @@ def test_streaming_matches_reference(dtype, D):
3030

3131
# Ground-truth (double precision to remove numeric noise)
3232
ref_mean = full.double().mean(dim=0)
33-
ref_std = full.double().std(dim=0, unbiased=True)
33+
ref_std = full.double().std(dim=0, unbiased=True)
3434

3535
# Compare
3636
torch.testing.assert_close(acc.mean, ref_mean, rtol=1e-6, atol=1e-7)
37-
torch.testing.assert_close(acc.std(), ref_std, rtol=1e-6, atol=1e-7)
37+
torch.testing.assert_close(acc.std(), ref_std, rtol=1e-6, atol=1e-7)
3838
assert acc.n == N_total
3939

4040

@@ -60,10 +60,10 @@ def test_merge_two_accumulators():
6060

6161
# Reference
6262
ref_mean = data.double().mean(dim=0)
63-
ref_std = data.double().std(dim=0, unbiased=True)
63+
ref_std = data.double().std(dim=0, unbiased=True)
6464

6565
torch.testing.assert_close(acc1.mean, ref_mean, rtol=1e-6, atol=1e-7)
66-
torch.testing.assert_close(acc1.std(), ref_std, rtol=1e-6, atol=1e-7)
66+
torch.testing.assert_close(acc1.std(), ref_std, rtol=1e-6, atol=1e-7)
6767
assert acc1.n == N_total
6868

6969

@@ -87,4 +87,6 @@ def test_edge_cases():
8787
acc.update(torch.tensor([[2.0, 4.0, 6.0]], dtype=dtype))
8888
assert acc.n == 2
8989
torch.testing.assert_close(acc.mean, torch.tensor([1.5, 3.0, 4.5], dtype=dtype))
90-
torch.testing.assert_close(acc.std(), torch.tensor([0.70710678, 1.41421356, 2.12132034], dtype=dtype))
90+
torch.testing.assert_close(
91+
acc.std(), torch.tensor([0.70710678, 1.41421356, 2.12132034], dtype=dtype)
92+
)

0 commit comments

Comments
 (0)