Skip to content

Commit 6314cb5

Browse files
committed
Updated misleading warning message + format
1 parent 2ccfb8a commit 6314cb5

File tree

3 files changed

+94
-48
lines changed

3 files changed

+94
-48
lines changed

dictionary_learning/cache.py

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,9 @@ def __init__(self, store_dir: str, submodule_name: str = None):
272272
self._range_to_shard_idx = np.cumsum([0] + [s.shape[0] for s in self.shards])
273273
if "store_tokens" in self.config and self.config["store_tokens"]:
274274
self._tokens = th.load(
275-
os.path.join(store_dir, "tokens.pt"), weights_only=True, map_location=th.device("cpu")
275+
os.path.join(store_dir, "tokens.pt"),
276+
weights_only=True,
277+
map_location=th.device("cpu"),
276278
)
277279

278280
self._sequence_ranges = None
@@ -309,7 +311,6 @@ def std(self):
309311
)
310312
return self._std
311313

312-
313314
@property
314315
def running_stats(self):
315316
return RunningStatWelford.load_or_create_state(
@@ -318,12 +319,17 @@ def running_stats(self):
318319

319320
def __len__(self):
320321
return self.config["total_size"]
322+
321323
def __getitem__(self, index):
322324
if isinstance(index, slice):
323325
# Handle slice objects
324326
start, stop, step = index.indices(len(self))
325-
start_shard_idx = np.searchsorted(self._range_to_shard_idx, start, side="right") - 1
326-
stop_shard_idx = np.searchsorted(self._range_to_shard_idx, stop, side="right") - 1
327+
start_shard_idx = (
328+
np.searchsorted(self._range_to_shard_idx, start, side="right") - 1
329+
)
330+
stop_shard_idx = (
331+
np.searchsorted(self._range_to_shard_idx, stop, side="right") - 1
332+
)
327333
if start_shard_idx == stop_shard_idx:
328334
offset = start - self._range_to_shard_idx[start_shard_idx]
329335
end_offset = stop - self._range_to_shard_idx[stop_shard_idx]
@@ -335,7 +341,9 @@ def __getitem__(self, index):
335341
return th.stack([self[i] for i in range(start, stop, step)], dim=0)
336342
elif isinstance(index, int):
337343
# Handle single integer index
338-
shard_idx = np.searchsorted(self._range_to_shard_idx, index, side="right") - 1
344+
shard_idx = (
345+
np.searchsorted(self._range_to_shard_idx, index, side="right") - 1
346+
)
339347
offset = index - self._range_to_shard_idx[shard_idx]
340348
shard = self.shards[shard_idx]
341349
return shard[offset]
@@ -346,23 +354,29 @@ def __getitem__(self, index):
346354
else:
347355
raise TypeError(f"Tensor index must be scalar, got shape {index.shape}")
348356
else:
349-
raise TypeError(f"Index must be int, slice, or scalar tensor, got {type(index)}")
357+
raise TypeError(
358+
f"Index must be int, slice, or scalar tensor, got {type(index)}"
359+
)
350360

351361
@property
352362
def tokens(self):
353363
return self._tokens
354364

355365
@property
356366
def sequence_ranges(self):
357-
if hasattr(self, '_sequence_ranges') and self._sequence_ranges is not None:
367+
if hasattr(self, "_sequence_ranges") and self._sequence_ranges is not None:
358368
return self._sequence_ranges
359-
360-
if ("store_sequence_ranges" in self.config and
361-
self.config["store_sequence_ranges"] and
362-
os.path.exists(os.path.join(self._cache_store_dir, "..", "sequence_ranges.pt"))):
369+
370+
if (
371+
"store_sequence_ranges" in self.config
372+
and self.config["store_sequence_ranges"]
373+
and os.path.exists(
374+
os.path.join(self._cache_store_dir, "..", "sequence_ranges.pt")
375+
)
376+
):
363377
self._sequence_ranges = th.load(
364-
os.path.join(self._cache_store_dir, "..", "sequence_ranges.pt"),
365-
weights_only=True
378+
os.path.join(self._cache_store_dir, "..", "sequence_ranges.pt"),
379+
weights_only=True,
366380
).cpu()
367381
return self._sequence_ranges
368382
else:
@@ -483,23 +497,27 @@ def exists(
483497
num_tokens = 0
484498
config = None
485499
for submodule_name in submodule_names:
486-
config_path = os.path.join(store_dir, f"{submodule_name}_{io}", "config.json")
500+
config_path = os.path.join(
501+
store_dir, f"{submodule_name}_{io}", "config.json"
502+
)
487503
if not os.path.exists(config_path):
488504
return False, 0
489505
with open(config_path, "r") as f:
490506
config = json.load(f)
491507
num_tokens = config["total_size"]
492-
508+
493509
if store_tokens and not os.path.exists(os.path.join(store_dir, "tokens.pt")):
494510
return False, 0
495-
511+
496512
# Check for sequence ranges if they should exist
497-
if (config and
498-
"store_sequence_ranges" in config and
499-
config["store_sequence_ranges"] and
500-
not os.path.exists(os.path.join(store_dir, "sequence_ranges.pt"))):
513+
if (
514+
config
515+
and "store_sequence_ranges" in config
516+
and config["store_sequence_ranges"]
517+
and not os.path.exists(os.path.join(store_dir, "sequence_ranges.pt"))
518+
):
501519
return False, 0
502-
520+
503521
return True, num_tokens
504522

505523
@th.no_grad()
@@ -530,19 +548,16 @@ def collect(
530548
assert (
531549
not shuffle_shards or not store_tokens
532550
), "Shuffling shards and storing tokens is not supported yet"
533-
534-
store_sequence_ranges = (
535-
store_tokens and
536-
not shuffle_shards
537-
)
538-
551+
552+
store_sequence_ranges = store_tokens and not shuffle_shards
553+
539554
dataloader = DataLoader(data, batch_size=batch_size, num_workers=num_workers)
540555

541556
activation_cache = [[] for _ in submodules]
542557
tokens_cache = []
543558
sequence_ranges_cache = []
544559
current_token_position = 0 # Track position in flattened token stream
545-
560+
546561
store_sub_dirs = [
547562
os.path.join(store_dir, f"{submodule_names[i]}_{io}")
548563
for i in range(len(submodules))
@@ -594,11 +609,13 @@ def collect(
594609
store_mask = attention_mask.clone()
595610
if ignore_first_n_tokens_per_sample > 0:
596611
store_mask[:, :ignore_first_n_tokens_per_sample] = 0
597-
612+
598613
# Track sequence ranges if needed
599614
if store_sequence_ranges:
600615
batch_lengths = store_mask.sum(dim=1).tolist()
601-
batch_sequence_ranges = np.cumsum([0] + batch_lengths[:-1]) + current_token_position
616+
batch_sequence_ranges = (
617+
np.cumsum([0] + batch_lengths[:-1]) + current_token_position
618+
)
602619
sequence_ranges_cache.extend(batch_sequence_ranges.tolist())
603620
current_token_position += sum(batch_lengths)
604621

@@ -733,7 +750,9 @@ def collect(
733750
sequence_ranges_cache.append(current_token_position)
734751
assert sequence_ranges_cache[-1] == total_size
735752
sequence_ranges_tensor = th.tensor(sequence_ranges_cache, dtype=th.long)
736-
th.save(sequence_ranges_tensor, os.path.join(store_dir, "sequence_ranges.pt"))
753+
th.save(
754+
sequence_ranges_tensor, os.path.join(store_dir, "sequence_ranges.pt")
755+
)
737756
print(f"Stored {len(sequence_ranges_cache)} sequence ranges")
738757

739758
# store running stats
@@ -755,24 +774,38 @@ def __init__(self, store_dir_1: str, store_dir_2: str, submodule_name: str = Non
755774
self.activation_cache_2 = ActivationCache(store_dir_2, submodule_name)
756775
if len(self.activation_cache_1) != len(self.activation_cache_2):
757776
min_len = min(len(self.activation_cache_1), len(self.activation_cache_2))
758-
assert self.activation_cache_1.tokens is not None and self.activation_cache_2.tokens is not None, "Caches have not the same length and tokens are not stored"
759-
assert torch.all(self.activation_cache_1.tokens[:min_len] == self.activation_cache_2.tokens[:min_len]), "Tokens do not match"
777+
assert (
778+
self.activation_cache_1.tokens is not None
779+
and self.activation_cache_2.tokens is not None
780+
), "Caches have not the same length and tokens are not stored"
781+
assert torch.all(
782+
self.activation_cache_1.tokens[:min_len]
783+
== self.activation_cache_2.tokens[:min_len]
784+
), "Tokens do not match"
760785
self._len = min_len
761-
print(f"Warning: Caches have not the same length and tokens are not stored. Using the first {min_len} tokens.")
786+
print(
787+
f"Warning: Caches have not the same length. Using the first {min_len} tokens."
788+
)
762789
if len(self.activation_cache_1) > self._len:
763790
self._sequence_ranges = self.activation_cache_2.sequence_ranges
764791
else:
765792
self._sequence_ranges = self.activation_cache_1.sequence_ranges
766793
else:
767-
assert len(self.activation_cache_1) == len(self.activation_cache_2), f"Lengths do not match: {len(self.activation_cache_1)} != {len(self.activation_cache_2)}"
794+
assert len(self.activation_cache_1) == len(
795+
self.activation_cache_2
796+
), f"Lengths do not match: {len(self.activation_cache_1)} != {len(self.activation_cache_2)}"
768797
self._len = len(self.activation_cache_1)
769798
self._sequence_ranges = self.activation_cache_1.sequence_ranges
770-
771-
if self.activation_cache_1.tokens is not None and self.activation_cache_2.tokens is not None:
772-
assert torch.all(self.activation_cache_1.tokens[:self._len] == self.activation_cache_2.tokens[:self._len]), "Tokens do not match"
773-
774-
775-
799+
800+
if (
801+
self.activation_cache_1.tokens is not None
802+
and self.activation_cache_2.tokens is not None
803+
):
804+
assert torch.all(
805+
self.activation_cache_1.tokens[: self._len]
806+
== self.activation_cache_2.tokens[: self._len]
807+
), "Tokens do not match"
808+
776809
def __len__(self):
777810
return self._len
778811

@@ -794,7 +827,13 @@ def __getitem__(self, index):
794827

795828
@property
796829
def tokens(self):
797-
return th.stack((self.activation_cache_1.tokens[:self._len], self.activation_cache_2.tokens[:self._len]), dim=0)
830+
return th.stack(
831+
(
832+
self.activation_cache_1.tokens[: self._len],
833+
self.activation_cache_2.tokens[: self._len],
834+
),
835+
dim=0,
836+
)
798837

799838
@property
800839
def sequence_ranges(self):
@@ -813,7 +852,6 @@ def std(self):
813852
)
814853

815854

816-
817855
class ActivationCacheTuple:
818856
def __init__(self, *store_dirs: str, submodule_name: str = None):
819857
self.activation_caches = [

dictionary_learning/dictionary.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def __init__(
7474
activation_global_scale = self.target_rms / th.sqrt(total_var + 1e-8)
7575
self.register_buffer("activation_global_scale", activation_global_scale)
7676
else:
77-
self.register_buffer("activation_global_scale", th.ones(activation_shape[:-1]))
77+
self.register_buffer(
78+
"activation_global_scale", th.ones(activation_shape[:-1])
79+
)
7880

7981
@property
8082
def has_activation_normalizer(self) -> bool:
@@ -98,7 +100,9 @@ def normalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tenso
98100
if self.has_activation_normalizer:
99101
if not inplace:
100102
x = x.clone()
101-
assert x.shape[1:-1] == self.activation_global_scale.shape, "Normalization shape mismatch"
103+
assert (
104+
x.shape[1:-1] == self.activation_global_scale.shape
105+
), "Normalization shape mismatch"
102106
x = x - self.activation_mean
103107

104108
if self.keep_relative_variance:
@@ -121,7 +125,9 @@ def denormalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Ten
121125
if self.has_activation_normalizer:
122126
if not inplace:
123127
x = x.clone()
124-
assert x.shape[1:-1] == self.activation_global_scale.shape, "Normalization shape mismatch"
128+
assert (
129+
x.shape[1:-1] == self.activation_global_scale.shape
130+
), "Normalization shape mismatch"
125131

126132
if self.keep_relative_variance:
127133
x = x / (self.activation_global_scale.unsqueeze(0).unsqueeze(-1) + 1e-8)

dictionary_learning/training.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,9 @@ def trainSAE(
307307
epoch_idx_per_step=epoch_idx_per_step,
308308
num_tokens=num_tokens,
309309
)
310-
if isinstance(trainer, BatchTopKCrossCoderTrainer) or isinstance(trainer, BatchTopKTrainer):
310+
if isinstance(trainer, BatchTopKCrossCoderTrainer) or isinstance(
311+
trainer, BatchTopKTrainer
312+
):
311313
log_stats(
312314
trainer,
313315
step,
@@ -374,4 +376,4 @@ def trainSAE(
374376
if return_last_eval_logs:
375377
return get_model(trainer), last_eval_logs
376378
else:
377-
return get_model(trainer)
379+
return get_model(trainer)

0 commit comments

Comments
 (0)