Skip to content

Commit 658a0ee

Browse files
committed
Added support for PairedCaches where one of the two caches is longer. This will just cut both to the same length. They will have to have the same tokens for the first min(len(cache_1),len(cache_2)) tokens.
1 parent 30cfeb0 commit 658a0ee

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

dictionary_learning/cache.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ def __init__(self, store_dir: str, submodule_name: str = None):
272272
self._range_to_shard_idx = np.cumsum([0] + [s.shape[0] for s in self.shards])
273273
if "store_tokens" in self.config and self.config["store_tokens"]:
274274
self._tokens = th.load(
275-
os.path.join(store_dir, "tokens.pt"), weights_only=True
276-
).cpu()
275+
os.path.join(store_dir, "tokens.pt"), weights_only=True, map_location=th.device("cpu")
276+
)
277277

278278
self._sequence_ranges = None
279279
self._mean = None
@@ -753,10 +753,27 @@ class PairedActivationCache:
753753
def __init__(self, store_dir_1: str, store_dir_2: str, submodule_name: str = None):
754754
self.activation_cache_1 = ActivationCache(store_dir_1, submodule_name)
755755
self.activation_cache_2 = ActivationCache(store_dir_2, submodule_name)
756-
assert len(self.activation_cache_1) == len(self.activation_cache_2)
757-
756+
if len(self.activation_cache_1) != len(self.activation_cache_2):
757+
min_len = min(len(self.activation_cache_1), len(self.activation_cache_2))
758+
assert self.activation_cache_1.tokens is not None and self.activation_cache_2.tokens is not None, "Caches have not the same length and tokens are not stored"
759+
assert torch.all(self.activation_cache_1.tokens[:min_len] == self.activation_cache_2.tokens[:min_len]), "Tokens do not match"
760+
self._len = min_len
761+
print(f"Warning: Caches have not the same length and tokens are not stored. Using the first {min_len} tokens.")
762+
if len(self.activation_cache_1) > self._len:
763+
self._sequence_ranges = self.activation_cache_2.sequence_ranges
764+
else:
765+
self._sequence_ranges = self.activation_cache_1.sequence_ranges
766+
else:
767+
assert len(self.activation_cache_1) == len(self.activation_cache_2), f"Lengths do not match: {len(self.activation_cache_1)} != {len(self.activation_cache_2)}"
768+
self._len = len(self.activation_cache_1)
769+
770+
if self.activation_cache_1.tokens is not None and self.activation_cache_2.tokens is not None:
771+
assert torch.all(self.activation_cache_1.tokens[:self._len] == self.activation_cache_2.tokens[:self._len]), "Tokens do not match"
772+
773+
774+
758775
def __len__(self):
759-
return len(self.activation_cache_1)
776+
return self._len
760777

761778
def __getitem__(self, index):
762779
if isinstance(index, slice):
@@ -776,17 +793,11 @@ def __getitem__(self, index):
776793

777794
@property
778795
def tokens(self):
779-
return th.stack(
780-
(self.activation_cache_1.tokens, self.activation_cache_2.tokens), dim=0
781-
)
796+
return th.stack((self.activation_cache_1.tokens[:self._len], self.activation_cache_2.tokens[:self._len]), dim=0)
782797

783798
@property
784799
def sequence_ranges(self):
785-
seq_starts_1 = self.activation_cache_1.sequence_ranges
786-
seq_starts_2 = self.activation_cache_2.sequence_ranges
787-
if seq_starts_1 is not None and seq_starts_2 is not None:
788-
return th.stack((seq_starts_1, seq_starts_2), dim=0)
789-
return None
800+
return self._sequence_ranges
790801

791802
@property
792803
def mean(self):

0 commit comments

Comments
 (0)