@@ -272,7 +272,9 @@ def __init__(self, store_dir: str, submodule_name: str = None):
272
272
self ._range_to_shard_idx = np .cumsum ([0 ] + [s .shape [0 ] for s in self .shards ])
273
273
if "store_tokens" in self .config and self .config ["store_tokens" ]:
274
274
self ._tokens = th .load (
275
- os .path .join (store_dir , "tokens.pt" ), weights_only = True , map_location = th .device ("cpu" )
275
+ os .path .join (store_dir , "tokens.pt" ),
276
+ weights_only = True ,
277
+ map_location = th .device ("cpu" ),
276
278
)
277
279
278
280
self ._sequence_ranges = None
@@ -309,7 +311,6 @@ def std(self):
309
311
)
310
312
return self ._std
311
313
312
-
313
314
@property
314
315
def running_stats (self ):
315
316
return RunningStatWelford .load_or_create_state (
@@ -318,12 +319,17 @@ def running_stats(self):
318
319
319
320
def __len__ (self ):
320
321
return self .config ["total_size" ]
322
+
321
323
def __getitem__ (self , index ):
322
324
if isinstance (index , slice ):
323
325
# Handle slice objects
324
326
start , stop , step = index .indices (len (self ))
325
- start_shard_idx = np .searchsorted (self ._range_to_shard_idx , start , side = "right" ) - 1
326
- stop_shard_idx = np .searchsorted (self ._range_to_shard_idx , stop , side = "right" ) - 1
327
+ start_shard_idx = (
328
+ np .searchsorted (self ._range_to_shard_idx , start , side = "right" ) - 1
329
+ )
330
+ stop_shard_idx = (
331
+ np .searchsorted (self ._range_to_shard_idx , stop , side = "right" ) - 1
332
+ )
327
333
if start_shard_idx == stop_shard_idx :
328
334
offset = start - self ._range_to_shard_idx [start_shard_idx ]
329
335
end_offset = stop - self ._range_to_shard_idx [stop_shard_idx ]
@@ -335,7 +341,9 @@ def __getitem__(self, index):
335
341
return th .stack ([self [i ] for i in range (start , stop , step )], dim = 0 )
336
342
elif isinstance (index , int ):
337
343
# Handle single integer index
338
- shard_idx = np .searchsorted (self ._range_to_shard_idx , index , side = "right" ) - 1
344
+ shard_idx = (
345
+ np .searchsorted (self ._range_to_shard_idx , index , side = "right" ) - 1
346
+ )
339
347
offset = index - self ._range_to_shard_idx [shard_idx ]
340
348
shard = self .shards [shard_idx ]
341
349
return shard [offset ]
@@ -346,23 +354,29 @@ def __getitem__(self, index):
346
354
else :
347
355
raise TypeError (f"Tensor index must be scalar, got shape { index .shape } " )
348
356
else :
349
- raise TypeError (f"Index must be int, slice, or scalar tensor, got { type (index )} " )
357
+ raise TypeError (
358
+ f"Index must be int, slice, or scalar tensor, got { type (index )} "
359
+ )
350
360
351
361
@property
352
362
def tokens (self ):
353
363
return self ._tokens
354
364
355
365
@property
356
366
def sequence_ranges (self ):
357
- if hasattr (self , ' _sequence_ranges' ) and self ._sequence_ranges is not None :
367
+ if hasattr (self , " _sequence_ranges" ) and self ._sequence_ranges is not None :
358
368
return self ._sequence_ranges
359
-
360
- if ("store_sequence_ranges" in self .config and
361
- self .config ["store_sequence_ranges" ] and
362
- os .path .exists (os .path .join (self ._cache_store_dir , ".." , "sequence_ranges.pt" ))):
369
+
370
+ if (
371
+ "store_sequence_ranges" in self .config
372
+ and self .config ["store_sequence_ranges" ]
373
+ and os .path .exists (
374
+ os .path .join (self ._cache_store_dir , ".." , "sequence_ranges.pt" )
375
+ )
376
+ ):
363
377
self ._sequence_ranges = th .load (
364
- os .path .join (self ._cache_store_dir , ".." , "sequence_ranges.pt" ),
365
- weights_only = True
378
+ os .path .join (self ._cache_store_dir , ".." , "sequence_ranges.pt" ),
379
+ weights_only = True ,
366
380
).cpu ()
367
381
return self ._sequence_ranges
368
382
else :
@@ -483,23 +497,27 @@ def exists(
483
497
num_tokens = 0
484
498
config = None
485
499
for submodule_name in submodule_names :
486
- config_path = os .path .join (store_dir , f"{ submodule_name } _{ io } " , "config.json" )
500
+ config_path = os .path .join (
501
+ store_dir , f"{ submodule_name } _{ io } " , "config.json"
502
+ )
487
503
if not os .path .exists (config_path ):
488
504
return False , 0
489
505
with open (config_path , "r" ) as f :
490
506
config = json .load (f )
491
507
num_tokens = config ["total_size" ]
492
-
508
+
493
509
if store_tokens and not os .path .exists (os .path .join (store_dir , "tokens.pt" )):
494
510
return False , 0
495
-
511
+
496
512
# Check for sequence ranges if they should exist
497
- if (config and
498
- "store_sequence_ranges" in config and
499
- config ["store_sequence_ranges" ] and
500
- not os .path .exists (os .path .join (store_dir , "sequence_ranges.pt" ))):
513
+ if (
514
+ config
515
+ and "store_sequence_ranges" in config
516
+ and config ["store_sequence_ranges" ]
517
+ and not os .path .exists (os .path .join (store_dir , "sequence_ranges.pt" ))
518
+ ):
501
519
return False , 0
502
-
520
+
503
521
return True , num_tokens
504
522
505
523
@th .no_grad ()
@@ -530,19 +548,16 @@ def collect(
530
548
assert (
531
549
not shuffle_shards or not store_tokens
532
550
), "Shuffling shards and storing tokens is not supported yet"
533
-
534
- store_sequence_ranges = (
535
- store_tokens and
536
- not shuffle_shards
537
- )
538
-
551
+
552
+ store_sequence_ranges = store_tokens and not shuffle_shards
553
+
539
554
dataloader = DataLoader (data , batch_size = batch_size , num_workers = num_workers )
540
555
541
556
activation_cache = [[] for _ in submodules ]
542
557
tokens_cache = []
543
558
sequence_ranges_cache = []
544
559
current_token_position = 0 # Track position in flattened token stream
545
-
560
+
546
561
store_sub_dirs = [
547
562
os .path .join (store_dir , f"{ submodule_names [i ]} _{ io } " )
548
563
for i in range (len (submodules ))
@@ -594,11 +609,13 @@ def collect(
594
609
store_mask = attention_mask .clone ()
595
610
if ignore_first_n_tokens_per_sample > 0 :
596
611
store_mask [:, :ignore_first_n_tokens_per_sample ] = 0
597
-
612
+
598
613
# Track sequence ranges if needed
599
614
if store_sequence_ranges :
600
615
batch_lengths = store_mask .sum (dim = 1 ).tolist ()
601
- batch_sequence_ranges = np .cumsum ([0 ] + batch_lengths [:- 1 ]) + current_token_position
616
+ batch_sequence_ranges = (
617
+ np .cumsum ([0 ] + batch_lengths [:- 1 ]) + current_token_position
618
+ )
602
619
sequence_ranges_cache .extend (batch_sequence_ranges .tolist ())
603
620
current_token_position += sum (batch_lengths )
604
621
@@ -733,7 +750,9 @@ def collect(
733
750
sequence_ranges_cache .append (current_token_position )
734
751
assert sequence_ranges_cache [- 1 ] == total_size
735
752
sequence_ranges_tensor = th .tensor (sequence_ranges_cache , dtype = th .long )
736
- th .save (sequence_ranges_tensor , os .path .join (store_dir , "sequence_ranges.pt" ))
753
+ th .save (
754
+ sequence_ranges_tensor , os .path .join (store_dir , "sequence_ranges.pt" )
755
+ )
737
756
print (f"Stored { len (sequence_ranges_cache )} sequence ranges" )
738
757
739
758
# store running stats
@@ -755,24 +774,38 @@ def __init__(self, store_dir_1: str, store_dir_2: str, submodule_name: str = Non
755
774
self .activation_cache_2 = ActivationCache (store_dir_2 , submodule_name )
756
775
if len (self .activation_cache_1 ) != len (self .activation_cache_2 ):
757
776
min_len = min (len (self .activation_cache_1 ), len (self .activation_cache_2 ))
758
- assert self .activation_cache_1 .tokens is not None and self .activation_cache_2 .tokens is not None , "Caches have not the same length and tokens are not stored"
759
- assert torch .all (self .activation_cache_1 .tokens [:min_len ] == self .activation_cache_2 .tokens [:min_len ]), "Tokens do not match"
777
+ assert (
778
+ self .activation_cache_1 .tokens is not None
779
+ and self .activation_cache_2 .tokens is not None
780
+ ), "Caches have not the same length and tokens are not stored"
781
+ assert torch .all (
782
+ self .activation_cache_1 .tokens [:min_len ]
783
+ == self .activation_cache_2 .tokens [:min_len ]
784
+ ), "Tokens do not match"
760
785
self ._len = min_len
761
- print (f"Warning: Caches have not the same length and tokens are not stored. Using the first { min_len } tokens." )
786
+ print (
787
+ f"Warning: Caches have not the same length. Using the first { min_len } tokens."
788
+ )
762
789
if len (self .activation_cache_1 ) > self ._len :
763
790
self ._sequence_ranges = self .activation_cache_2 .sequence_ranges
764
791
else :
765
792
self ._sequence_ranges = self .activation_cache_1 .sequence_ranges
766
793
else :
767
- assert len (self .activation_cache_1 ) == len (self .activation_cache_2 ), f"Lengths do not match: { len (self .activation_cache_1 )} != { len (self .activation_cache_2 )} "
794
+ assert len (self .activation_cache_1 ) == len (
795
+ self .activation_cache_2
796
+ ), f"Lengths do not match: { len (self .activation_cache_1 )} != { len (self .activation_cache_2 )} "
768
797
self ._len = len (self .activation_cache_1 )
769
798
self ._sequence_ranges = self .activation_cache_1 .sequence_ranges
770
-
771
- if self .activation_cache_1 .tokens is not None and self .activation_cache_2 .tokens is not None :
772
- assert torch .all (self .activation_cache_1 .tokens [:self ._len ] == self .activation_cache_2 .tokens [:self ._len ]), "Tokens do not match"
773
-
774
-
775
-
799
+
800
+ if (
801
+ self .activation_cache_1 .tokens is not None
802
+ and self .activation_cache_2 .tokens is not None
803
+ ):
804
+ assert torch .all (
805
+ self .activation_cache_1 .tokens [: self ._len ]
806
+ == self .activation_cache_2 .tokens [: self ._len ]
807
+ ), "Tokens do not match"
808
+
776
809
def __len__ (self ):
777
810
return self ._len
778
811
@@ -794,7 +827,13 @@ def __getitem__(self, index):
794
827
795
828
@property
796
829
def tokens (self ):
797
- return th .stack ((self .activation_cache_1 .tokens [:self ._len ], self .activation_cache_2 .tokens [:self ._len ]), dim = 0 )
830
+ return th .stack (
831
+ (
832
+ self .activation_cache_1 .tokens [: self ._len ],
833
+ self .activation_cache_2 .tokens [: self ._len ],
834
+ ),
835
+ dim = 0 ,
836
+ )
798
837
799
838
@property
800
839
def sequence_ranges (self ):
@@ -813,7 +852,6 @@ def std(self):
813
852
)
814
853
815
854
816
-
817
855
class ActivationCacheTuple :
818
856
def __init__ (self , * store_dirs : str , submodule_name : str = None ):
819
857
self .activation_caches = [
0 commit comments