8
8
from transformers .utils import is_flash_attn_2_available , is_optimum_quanto_available
9
9
10
10
from tests .default_presses import default_presses
11
- from tests .fixtures import kv_press_llama3_1_flash_attn_pipeline # noqa: F401
11
+ from tests .fixtures import kv_press_qwen3_flash_attn_pipeline # noqa: F401
12
12
13
13
14
14
@pytest .fixture (scope = "session" )
@@ -22,7 +22,7 @@ def df_ruler():
22
22
@pytest .mark .skipif (not is_flash_attn_2_available (), reason = "flash_attn is not installed" )
23
23
@pytest .mark .parametrize ("press_dict" , default_presses )
24
24
@pytest .mark .parametrize ("cache" , ["dynamic" , "quantized" ])
25
- def test_ruler_is_correct (kv_press_llama3_1_flash_attn_pipeline , df_ruler , press_dict , cache ): # noqa: F811
25
+ def test_ruler_is_correct (kv_press_qwen3_flash_attn_pipeline , df_ruler , press_dict , cache ): # noqa: F811
26
26
cls = press_dict ["cls" ]
27
27
kwargs = press_dict ["kwargs" ][0 ]
28
28
press = cls (** kwargs )
@@ -49,5 +49,5 @@ def test_ruler_is_correct(kv_press_llama3_1_flash_attn_pipeline, df_ruler, press
49
49
question = df_ruler .iloc [idx ]["question" ]
50
50
true_answer = df_ruler .iloc [idx ]["answer" ][0 ]
51
51
52
- pred_answer = kv_press_llama3_1_flash_attn_pipeline (context , question = question , press = press , cache = cache )["answer" ]
52
+ pred_answer = kv_press_qwen3_flash_attn_pipeline (context , question = question , press = press , cache = cache )["answer" ]
53
53
assert true_answer in pred_answer
0 commit comments