@@ -272,8 +272,8 @@ def __init__(self, store_dir: str, submodule_name: str = None):
272
272
self ._range_to_shard_idx = np .cumsum ([0 ] + [s .shape [0 ] for s in self .shards ])
273
273
if "store_tokens" in self .config and self .config ["store_tokens" ]:
274
274
self ._tokens = th .load (
275
- os .path .join (store_dir , "tokens.pt" ), weights_only = True
276
- ). cpu ()
275
+ os .path .join (store_dir , "tokens.pt" ), weights_only = True , map_location = th . device ( "cpu" )
276
+ )
277
277
278
278
self ._sequence_ranges = None
279
279
self ._mean = None
@@ -753,10 +753,27 @@ class PairedActivationCache:
753
753
def __init__ (self , store_dir_1 : str , store_dir_2 : str , submodule_name : str = None ):
754
754
self .activation_cache_1 = ActivationCache (store_dir_1 , submodule_name )
755
755
self .activation_cache_2 = ActivationCache (store_dir_2 , submodule_name )
756
- assert len (self .activation_cache_1 ) == len (self .activation_cache_2 )
757
-
756
+ if len (self .activation_cache_1 ) != len (self .activation_cache_2 ):
757
+ min_len = min (len (self .activation_cache_1 ), len (self .activation_cache_2 ))
758
+ assert self .activation_cache_1 .tokens is not None and self .activation_cache_2 .tokens is not None , "Caches have not the same length and tokens are not stored"
759
+ assert torch .all (self .activation_cache_1 .tokens [:min_len ] == self .activation_cache_2 .tokens [:min_len ]), "Tokens do not match"
760
+ self ._len = min_len
761
+ print (f"Warning: Caches have not the same length and tokens are not stored. Using the first { min_len } tokens." )
762
+ if len (self .activation_cache_1 ) > self ._len :
763
+ self ._sequence_ranges = self .activation_cache_2 .sequence_ranges
764
+ else :
765
+ self ._sequence_ranges = self .activation_cache_1 .sequence_ranges
766
+ else :
767
+ assert len (self .activation_cache_1 ) == len (self .activation_cache_2 ), f"Lengths do not match: { len (self .activation_cache_1 )} != { len (self .activation_cache_2 )} "
768
+ self ._len = len (self .activation_cache_1 )
769
+
770
+ if self .activation_cache_1 .tokens is not None and self .activation_cache_2 .tokens is not None :
771
+ assert torch .all (self .activation_cache_1 .tokens [:self ._len ] == self .activation_cache_2 .tokens [:self ._len ]), "Tokens do not match"
772
+
773
+
774
+
758
775
def __len__ (self ):
759
- return len ( self .activation_cache_1 )
776
+ return self ._len
760
777
761
778
def __getitem__ (self , index ):
762
779
if isinstance (index , slice ):
@@ -776,17 +793,11 @@ def __getitem__(self, index):
776
793
777
794
@property
778
795
def tokens (self ):
779
- return th .stack (
780
- (self .activation_cache_1 .tokens , self .activation_cache_2 .tokens ), dim = 0
781
- )
796
+ return th .stack ((self .activation_cache_1 .tokens [:self ._len ], self .activation_cache_2 .tokens [:self ._len ]), dim = 0 )
782
797
783
798
@property
784
799
def sequence_ranges (self ):
785
- seq_starts_1 = self .activation_cache_1 .sequence_ranges
786
- seq_starts_2 = self .activation_cache_2 .sequence_ranges
787
- if seq_starts_1 is not None and seq_starts_2 is not None :
788
- return th .stack ((seq_starts_1 , seq_starts_2 ), dim = 0 )
789
- return None
800
+ return self ._sequence_ranges
790
801
791
802
@property
792
803
def mean (self ):
0 commit comments