Skip to content

Commit 173f7fa

Browse files
committed
format
1 parent b870fab commit 173f7fa

File tree

3 files changed

+106
-61
lines changed

3 files changed

+106
-61
lines changed

dictionary_learning/trainers/crosscoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ def loss(
212212
"deads": deads if return_deads else None,
213213
}
214214
for layer in range(x.shape[1]):
215-
log_dict[f"rms_norm_l{layer}"] = th.sqrt((x[:, layer, :].pow(2).sum(-1)).mean()).item()
215+
log_dict[f"rms_norm_l{layer}"] = th.sqrt(
216+
(x[:, layer, :].pow(2).sum(-1)).mean()
217+
).item()
216218
return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])(
217219
x,
218220
x_hat,

dictionary_learning/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,4 @@ def remove_gradient_parallel_to_decoder_directions(
122122
normed_W_dec_DF,
123123
"d_sae, d_in d_sae -> d_in d_sae",
124124
)
125-
return W_dec_DF_grad
125+
return W_dec_DF_grad

tests/test_cache.py

Lines changed: 102 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from transformers import AutoModelForCausalLM, AutoTokenizer
1010
import numpy as np
1111

12+
1213
@pytest.fixture
1314
def temp_dir():
1415
"""Create a temporary directory for test files."""
@@ -274,7 +275,7 @@ def test_activation_cache_with_normalizer(temp_dir):
274275
def test_sequence_ranges_no_bos_token(temp_dir):
275276
"""Test that sequence ranges are stored when model has no BOS token."""
276277
# Set flag to handle meta tensors properly
277-
if hasattr(th.fx, 'experimental'):
278+
if hasattr(th.fx, "experimental"):
278279
th.fx.experimental._config.meta_nonzero_assume_all_nonzero = True
279280

280281
# Skip test if CUDA not available
@@ -296,12 +297,18 @@ def test_sequence_ranges_no_bos_token(temp_dir):
296297
)
297298
model = LanguageModel(model, torch_dtype=th.float32, tokenizer=tokenizer)
298299
model.tokenizer.pad_token = model.tokenizer.eos_token
299-
300+
300301
# Simulate model without BOS token
301302
original_bos_token_id = model.tokenizer.bos_token_id
302303
model.tokenizer.bos_token_id = None
303304

304-
tokens = model.tokenizer(test_strings, add_special_tokens=True, return_tensors="pt", padding=True, truncation=True)
305+
tokens = model.tokenizer(
306+
test_strings,
307+
add_special_tokens=True,
308+
return_tensors="pt",
309+
padding=True,
310+
truncation=True,
311+
)
305312
lengths = tokens["attention_mask"].sum(dim=1).tolist()
306313
ranges = np.cumsum([0] + lengths)
307314
try:
@@ -335,28 +342,40 @@ def test_sequence_ranges_no_bos_token(temp_dir):
335342

336343
# Verify sequence ranges were stored
337344
sequence_ranges = cache.sequence_ranges
338-
assert sequence_ranges is not None, "sequence ranges should be stored for model without BOS token"
339-
345+
assert (
346+
sequence_ranges is not None
347+
), "sequence ranges should be stored for model without BOS token"
348+
340349
# Should have one sequence start per input string plus one for the last sequence
341-
assert len(sequence_ranges) == len(test_strings) + 1, f"Expected {len(test_strings)} sequence ranges, got {len(sequence_ranges)}"
342-
350+
assert (
351+
len(sequence_ranges) == len(test_strings) + 1
352+
), f"Expected {len(test_strings)} sequence ranges, got {len(sequence_ranges)}"
353+
343354
# First sequence should start at position 0
344-
assert sequence_ranges[0].item() == 0, "First sequence should start at position 0"
355+
assert (
356+
sequence_ranges[0].item() == 0
357+
), "First sequence should start at position 0"
345358

346359
# sequence ranges should be the same as the ranges computed from the tokens
347-
assert np.allclose(sequence_ranges, ranges), "sequence ranges should be the same as the ranges computed from the tokens"
348-
360+
assert np.allclose(
361+
sequence_ranges, ranges
362+
), "sequence ranges should be the same as the ranges computed from the tokens"
363+
349364
# sequence ranges should be in ascending order
350365
for i in range(1, len(sequence_ranges)):
351-
assert sequence_ranges[i] > sequence_ranges[i-1], f"sequence ranges should be ascending: {sequence_ranges}"
366+
assert (
367+
sequence_ranges[i] > sequence_ranges[i - 1]
368+
), f"sequence ranges should be ascending: {sequence_ranges}"
352369

353370
# Verify sequence ranges align with token boundaries
354371
tokens = cache.tokens
355372
total_tokens = len(tokens)
356-
373+
357374
# All sequence ranges should be valid indices
358375
for start_idx in sequence_ranges:
359-
assert 0 <= start_idx <= total_tokens, f"Invalid sequence start index: {start_idx}"
376+
assert (
377+
0 <= start_idx <= total_tokens
378+
), f"Invalid sequence start index: {start_idx}"
360379

361380
finally:
362381
# Restore original BOS token
@@ -366,7 +385,7 @@ def test_sequence_ranges_no_bos_token(temp_dir):
366385
def test_sequence_ranges_with_bos_token(temp_dir):
367386
"""Test that sequence ranges are NOT stored when model has BOS token."""
368387
# Set flag to handle meta tensors properly
369-
if hasattr(th.fx, 'experimental'):
388+
if hasattr(th.fx, "experimental"):
370389
th.fx.experimental._config.meta_nonzero_assume_all_nonzero = True
371390

372391
# Skip test if CUDA not available
@@ -382,7 +401,7 @@ def test_sequence_ranges_with_bos_token(temp_dir):
382401
)
383402
model = LanguageModel(model, torch_dtype=th.float32, tokenizer=tokenizer)
384403
model.tokenizer.pad_token = model.tokenizer.eos_token
385-
404+
386405
# Ensure model has BOS token (set it explicitly)
387406
model.tokenizer.bos_token_id = model.tokenizer.eos_token_id
388407

@@ -411,7 +430,9 @@ def test_sequence_ranges_with_bos_token(temp_dir):
411430

412431
# Verify sequence ranges were NOT stored
413432
sequence_ranges = cache.sequence_ranges
414-
assert sequence_ranges is None, "sequence ranges should not be stored for model with BOS token"
433+
assert (
434+
sequence_ranges is None
435+
), "sequence ranges should not be stored for model with BOS token"
415436

416437

417438
def test_activation_cache_slice_indexing_cross_shard(temp_dir):
@@ -469,39 +490,45 @@ def test_activation_cache_slice_indexing_cross_shard(temp_dir):
469490

470491
# Load the cached activations
471492
cache = ActivationCache(temp_dir, submodule_name + "_out")
472-
493+
473494
# Verify we have multiple shards
474-
assert len(cache.shards) >= 2, f"Expected at least 2 shards, got {len(cache.shards)}"
475-
495+
assert (
496+
len(cache.shards) >= 2
497+
), f"Expected at least 2 shards, got {len(cache.shards)}"
498+
476499
total_size = len(cache)
477500
print(f"Cache has {len(cache.shards)} shards with total size {total_size}")
478-
501+
479502
# Print shard boundaries for debugging
480503
shard_boundaries = cache._range_to_shard_idx
481504
print(f"Shard boundaries: {shard_boundaries}")
482-
505+
483506
# Test 1: Slice that crosses exactly one shard boundary
484507
if len(cache.shards) >= 2:
485508
# Find a slice that starts in first shard and ends in second shard
486509
first_shard_end = shard_boundaries[1]
487510
start_idx = max(0, first_shard_end - 10)
488511
end_idx = min(total_size, first_shard_end + 10)
489-
512+
490513
# Get slice result
491514
slice_result = cache[start_idx:end_idx]
492-
515+
493516
# Get individual results for comparison
494-
individual_results = th.stack([cache[i] for i in range(start_idx, end_idx)], dim=0)
495-
517+
individual_results = th.stack(
518+
[cache[i] for i in range(start_idx, end_idx)], dim=0
519+
)
520+
496521
# Verify they match
497-
assert th.allclose(slice_result, individual_results, atol=1e-5, rtol=1e-5), \
498-
f"Slice result doesn't match individual indexing for indices {start_idx}:{end_idx}"
499-
522+
assert th.allclose(
523+
slice_result, individual_results, atol=1e-5, rtol=1e-5
524+
), f"Slice result doesn't match individual indexing for indices {start_idx}:{end_idx}"
525+
500526
# Verify correct shape
501527
expected_length = end_idx - start_idx
502-
assert slice_result.shape[0] == expected_length, \
503-
f"Expected slice length {expected_length}, got {slice_result.shape[0]}"
504-
528+
assert (
529+
slice_result.shape[0] == expected_length
530+
), f"Expected slice length {expected_length}, got {slice_result.shape[0]}"
531+
505532
print(f"✓ Cross-shard slice test 1 passed: indices {start_idx}:{end_idx}")
506533

507534
# Test 2: Slice that spans multiple shards
@@ -510,54 +537,70 @@ def test_activation_cache_slice_indexing_cross_shard(temp_dir):
510537
second_shard_end = shard_boundaries[2]
511538
start_idx = max(0, shard_boundaries[1] - 5) # Start near end of first shard
512539
end_idx = min(total_size, second_shard_end + 5) # End in third shard
513-
540+
514541
slice_result = cache[start_idx:end_idx]
515-
individual_results = th.stack([cache[i] for i in range(start_idx, end_idx)], dim=0)
516-
517-
assert th.allclose(slice_result, individual_results, atol=1e-5, rtol=1e-5), \
518-
f"Multi-shard slice result doesn't match individual indexing for indices {start_idx}:{end_idx}"
519-
542+
individual_results = th.stack(
543+
[cache[i] for i in range(start_idx, end_idx)], dim=0
544+
)
545+
546+
assert th.allclose(
547+
slice_result, individual_results, atol=1e-5, rtol=1e-5
548+
), f"Multi-shard slice result doesn't match individual indexing for indices {start_idx}:{end_idx}"
549+
520550
expected_length = end_idx - start_idx
521-
assert slice_result.shape[0] == expected_length, \
522-
f"Expected multi-shard slice length {expected_length}, got {slice_result.shape[0]}"
523-
551+
assert (
552+
slice_result.shape[0] == expected_length
553+
), f"Expected multi-shard slice length {expected_length}, got {slice_result.shape[0]}"
554+
524555
print(f"✓ Multi-shard slice test passed: indices {start_idx}:{end_idx}")
525556

526557
# Test 3: Slice with step parameter across shards
527558
if total_size >= 50:
528559
start_idx = 5
529560
end_idx = min(total_size, 45)
530561
step = 3
531-
562+
532563
slice_result = cache[start_idx:end_idx:step]
533-
individual_results = th.stack([cache[i] for i in range(start_idx, end_idx, step)], dim=0)
534-
535-
assert th.allclose(slice_result, individual_results, atol=1e-5, rtol=1e-5), \
536-
f"Stepped slice result doesn't match individual indexing for indices {start_idx}:{end_idx}:{step}"
537-
564+
individual_results = th.stack(
565+
[cache[i] for i in range(start_idx, end_idx, step)], dim=0
566+
)
567+
568+
assert th.allclose(
569+
slice_result, individual_results, atol=1e-5, rtol=1e-5
570+
), f"Stepped slice result doesn't match individual indexing for indices {start_idx}:{end_idx}:{step}"
571+
538572
expected_length = len(range(start_idx, end_idx, step))
539-
assert slice_result.shape[0] == expected_length, \
540-
f"Expected stepped slice length {expected_length}, got {slice_result.shape[0]}"
541-
573+
assert (
574+
slice_result.shape[0] == expected_length
575+
), f"Expected stepped slice length {expected_length}, got {slice_result.shape[0]}"
576+
542577
print(f"✓ Stepped slice test passed: indices {start_idx}:{end_idx}:{step}")
543578

544579
# Test 4: Edge cases - slice at boundaries
545580
if len(cache.shards) >= 2:
546581
# Test slice starting exactly at shard boundary
547582
boundary_idx = shard_boundaries[1]
548583
if boundary_idx < total_size - 5:
549-
slice_result = cache[boundary_idx:boundary_idx + 5]
550-
individual_results = th.stack([cache[i] for i in range(boundary_idx, boundary_idx + 5)], dim=0)
551-
552-
assert th.allclose(slice_result, individual_results, atol=1e-5, rtol=1e-5), \
553-
f"Boundary slice result doesn't match individual indexing"
554-
555-
print(f"✓ Boundary slice test passed: starting at shard boundary {boundary_idx}")
584+
slice_result = cache[boundary_idx : boundary_idx + 5]
585+
individual_results = th.stack(
586+
[cache[i] for i in range(boundary_idx, boundary_idx + 5)], dim=0
587+
)
588+
589+
assert th.allclose(
590+
slice_result, individual_results, atol=1e-5, rtol=1e-5
591+
), f"Boundary slice result doesn't match individual indexing"
592+
593+
print(
594+
f"✓ Boundary slice test passed: starting at shard boundary {boundary_idx}"
595+
)
556596

557597
# Test 5: Empty slice
558598
empty_slice = cache[10:10]
559-
assert empty_slice.shape[0] == 0, f"Expected empty slice, got shape {empty_slice.shape}"
599+
assert (
600+
empty_slice.shape[0] == 0
601+
), f"Expected empty slice, got shape {empty_slice.shape}"
560602
print("✓ Empty slice test passed")
561-
562603

563-
print(f"✓ All slice indexing tests passed for cache with {len(cache.shards)} shards")
604+
print(
605+
f"✓ All slice indexing tests passed for cache with {len(cache.shards)} shards"
606+
)

0 commit comments

Comments
 (0)