@@ -27,6 +27,10 @@ def mock_model():
27
27
28
28
# Simple implementation of PromptScorer for testing
29
29
class SampleScorer (PromptScorer ):
30
+ def __init__ (self , mock_model , * args , ** kwargs ):
31
+ super ().__init__ (* args , ** kwargs )
32
+ self .model = mock_model
33
+
30
34
def build_measure_prompt (self , example : Example ) -> List [dict ]:
31
35
return [
32
36
{"role" : "system" , "content" : "Test system prompt" },
@@ -44,19 +48,19 @@ def success_check(self, **kwargs) -> bool:
44
48
45
49
# Tests for PromptScorer
46
50
class TestPromptScorer :
47
- def test_init (self ):
48
- scorer = SampleScorer ("test_scorer" )
51
+ def test_init (self , mock_model ):
52
+ scorer = SampleScorer (name = "test_scorer" , mock_model = mock_model )
49
53
assert scorer .name == "test_scorer"
50
54
assert scorer .threshold == 0.5
51
55
assert scorer .include_reason is True
52
56
assert scorer .async_mode is True
53
57
54
- def test_init_strict_mode (self ):
55
- scorer = SampleScorer ("test_scorer" , strict_mode = True )
58
+ def test_init_strict_mode (self , mock_model ):
59
+ scorer = SampleScorer (name = "test_scorer" , mock_model = mock_model , strict_mode = True )
56
60
assert scorer .threshold == 1
57
61
58
- def test_enforce_prompt_format (self ):
59
- scorer = SampleScorer ("test_scorer" )
62
+ def test_enforce_prompt_format (self , mock_model ):
63
+ scorer = SampleScorer (name = "test_scorer" , mock_model = mock_model )
60
64
prompt = [{"role" : "system" , "content" : "Base prompt" }]
61
65
schema = {"score" : float , "reason" : str }
62
66
@@ -65,23 +69,21 @@ def test_enforce_prompt_format(self):
65
69
assert '"score": <score> (float)' in formatted [0 ]["content" ]
66
70
assert '"reason": <reason> (str)' in formatted [0 ]["content" ]
67
71
68
- def test_enforce_prompt_format_invalid_input (self ):
69
- scorer = SampleScorer ("test_scorer" )
72
+ def test_enforce_prompt_format_invalid_input (self , mock_model ):
73
+ scorer = SampleScorer (name = "test_scorer" , mock_model = mock_model )
70
74
with pytest .raises (TypeError ):
71
75
scorer .enforce_prompt_format ("invalid" , {})
72
76
73
77
@pytest .mark .asyncio
74
78
async def test_a_score_example (self , example , mock_model ):
75
- scorer = SampleScorer ("test_scorer" )
76
- scorer .model = mock_model
79
+ scorer = SampleScorer (name = "test_scorer" , mock_model = mock_model )
77
80
78
81
result = await scorer .a_score_example (example , _show_indicator = False )
79
82
assert result == 0.8
80
83
assert scorer .reason == "Test reason"
81
84
82
85
def test_score_example_sync (self , example , mock_model ):
83
- scorer = SampleScorer ("test_scorer" , async_mode = False )
84
- scorer .model = mock_model
86
+ scorer = SampleScorer (name = "test_scorer" , mock_model = mock_model , async_mode = False )
85
87
86
88
result = scorer .score_example (example , _show_indicator = False )
87
89
assert result == 0.8
@@ -102,28 +104,28 @@ def classifier_options(self):
102
104
103
105
def test_classifier_init (self , classifier_conversation , classifier_options ):
104
106
scorer = ClassifierScorer (
105
- "test_classifier" ,
106
- classifier_conversation ,
107
- classifier_options
107
+ name = "test_classifier" ,
108
+ conversation = classifier_conversation ,
109
+ options = classifier_options
108
110
)
109
111
assert scorer .conversation == classifier_conversation
110
112
assert scorer .options == classifier_options
111
113
112
114
def test_build_measure_prompt (self , example , classifier_conversation , classifier_options ):
113
115
scorer = ClassifierScorer (
114
- "test_classifier" ,
115
- classifier_conversation ,
116
- classifier_options
116
+ name = "test_classifier" ,
117
+ conversation = classifier_conversation ,
118
+ options = classifier_options
117
119
)
118
120
119
121
prompt = scorer .build_measure_prompt (example )
120
122
assert "This is a test response" in prompt [0 ]["content" ]
121
123
122
124
def test_process_response (self , classifier_conversation , classifier_options ):
123
125
scorer = ClassifierScorer (
124
- "test_classifier" ,
125
- classifier_conversation ,
126
- classifier_options
126
+ name = "test_classifier" ,
127
+ conversation = classifier_conversation ,
128
+ options = classifier_options
127
129
)
128
130
129
131
response = {"choice" : "positive" , "reason" : "Test reason" }
@@ -133,9 +135,9 @@ def test_process_response(self, classifier_conversation, classifier_options):
133
135
134
136
def test_process_response_invalid_choice (self , classifier_conversation , classifier_options ):
135
137
scorer = ClassifierScorer (
136
- "test_classifier" ,
137
- classifier_conversation ,
138
- classifier_options
138
+ name = "test_classifier" ,
139
+ conversation = classifier_conversation ,
140
+ options = classifier_options
139
141
)
140
142
141
143
response = {"choice" : "invalid" , "reason" : "Test reason" }
@@ -144,9 +146,9 @@ def test_process_response_invalid_choice(self, classifier_conversation, classifi
144
146
145
147
def test_success_check (self , classifier_conversation , classifier_options ):
146
148
scorer = ClassifierScorer (
147
- "test_classifier" ,
148
- classifier_conversation ,
149
- classifier_options
149
+ name = "test_classifier" ,
150
+ conversation = classifier_conversation ,
151
+ options = classifier_options
150
152
)
151
153
152
154
scorer .score = 1.0
0 commit comments