Skip to content

Commit 08649f0

Browse files
committed
enfore all tests to run
Signed-off-by: Max Jeblick <maximilianjeblick@gmail.com>
1 parent 3bf1d63 commit 08649f0

File tree

9 files changed

+28
-17
lines changed

9 files changed

+28
-17
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Description of your PR. Fixes # (issue) (if applicable)
44

55
## Checklist
66

7+
Before submitting a PR, please make sure:
8+
79
- Tests are working (`make test`)
810
- Code is formatted correctly (`make style`, on errors try fix with `make format`)
911
- Copyright header is included

Makefile

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,16 @@ reports:
4242
.PHONY: test
4343
test: reports
4444
$(UV) add optimum-quanto
45-
$(UV) add flash-attn --no-build-isolation
45+
$(UV) add flash-attn
4646
PYTHONPATH=. \
4747
$(UV) run pytest \
4848
--cov-report xml:reports/coverage.xml \
4949
--cov=kvpress/ \
5050
--junitxml=./reports/junit.xml \
51-
tests/
51+
-v \
52+
tests/ | tee reports/pytest_output.log
53+
@if grep -q "SKIPPED" reports/pytest_output.log; then \
54+
echo "Error: Tests were skipped. All tests must run."; \
55+
grep "SKIPPED" reports/pytest_output.log; \
56+
exit 1; \
57+
fi

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ dependencies = [
2222
"accelerate>=1.0.0,<2",
2323
"requests>=2.32.3,<3",
2424
"cachetools>=5.5.2,<6",
25+
"optimum-quanto>=0.2.7",
26+
"hatch>=1.14.1",
27+
"flash-attn>=2.8.2",
2528
]
2629

2730
[project.optional-dependencies]
@@ -89,4 +92,4 @@ disable_error_code = ["attr-defined"]
8992

9093
[[tool.mypy.overrides]]
9194
module = "kvpress.pipeline"
92-
disable_error_code = ["attr-defined", "assignment", "override"]
95+
disable_error_code = ["attr-defined", "assignment", "override"]

tests/integration/test_ruler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press
2727
kwargs = press_dict["kwargs"][0]
2828
press = cls(**kwargs)
2929
if not hasattr(cls, "compression_ratio"):
30-
pytest.skip(reason="Press does not support compression_ratio")
30+
return # "Press does not support compression_ratio"
3131
# set compression ratio to a small value for testing
3232
try:
3333
press.compression_ratio = 0.1
3434
except AttributeError:
35-
pytest.skip(reason="Press does not support setting compression_ratio")
35+
return # "Press does not support setting compression_ratio"
3636

3737
if cache == "dynamic":
3838
cache = DynamicCache()

tests/presses/test_block_press.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_block_press_is_streaming_top_k(unit_test_model): # noqa: F811
3333
"""
3434
press = HiddenStatesPress(compression_ratio=0.5)
3535
generator = torch.Generator().manual_seed(0)
36-
input_ids = torch.randint(0, 1024, (1, 256), generator=generator)
36+
input_ids = torch.randint(0, 1024, (1, 256), generator=generator).to(unit_test_model.device)
3737
keys_hash = []
3838
values_hash = []
3939

tests/presses/test_finch_press.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ def test_finch_press(unit_test_model): # noqa: F811
1616
]:
1717
press.delimiter_token_id = unit_test_model.config.eos_token_id
1818
with press(unit_test_model):
19-
input_ids = torch.arange(10, 20)
19+
input_ids = torch.arange(10, 20).to(unit_test_model.device)
2020
input_ids[8] = press.delimiter_token_id
2121
unit_test_model(input_ids.unsqueeze(0))

tests/presses/test_head_compression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_wrapper_head_compression(unit_test_model, wrapper_press, compression_ra
2828
p = KnormPress(compression_ratio=compression_ratio)
2929
press = wrapper_press(press=p)
3030
with press(unit_test_model):
31-
input_ids = torch.randint(0, 1024, (1, 128))
31+
input_ids = torch.randint(0, 1024, (1, 128)).to(unit_test_model.device)
3232
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
3333

3434
assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None
@@ -47,7 +47,7 @@ def test_wrapper_head_compression(unit_test_model, wrapper_press, compression_ra
4747
def test_head_compression(unit_test_model, press, compression_ratio, layerwise): # noqa: F811
4848
press = KVzipPress(compression_ratio=compression_ratio, layerwise=layerwise)
4949
with press(unit_test_model):
50-
input_ids = torch.randint(0, 1024, (1, 128))
50+
input_ids = torch.randint(0, 1024, (1, 128)).to(unit_test_model.device)
5151
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
5252

5353
assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None

tests/presses/test_observed_attention_press.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313

1414
@torch.no_grad()
1515
def test_observed_drops_attention_output(unit_test_model, unit_test_model_output_attention, caplog): # noqa: F811
16-
input_ids = unit_test_model.dummy_inputs["input_ids"]
16+
input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
1717
output = unit_test_model(input_ids, past_key_values=DynamicCache())
1818
assert output.attentions is None
1919

20-
input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"]
20+
input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"].to(unit_test_model.device)
2121
attentions = unit_test_model_output_attention(input_ids, past_key_values=DynamicCache()).attentions
2222
assert all([isinstance(attention, torch.Tensor) for attention in attentions])
2323

2424
with caplog.at_level(logging.DEBUG):
2525
press = ObservedAttentionPress(compression_ratio=0.4)
2626
with press(unit_test_model_output_attention):
27-
input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"]
27+
input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"].to(unit_test_model.device)
2828
output = unit_test_model_output_attention(input_ids, past_key_values=DynamicCache())
2929

3030
# There's a slight mismatch in outputs when using a model that has output_attentions=True
@@ -36,7 +36,7 @@ def test_observed_drops_attention_output(unit_test_model, unit_test_model_output
3636

3737
press = ObservedAttentionPress(compression_ratio=0.4, output_attentions=True)
3838
with press(unit_test_model_output_attention):
39-
input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"]
39+
input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"].to(unit_test_model.device)
4040
output = unit_test_model_output_attention(input_ids, past_key_values=DynamicCache())
4141

4242
assert all(

tests/presses/test_wrappers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_composed_press_qfilter_without_post_init(unit_test_model): # noqa: F81
1414
composed_press = ComposedPress([press1, press2])
1515
with pytest.raises(ValueError, match="post_init_from_model"):
1616
with composed_press(unit_test_model):
17-
input_ids = unit_test_model.dummy_inputs["input_ids"]
17+
input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
1818
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
1919

2020

@@ -28,7 +28,7 @@ def test_composed_press_duo_attention_without_post_init(unit_test_model): # noq
2828
composed_press = ComposedPress([press1, press2])
2929
with pytest.raises(ValueError, match="post_init_from_model"):
3030
with composed_press(unit_test_model):
31-
input_ids = unit_test_model.dummy_inputs["input_ids"]
31+
input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
3232
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
3333

3434

@@ -45,7 +45,7 @@ def test_composed_qfilter_press_with_post_init(unit_test_model): # noqa: F811
4545

4646
composed_press = ComposedPress([press1, press2])
4747
with composed_press(unit_test_model):
48-
input_ids = unit_test_model.dummy_inputs["input_ids"]
48+
input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
4949
with pytest.raises(RuntimeError, match="The size of tensor"):
5050
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
5151

@@ -63,5 +63,5 @@ def test_composed_duo_attention_press_with_post_init(unit_test_model): # noqa:
6363

6464
composed_press = ComposedPress([press1, press2])
6565
with composed_press(unit_test_model):
66-
input_ids = unit_test_model.dummy_inputs["input_ids"]
66+
input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
6767
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values

0 commit comments

Comments
 (0)