Skip to content

Commit 5322400

Browse files
committed
Edit UT to pass GH requirements >:)
1 parent bc372ab commit 5322400

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

tests/scorers/test_prompt_scorer.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def mock_model():
2727

2828
# Simple implementation of PromptScorer for testing
2929
class SampleScorer(PromptScorer):
30+
def __init__(self, mock_model, *args, **kwargs):
31+
super().__init__(*args, **kwargs)
32+
self.model = mock_model
33+
3034
def build_measure_prompt(self, example: Example) -> List[dict]:
3135
return [
3236
{"role": "system", "content": "Test system prompt"},
@@ -44,19 +48,19 @@ def success_check(self, **kwargs) -> bool:
4448

4549
# Tests for PromptScorer
4650
class TestPromptScorer:
47-
def test_init(self):
48-
scorer = SampleScorer("test_scorer")
51+
def test_init(self, mock_model):
52+
scorer = SampleScorer(name="test_scorer", mock_model=mock_model)
4953
assert scorer.name == "test_scorer"
5054
assert scorer.threshold == 0.5
5155
assert scorer.include_reason is True
5256
assert scorer.async_mode is True
5357

54-
def test_init_strict_mode(self):
55-
scorer = SampleScorer("test_scorer", strict_mode=True)
58+
def test_init_strict_mode(self, mock_model):
59+
scorer = SampleScorer(name="test_scorer", mock_model=mock_model, strict_mode=True)
5660
assert scorer.threshold == 1
5761

58-
def test_enforce_prompt_format(self):
59-
scorer = SampleScorer("test_scorer")
62+
def test_enforce_prompt_format(self, mock_model):
63+
scorer = SampleScorer(name="test_scorer", mock_model=mock_model)
6064
prompt = [{"role": "system", "content": "Base prompt"}]
6165
schema = {"score": float, "reason": str}
6266

@@ -65,23 +69,21 @@ def test_enforce_prompt_format(self):
6569
assert '"score": <score> (float)' in formatted[0]["content"]
6670
assert '"reason": <reason> (str)' in formatted[0]["content"]
6771

68-
def test_enforce_prompt_format_invalid_input(self):
69-
scorer = SampleScorer("test_scorer")
72+
def test_enforce_prompt_format_invalid_input(self, mock_model):
73+
scorer = SampleScorer(name="test_scorer", mock_model=mock_model)
7074
with pytest.raises(TypeError):
7175
scorer.enforce_prompt_format("invalid", {})
7276

7377
@pytest.mark.asyncio
7478
async def test_a_score_example(self, example, mock_model):
75-
scorer = SampleScorer("test_scorer")
76-
scorer.model = mock_model
79+
scorer = SampleScorer(name="test_scorer", mock_model=mock_model)
7780

7881
result = await scorer.a_score_example(example, _show_indicator=False)
7982
assert result == 0.8
8083
assert scorer.reason == "Test reason"
8184

8285
def test_score_example_sync(self, example, mock_model):
83-
scorer = SampleScorer("test_scorer", async_mode=False)
84-
scorer.model = mock_model
86+
scorer = SampleScorer(name="test_scorer", mock_model=mock_model, async_mode=False)
8587

8688
result = scorer.score_example(example, _show_indicator=False)
8789
assert result == 0.8
@@ -102,28 +104,28 @@ def classifier_options(self):
102104

103105
def test_classifier_init(self, classifier_conversation, classifier_options):
104106
scorer = ClassifierScorer(
105-
"test_classifier",
106-
classifier_conversation,
107-
classifier_options
107+
name="test_classifier",
108+
conversation=classifier_conversation,
109+
options=classifier_options
108110
)
109111
assert scorer.conversation == classifier_conversation
110112
assert scorer.options == classifier_options
111113

112114
def test_build_measure_prompt(self, example, classifier_conversation, classifier_options):
113115
scorer = ClassifierScorer(
114-
"test_classifier",
115-
classifier_conversation,
116-
classifier_options
116+
name="test_classifier",
117+
conversation=classifier_conversation,
118+
options=classifier_options
117119
)
118120

119121
prompt = scorer.build_measure_prompt(example)
120122
assert "This is a test response" in prompt[0]["content"]
121123

122124
def test_process_response(self, classifier_conversation, classifier_options):
123125
scorer = ClassifierScorer(
124-
"test_classifier",
125-
classifier_conversation,
126-
classifier_options
126+
name="test_classifier",
127+
conversation=classifier_conversation,
128+
options=classifier_options
127129
)
128130

129131
response = {"choice": "positive", "reason": "Test reason"}
@@ -133,9 +135,9 @@ def test_process_response(self, classifier_conversation, classifier_options):
133135

134136
def test_process_response_invalid_choice(self, classifier_conversation, classifier_options):
135137
scorer = ClassifierScorer(
136-
"test_classifier",
137-
classifier_conversation,
138-
classifier_options
138+
name="test_classifier",
139+
conversation=classifier_conversation,
140+
options=classifier_options
139141
)
140142

141143
response = {"choice": "invalid", "reason": "Test reason"}
@@ -144,9 +146,9 @@ def test_process_response_invalid_choice(self, classifier_conversation, classifi
144146

145147
def test_success_check(self, classifier_conversation, classifier_options):
146148
scorer = ClassifierScorer(
147-
"test_classifier",
148-
classifier_conversation,
149-
classifier_options
149+
name="test_classifier",
150+
conversation=classifier_conversation,
151+
options=classifier_options
150152
)
151153

152154
scorer.score = 1.0

0 commit comments

Comments
 (0)