9
9
from transformers import AutoModelForCausalLM , AutoTokenizer
10
10
import numpy as np
11
11
12
+
12
13
@pytest .fixture
13
14
def temp_dir ():
14
15
"""Create a temporary directory for test files."""
@@ -274,7 +275,7 @@ def test_activation_cache_with_normalizer(temp_dir):
274
275
def test_sequence_ranges_no_bos_token (temp_dir ):
275
276
"""Test that sequence ranges are stored when model has no BOS token."""
276
277
# Set flag to handle meta tensors properly
277
- if hasattr (th .fx , ' experimental' ):
278
+ if hasattr (th .fx , " experimental" ):
278
279
th .fx .experimental ._config .meta_nonzero_assume_all_nonzero = True
279
280
280
281
# Skip test if CUDA not available
@@ -296,12 +297,18 @@ def test_sequence_ranges_no_bos_token(temp_dir):
296
297
)
297
298
model = LanguageModel (model , torch_dtype = th .float32 , tokenizer = tokenizer )
298
299
model .tokenizer .pad_token = model .tokenizer .eos_token
299
-
300
+
300
301
# Simulate model without BOS token
301
302
original_bos_token_id = model .tokenizer .bos_token_id
302
303
model .tokenizer .bos_token_id = None
303
304
304
- tokens = model .tokenizer (test_strings , add_special_tokens = True , return_tensors = "pt" , padding = True , truncation = True )
305
+ tokens = model .tokenizer (
306
+ test_strings ,
307
+ add_special_tokens = True ,
308
+ return_tensors = "pt" ,
309
+ padding = True ,
310
+ truncation = True ,
311
+ )
305
312
lengths = tokens ["attention_mask" ].sum (dim = 1 ).tolist ()
306
313
ranges = np .cumsum ([0 ] + lengths )
307
314
try :
@@ -335,28 +342,40 @@ def test_sequence_ranges_no_bos_token(temp_dir):
335
342
336
343
# Verify sequence ranges were stored
337
344
sequence_ranges = cache .sequence_ranges
338
- assert sequence_ranges is not None , "sequence ranges should be stored for model without BOS token"
339
-
345
+ assert (
346
+ sequence_ranges is not None
347
+ ), "sequence ranges should be stored for model without BOS token"
348
+
340
349
# Should have one sequence start per input string plus one for the last sequence
341
- assert len (sequence_ranges ) == len (test_strings ) + 1 , f"Expected { len (test_strings )} sequence ranges, got { len (sequence_ranges )} "
342
-
350
+ assert (
351
+ len (sequence_ranges ) == len (test_strings ) + 1
352
+ ), f"Expected { len (test_strings )} sequence ranges, got { len (sequence_ranges )} "
353
+
343
354
# First sequence should start at position 0
344
- assert sequence_ranges [0 ].item () == 0 , "First sequence should start at position 0"
355
+ assert (
356
+ sequence_ranges [0 ].item () == 0
357
+ ), "First sequence should start at position 0"
345
358
346
359
# sequence ranges should be the same as the ranges computed from the tokens
347
- assert np .allclose (sequence_ranges , ranges ), "sequence ranges should be the same as the ranges computed from the tokens"
348
-
360
+ assert np .allclose (
361
+ sequence_ranges , ranges
362
+ ), "sequence ranges should be the same as the ranges computed from the tokens"
363
+
349
364
# sequence ranges should be in ascending order
350
365
for i in range (1 , len (sequence_ranges )):
351
- assert sequence_ranges [i ] > sequence_ranges [i - 1 ], f"sequence ranges should be ascending: { sequence_ranges } "
366
+ assert (
367
+ sequence_ranges [i ] > sequence_ranges [i - 1 ]
368
+ ), f"sequence ranges should be ascending: { sequence_ranges } "
352
369
353
370
# Verify sequence ranges align with token boundaries
354
371
tokens = cache .tokens
355
372
total_tokens = len (tokens )
356
-
373
+
357
374
# All sequence ranges should be valid indices
358
375
for start_idx in sequence_ranges :
359
- assert 0 <= start_idx <= total_tokens , f"Invalid sequence start index: { start_idx } "
376
+ assert (
377
+ 0 <= start_idx <= total_tokens
378
+ ), f"Invalid sequence start index: { start_idx } "
360
379
361
380
finally :
362
381
# Restore original BOS token
@@ -366,7 +385,7 @@ def test_sequence_ranges_no_bos_token(temp_dir):
366
385
def test_sequence_ranges_with_bos_token (temp_dir ):
367
386
"""Test that sequence ranges are NOT stored when model has BOS token."""
368
387
# Set flag to handle meta tensors properly
369
- if hasattr (th .fx , ' experimental' ):
388
+ if hasattr (th .fx , " experimental" ):
370
389
th .fx .experimental ._config .meta_nonzero_assume_all_nonzero = True
371
390
372
391
# Skip test if CUDA not available
@@ -382,7 +401,7 @@ def test_sequence_ranges_with_bos_token(temp_dir):
382
401
)
383
402
model = LanguageModel (model , torch_dtype = th .float32 , tokenizer = tokenizer )
384
403
model .tokenizer .pad_token = model .tokenizer .eos_token
385
-
404
+
386
405
# Ensure model has BOS token (set it explicitly)
387
406
model .tokenizer .bos_token_id = model .tokenizer .eos_token_id
388
407
@@ -411,7 +430,9 @@ def test_sequence_ranges_with_bos_token(temp_dir):
411
430
412
431
# Verify sequence ranges were NOT stored
413
432
sequence_ranges = cache .sequence_ranges
414
- assert sequence_ranges is None , "sequence ranges should not be stored for model with BOS token"
433
+ assert (
434
+ sequence_ranges is None
435
+ ), "sequence ranges should not be stored for model with BOS token"
415
436
416
437
417
438
def test_activation_cache_slice_indexing_cross_shard (temp_dir ):
@@ -469,39 +490,45 @@ def test_activation_cache_slice_indexing_cross_shard(temp_dir):
469
490
470
491
# Load the cached activations
471
492
cache = ActivationCache (temp_dir , submodule_name + "_out" )
472
-
493
+
473
494
# Verify we have multiple shards
474
- assert len (cache .shards ) >= 2 , f"Expected at least 2 shards, got { len (cache .shards )} "
475
-
495
+ assert (
496
+ len (cache .shards ) >= 2
497
+ ), f"Expected at least 2 shards, got { len (cache .shards )} "
498
+
476
499
total_size = len (cache )
477
500
print (f"Cache has { len (cache .shards )} shards with total size { total_size } " )
478
-
501
+
479
502
# Print shard boundaries for debugging
480
503
shard_boundaries = cache ._range_to_shard_idx
481
504
print (f"Shard boundaries: { shard_boundaries } " )
482
-
505
+
483
506
# Test 1: Slice that crosses exactly one shard boundary
484
507
if len (cache .shards ) >= 2 :
485
508
# Find a slice that starts in first shard and ends in second shard
486
509
first_shard_end = shard_boundaries [1 ]
487
510
start_idx = max (0 , first_shard_end - 10 )
488
511
end_idx = min (total_size , first_shard_end + 10 )
489
-
512
+
490
513
# Get slice result
491
514
slice_result = cache [start_idx :end_idx ]
492
-
515
+
493
516
# Get individual results for comparison
494
- individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx )], dim = 0 )
495
-
517
+ individual_results = th .stack (
518
+ [cache [i ] for i in range (start_idx , end_idx )], dim = 0
519
+ )
520
+
496
521
# Verify they match
497
- assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
498
- f"Slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
499
-
522
+ assert th .allclose (
523
+ slice_result , individual_results , atol = 1e-5 , rtol = 1e-5
524
+ ), f"Slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
525
+
500
526
# Verify correct shape
501
527
expected_length = end_idx - start_idx
502
- assert slice_result .shape [0 ] == expected_length , \
503
- f"Expected slice length { expected_length } , got { slice_result .shape [0 ]} "
504
-
528
+ assert (
529
+ slice_result .shape [0 ] == expected_length
530
+ ), f"Expected slice length { expected_length } , got { slice_result .shape [0 ]} "
531
+
505
532
print (f"✓ Cross-shard slice test 1 passed: indices { start_idx } :{ end_idx } " )
506
533
507
534
# Test 2: Slice that spans multiple shards
@@ -510,54 +537,70 @@ def test_activation_cache_slice_indexing_cross_shard(temp_dir):
510
537
second_shard_end = shard_boundaries [2 ]
511
538
start_idx = max (0 , shard_boundaries [1 ] - 5 ) # Start near end of first shard
512
539
end_idx = min (total_size , second_shard_end + 5 ) # End in third shard
513
-
540
+
514
541
slice_result = cache [start_idx :end_idx ]
515
- individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx )], dim = 0 )
516
-
517
- assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
518
- f"Multi-shard slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
519
-
542
+ individual_results = th .stack (
543
+ [cache [i ] for i in range (start_idx , end_idx )], dim = 0
544
+ )
545
+
546
+ assert th .allclose (
547
+ slice_result , individual_results , atol = 1e-5 , rtol = 1e-5
548
+ ), f"Multi-shard slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
549
+
520
550
expected_length = end_idx - start_idx
521
- assert slice_result .shape [0 ] == expected_length , \
522
- f"Expected multi-shard slice length { expected_length } , got { slice_result .shape [0 ]} "
523
-
551
+ assert (
552
+ slice_result .shape [0 ] == expected_length
553
+ ), f"Expected multi-shard slice length { expected_length } , got { slice_result .shape [0 ]} "
554
+
524
555
print (f"✓ Multi-shard slice test passed: indices { start_idx } :{ end_idx } " )
525
556
526
557
# Test 3: Slice with step parameter across shards
527
558
if total_size >= 50 :
528
559
start_idx = 5
529
560
end_idx = min (total_size , 45 )
530
561
step = 3
531
-
562
+
532
563
slice_result = cache [start_idx :end_idx :step ]
533
- individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx , step )], dim = 0 )
534
-
535
- assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
536
- f"Stepped slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } :{ step } "
537
-
564
+ individual_results = th .stack (
565
+ [cache [i ] for i in range (start_idx , end_idx , step )], dim = 0
566
+ )
567
+
568
+ assert th .allclose (
569
+ slice_result , individual_results , atol = 1e-5 , rtol = 1e-5
570
+ ), f"Stepped slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } :{ step } "
571
+
538
572
expected_length = len (range (start_idx , end_idx , step ))
539
- assert slice_result .shape [0 ] == expected_length , \
540
- f"Expected stepped slice length { expected_length } , got { slice_result .shape [0 ]} "
541
-
573
+ assert (
574
+ slice_result .shape [0 ] == expected_length
575
+ ), f"Expected stepped slice length { expected_length } , got { slice_result .shape [0 ]} "
576
+
542
577
print (f"✓ Stepped slice test passed: indices { start_idx } :{ end_idx } :{ step } " )
543
578
544
579
# Test 4: Edge cases - slice at boundaries
545
580
if len (cache .shards ) >= 2 :
546
581
# Test slice starting exactly at shard boundary
547
582
boundary_idx = shard_boundaries [1 ]
548
583
if boundary_idx < total_size - 5 :
549
- slice_result = cache [boundary_idx :boundary_idx + 5 ]
550
- individual_results = th .stack ([cache [i ] for i in range (boundary_idx , boundary_idx + 5 )], dim = 0 )
551
-
552
- assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
553
- f"Boundary slice result doesn't match individual indexing"
554
-
555
- print (f"✓ Boundary slice test passed: starting at shard boundary { boundary_idx } " )
584
+ slice_result = cache [boundary_idx : boundary_idx + 5 ]
585
+ individual_results = th .stack (
586
+ [cache [i ] for i in range (boundary_idx , boundary_idx + 5 )], dim = 0
587
+ )
588
+
589
+ assert th .allclose (
590
+ slice_result , individual_results , atol = 1e-5 , rtol = 1e-5
591
+ ), f"Boundary slice result doesn't match individual indexing"
592
+
593
+ print (
594
+ f"✓ Boundary slice test passed: starting at shard boundary { boundary_idx } "
595
+ )
556
596
557
597
# Test 5: Empty slice
558
598
empty_slice = cache [10 :10 ]
559
- assert empty_slice .shape [0 ] == 0 , f"Expected empty slice, got shape { empty_slice .shape } "
599
+ assert (
600
+ empty_slice .shape [0 ] == 0
601
+ ), f"Expected empty slice, got shape { empty_slice .shape } "
560
602
print ("✓ Empty slice test passed" )
561
-
562
603
563
- print (f"✓ All slice indexing tests passed for cache with { len (cache .shards )} shards" )
604
+ print (
605
+ f"✓ All slice indexing tests passed for cache with { len (cache .shards )} shards"
606
+ )
0 commit comments