Skip to content

Commit 423baa8

Browse files
committed
switch to Qwen/Qwen3-4B-Instruct-2507
Signed-off-by: Max Jeblick <maximilianjeblick@gmail.com>
1 parent 92d6e0e commit 423baa8

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

tests/fixtures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ def kv_press_llama3_2_flash_attn_pipeline():
8787

8888

8989
@pytest.fixture(scope="session")
90-
def kv_press_llama3_1_flash_attn_pipeline():
90+
def kv_press_qwen3_flash_attn_pipeline():
9191
device = "cuda:0"
92-
ckpt = "meta-llama/Llama-3.1-8B-Instruct"
92+
ckpt = "Qwen/Qwen3-4B-Instruct-2507"
9393
attn_implementation = "flash_attention_2"
9494
pipe = pipeline(
9595
"kv-press-text-generation",

tests/integration/test_ruler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available
99

1010
from tests.default_presses import default_presses
11-
from tests.fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401
11+
from tests.fixtures import kv_press_qwen3_flash_attn_pipeline # noqa: F401
1212

1313

1414
@pytest.fixture(scope="session")
@@ -22,7 +22,7 @@ def df_ruler():
2222
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
2323
@pytest.mark.parametrize("press_dict", default_presses)
2424
@pytest.mark.parametrize("cache", ["dynamic", "quantized"])
25-
def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press_dict, cache): # noqa: F811
25+
def test_ruler_is_correct(kv_press_qwen3_flash_attn_pipeline, df_ruler, press_dict, cache): # noqa: F811
2626
cls = press_dict["cls"]
2727
kwargs = press_dict["kwargs"][0]
2828
press = cls(**kwargs)
@@ -49,5 +49,5 @@ def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press
4949
question = df_ruler.iloc[idx]["question"]
5050
true_answer = df_ruler.iloc[idx]["answer"][0]
5151

52-
pred_answer = kv_press_llama3_1_flash_attn_pipeline(context, question=question, press=press, cache=cache)["answer"]
52+
pred_answer = kv_press_qwen3_flash_attn_pipeline(context, question=question, press=press, cache=cache)["answer"]
5353
assert true_answer in pred_answer

0 commit comments

Comments
 (0)