8
8
from dictionary_learning .cache import ActivationCache
9
9
from transformers import AutoModelForCausalLM , AutoTokenizer
10
10
11
+
11
12
@pytest .fixture
12
13
def temp_dir ():
13
14
"""Create a temporary directory for test files."""
@@ -149,38 +150,40 @@ def test_activation_cache_with_normalizer(temp_dir):
149
150
"""Test ActivationCache collection and normalizer against direct model activations."""
150
151
# Set flag to handle meta tensors properly
151
152
th .fx .experimental ._config .meta_nonzero_assume_all_nonzero = True
152
-
153
+
153
154
# Skip test if CUDA not available to avoid device mapping issues
154
155
if not th .cuda .is_available ():
155
156
pytest .skip ("CUDA not available, skipping test to avoid device mapping issues" )
156
-
157
+
157
158
# Test strings
158
159
test_strings = [
159
160
"The quick brown fox jumps over the lazy dog." ,
160
161
"Machine learning is a subset of artificial intelligence." ,
161
162
"Python is a popular programming language for data science." ,
162
163
"Neural networks are inspired by biological brain structures." ,
163
- "Deep learning has revolutionized computer vision and natural language processing."
164
+ "Deep learning has revolutionized computer vision and natural language processing." ,
164
165
]
165
-
166
+
166
167
# Use the list directly - it already implements __len__ and __getitem__
167
168
dataset = test_strings
168
-
169
+
169
170
# Load GPT-2 model - use auto device mapping but force concrete tensors
170
171
tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
171
- model = AutoModelForCausalLM .from_pretrained ("gpt2" , device_map = "auto" , torch_dtype = th .float32 )
172
+ model = AutoModelForCausalLM .from_pretrained (
173
+ "gpt2" , device_map = "auto" , torch_dtype = th .float32
174
+ )
172
175
model = LanguageModel (model , torch_dtype = th .float32 , tokenizer = tokenizer )
173
176
model .tokenizer .pad_token = model .tokenizer .eos_token
174
177
175
178
# Get a transformer block to extract activations from
176
179
target_layer = model .transformer .h [6 ] # Middle layer of GPT-2
177
180
submodule_name = "transformer_h_6"
178
-
181
+
179
182
# Parameters for activation collection
180
183
batch_size = 2
181
184
context_len = 64
182
185
d_model = 768 # GPT-2 hidden size
183
-
186
+
184
187
# Collect activations using ActivationCache
185
188
ActivationCache .collect (
186
189
data = dataset ,
@@ -197,17 +200,16 @@ def test_activation_cache_with_normalizer(temp_dir):
197
200
store_tokens = True ,
198
201
)
199
202
200
-
201
203
# Load the cached activations
202
204
cache = ActivationCache (temp_dir , submodule_name + "_out" )
203
-
205
+
204
206
# Collect activations directly from model for comparison
205
207
direct_activations = []
206
208
direct_tokens = []
207
-
209
+
208
210
for i in range (0 , len (test_strings ), batch_size ):
209
- batch_texts = test_strings [i : i + batch_size ]
210
-
211
+ batch_texts = test_strings [i : i + batch_size ]
212
+
211
213
# Tokenize
212
214
tokens = model .tokenizer (
213
215
batch_texts ,
@@ -217,60 +219,84 @@ def test_activation_cache_with_normalizer(temp_dir):
217
219
padding = True ,
218
220
add_special_tokens = True ,
219
221
)
220
-
222
+
221
223
# Get activations directly
222
224
with model .trace (tokens ):
223
225
layer_output = target_layer .output [0 ].save ()
224
-
226
+
225
227
# Extract valid tokens (non-padding)
226
228
attention_mask = tokens ["attention_mask" ]
227
- valid_activations = layer_output .reshape (- 1 , d_model )[attention_mask .reshape (- 1 ).bool ()]
228
- valid_tokens = tokens ["input_ids" ].reshape (- 1 )[attention_mask .reshape (- 1 ).bool ()]
229
-
229
+ valid_activations = layer_output .reshape (- 1 , d_model )[
230
+ attention_mask .reshape (- 1 ).bool ()
231
+ ]
232
+ valid_tokens = tokens ["input_ids" ].reshape (- 1 )[
233
+ attention_mask .reshape (- 1 ).bool ()
234
+ ]
235
+
230
236
direct_activations .append (valid_activations .cpu ())
231
237
direct_tokens .append (valid_tokens .cpu ())
232
-
238
+
233
239
# Concatenate direct activations
234
240
direct_activations = th .cat (direct_activations , dim = 0 )
235
241
direct_tokens = th .cat (direct_tokens , dim = 0 )
236
-
242
+
237
243
# Test that we have the same number of activations
238
- assert len (cache ) == direct_activations .shape [0 ], f"Cache length { len (cache )} != direct activations length { direct_activations .shape [0 ]} "
239
-
244
+ assert (
245
+ len (cache ) == direct_activations .shape [0 ]
246
+ ), f"Cache length { len (cache )} != direct activations length { direct_activations .shape [0 ]} "
247
+
240
248
# Test that tokens match
241
- assert th .equal (cache .tokens , direct_tokens ), "Cached tokens don't match direct tokens"
242
-
249
+ assert th .equal (
250
+ cache .tokens , direct_tokens
251
+ ), "Cached tokens don't match direct tokens"
252
+
243
253
# Test that activations match (within tolerance for numerical precision)
244
254
cached_activations = th .stack ([cache [i ] for i in range (len (cache ))], dim = 0 )
245
- assert th .allclose (cached_activations , direct_activations , atol = 1e-5 , rtol = 1e-5 ), "Cached activations don't match direct activations"
246
-
255
+ assert th .allclose (
256
+ cached_activations , direct_activations , atol = 1e-5 , rtol = 1e-5
257
+ ), "Cached activations don't match direct activations"
258
+
247
259
# Test mean and std computation
248
260
computed_mean = direct_activations .mean (dim = 0 )
249
261
computed_std = direct_activations .std (dim = 0 , unbiased = True )
250
-
251
- assert th .allclose (cache .mean , computed_mean , atol = 1e-5 , rtol = 1e-5 ), "Cached mean doesn't match computed mean"
252
- assert th .allclose (cache .std , computed_std , atol = 1e-5 , rtol = 1e-5 ), "Cached std doesn't match computed std"
253
-
262
+
263
+ assert th .allclose (
264
+ cache .mean , computed_mean , atol = 1e-5 , rtol = 1e-5
265
+ ), "Cached mean doesn't match computed mean"
266
+ assert th .allclose (
267
+ cache .std , computed_std , atol = 1e-5 , rtol = 1e-5
268
+ ), "Cached std doesn't match computed std"
269
+
254
270
# Test normalizer functionality
255
271
normalizer = cache .normalizer
256
-
272
+
257
273
# Test normalization of a sample activation
258
274
sample_activation = cached_activations [0 ]
259
275
normalized = normalizer (sample_activation )
260
-
276
+
261
277
# Verify normalization: (x - mean) / std (with small epsilon for numerical stability)
262
278
expected_normalized = (sample_activation - cache .mean ) / (cache .std + 1e-8 )
263
- assert th .allclose (normalized , expected_normalized , atol = 1e-6 ), "Normalizer doesn't work correctly"
264
-
279
+ assert th .allclose (
280
+ normalized , expected_normalized , atol = 1e-6
281
+ ), "Normalizer doesn't work correctly"
282
+
265
283
# Test batch normalization
266
284
batch_normalized = normalizer (cached_activations [:5 ])
267
- expected_batch_normalized = (cached_activations [:5 ] - cache .mean ) / (cache .std + 1e-8 )
268
- assert th .allclose (batch_normalized , expected_batch_normalized , atol = 1e-6 ), "Batch normalization doesn't work correctly"
269
-
285
+ expected_batch_normalized = (cached_activations [:5 ] - cache .mean ) / (
286
+ cache .std + 1e-8
287
+ )
288
+ assert th .allclose (
289
+ batch_normalized , expected_batch_normalized , atol = 1e-6
290
+ ), "Batch normalization doesn't work correctly"
291
+
270
292
# Test that normalization preserves shape
271
- assert normalized .shape == sample_activation .shape , "Normalization changed tensor shape"
272
- assert batch_normalized .shape == cached_activations [:5 ].shape , "Batch normalization changed tensor shape"
273
-
293
+ assert (
294
+ normalized .shape == sample_activation .shape
295
+ ), "Normalization changed tensor shape"
296
+ assert (
297
+ batch_normalized .shape == cached_activations [:5 ].shape
298
+ ), "Batch normalization changed tensor shape"
299
+
274
300
print (f"✓ Successfully tested ActivationCache with { len (cache )} activations" )
275
301
print (f"✓ Mean shape: { cache .mean .shape } , Std shape: { cache .std .shape } " )
276
302
print (f"✓ Normalizer tests passed" )
0 commit comments