Skip to content

Commit 6e34acb

Browse files
authored
Merge pull request #26 from JudgmentLabs/alex/add-unit-tests
Add UT for Data and Scorers library
2 parents 685f864 + 50e2c9a commit 6e34acb

20 files changed

+2697
-75
lines changed

Pipfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ uvicorn = "*"
1414
deepeval = "*"
1515
supabase = "*"
1616
requests = "*"
17+
pandas = "*"
18+
anthropic = "*"
1719

1820
[dev-packages]
1921
pytest = "*"

judgeval/constants.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from enum import Enum
66
import litellm
77

8-
class APIScorer(Enum):
8+
class APIScorer(str, Enum):
99
"""
1010
Collection of proprietary scorers implemented by Judgment.
1111
@@ -20,7 +20,13 @@ class APIScorer(Enum):
2020
CONTEXTUAL_RELEVANCY = "contextual_relevancy"
2121
CONTEXTUAL_PRECISION = "contextual_precision"
2222
TOOL_CORRECTNESS = "tool_correctness"
23-
23+
24+
@classmethod
25+
def _missing_(cls, value):
26+
# Handle case-insensitive lookup
27+
for member in cls:
28+
if member.value == value.lower():
29+
return member
2430

2531
ROOT_API = "http://127.0.0.1:8000"
2632
# ROOT_API = "https://api.judgmentlabs.ai" # TODO replace this with the actual API root

judgeval/data/api_example.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,24 @@ class ProcessExample(BaseModel):
1313
"""
1414
name: str
1515
input: Optional[str] = None
16-
actual_output: Optional[str] = Field(None, alias="actualOutput")
17-
expected_output: Optional[str] = Field(None, alias="expectedOutput")
18-
context: Optional[list] = Field(None)
19-
retrieval_context: Optional[list] = Field(None, alias="retrievalContext")
20-
tools_called: Optional[list] = Field(None, alias="toolsCalled")
21-
expected_tools: Optional[list] = Field(None, alias="expectedTools")
16+
actual_output: Optional[str] = None
17+
expected_output: Optional[str] = None
18+
context: Optional[list] = None
19+
retrieval_context: Optional[list] = None
20+
tools_called: Optional[list] = None
21+
expected_tools: Optional[list] = None
2222

2323
# make these optional, not all test cases in a conversation will be evaluated
24-
success: Union[bool, None] = Field(None)
25-
scorers_data: Union[List[ScorerData], None] = Field(
26-
None, alias="scorersData"
27-
)
28-
run_duration: Union[float, None] = Field(None, alias="runDuration")
29-
evaluation_cost: Union[float, None] = Field(None, alias="evaluationCost")
24+
success: Optional[bool] = None
25+
scorers_data: Optional[List[ScorerData]] = None
26+
run_duration: Optional[float] = None
27+
evaluation_cost: Optional[float] = None
3028

31-
order: Union[int, None] = Field(None)
29+
order: Optional[int] = None
3230
# These should map 1 to 1 from golden
33-
additional_metadata: Optional[Dict] = Field(
34-
None, alias="additionalMetadata"
35-
)
36-
comments: Optional[str] = Field(None)
37-
trace_id: Optional[str] = Field(None)
31+
additional_metadata: Optional[Dict] = None
32+
comments: Optional[str] = None
33+
trace_id: Optional[str] = None
3834
model_config = ConfigDict(arbitrary_types_allowed=True)
3935

4036
def update_scorer_data(self, scorer_data: ScorerData):
@@ -65,12 +61,12 @@ def update_run_duration(self, run_duration: float):
6561
@model_validator(mode="before")
6662
def check_input(cls, values: Dict[str, Any]):
6763
input = values.get("input")
68-
actual_output = values.get("actualOutput")
64+
actual_output = values.get("actual_output")
6965

7066
if (input is None or actual_output is None):
71-
error(f"Validation error: Required fields missing. input={input}, actualOutput={actual_output}")
67+
error(f"Validation error: Required fields missing. input={input}, actual_output={actual_output}")
7268
raise ValueError(
73-
"'input' and 'actualOutput' must be provided."
69+
"'input' and 'actual_output' must be provided."
7470
)
7571

7672
return values
@@ -97,18 +93,18 @@ def create_process_example(
9793
process_ex = ProcessExample(
9894
name=name,
9995
input=example.input,
100-
actualOutput=example.actual_output,
101-
expectedOutput=example.expected_output,
96+
actual_output=example.actual_output,
97+
expected_output=example.expected_output,
10298
context=example.context,
103-
retrievalContext=example.retrieval_context,
104-
toolsCalled=example.tools_called,
105-
expectedTools=example.expected_tools,
99+
retrieval_context=example.retrieval_context,
100+
tools_called=example.tools_called,
101+
expected_tools=example.expected_tools,
106102
success=success,
107-
scorersData=scorers_data,
108-
runDuration=None,
109-
evaluationCost=None,
103+
scorers_data=scorers_data,
104+
run_duration=None,
105+
evaluation_cost=None,
110106
order=order,
111-
additionalMetadata=example.additional_metadata,
107+
additional_metadata=example.additional_metadata,
112108
trace_id=example.trace_id
113109
)
114110
return process_ex

judgeval/data/datasets/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ def examples_to_ground_truths(examples: List[Example]) -> List[GroundTruthExampl
1414
Returns:
1515
List[GroundTruthExample]: A list of `GroundTruthExample` objects.
1616
"""
17+
18+
if not isinstance(examples, list):
19+
raise TypeError("Input should be a list of `Example` objects")
20+
21+
ground_truths = []
1722
ground_truths = []
1823
for e in examples:
1924
g_truth = {
@@ -45,6 +50,10 @@ def ground_truths_to_examples(
4550
Returns:
4651
List[Example]: A list of `Example` objects.
4752
"""
53+
54+
if not isinstance(ground_truths, list):
55+
raise TypeError("Input should be a list of `GroundTruthExample` objects")
56+
4857
examples = []
4958
for index, ground_truth in enumerate(ground_truths):
5059
e = Example(

judgeval/data/example.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -37,41 +37,6 @@ class Example(BaseModel):
3737
timestamp: Optional[str] = None
3838
trace_id: Optional[str] = None
3939

40-
def __post_init__(self):
41-
# Ensure `context` is None or a list of strings
42-
if self.context is not None:
43-
if not isinstance(self.context, list) or not all(
44-
isinstance(item, str) for item in self.context
45-
):
46-
raise TypeError("'context' must be None or a list of strings")
47-
48-
# Ensure `retrieval_context` is None or a list of strings
49-
if self.retrieval_context is not None:
50-
if not isinstance(self.retrieval_context, list) or not all(
51-
isinstance(item, str) for item in self.retrieval_context
52-
):
53-
raise TypeError(
54-
"'retrieval_context' must be None or a list of strings"
55-
)
56-
57-
# Ensure `tools_called` is None or a list of strings
58-
if self.tools_called is not None:
59-
if not isinstance(self.tools_called, list) or not all(
60-
isinstance(item, str) for item in self.tools_called
61-
):
62-
raise TypeError(
63-
"'tools_called' must be None or a list of strings"
64-
)
65-
66-
# Ensure `expected_tools` is None or a list of strings
67-
if self.expected_tools is not None:
68-
if not isinstance(self.expected_tools, list) or not all(
69-
isinstance(item, str) for item in self.expected_tools
70-
):
71-
raise TypeError(
72-
"'expected_tools' must be None or a list of strings"
73-
)
74-
7540
def __init__(self, **data):
7641
super().__init__(**data)
7742
# Set timestamp if not provided

judgeval/scorers/custom_scorer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from abc import abstractmethod
1010

1111
from judgeval.common.logger import debug, info, warning, error
12-
from judgeval.data import Example
1312
from judgeval.judges import judgevalJudge
1413
from judgeval.judges.utils import create_judge
1514

@@ -84,7 +83,7 @@ def _add_model(self, model: Optional[Union[str, List[str], judgevalJudge]] = Non
8483
self.evaluation_model = self.model.get_model_name()
8584

8685
@abstractmethod
87-
def score_example(self, example: Example, *args, **kwargs) -> float:
86+
def score_example(self, example, *args, **kwargs) -> float:
8887
"""
8988
Measures the score on a single example
9089
"""
@@ -93,7 +92,7 @@ def score_example(self, example: Example, *args, **kwargs) -> float:
9392
raise NotImplementedError("You must implement the `score` method in your custom scorer")
9493

9594
@abstractmethod
96-
async def a_score_example(self, example: Example, *args, **kwargs) -> float:
95+
async def a_score_example(self, example, *args, **kwargs) -> float:
9796
"""
9897
Asynchronously measures the score on a single example
9998
"""

judgeval/scorers/prompt_scorer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def score_example(
6868
"""
6969
Synchronous method for scoring an example using the prompt criteria.
7070
"""
71-
with scorer_progress_meter(self, _show_indicator=_show_indicator):
71+
with scorer_progress_meter(self, display_meter=_show_indicator):
7272
if self.async_mode:
7373
loop = get_or_create_event_loop()
7474
loop.run_until_complete(
@@ -217,7 +217,7 @@ def enforce_prompt_format(self, judge_prompt: List[dict], schema: dict):
217217
# create formatting string for schema enforcement
218218
# schema is a map between key and type of the value
219219
for key, key_type in schema.items():
220-
SCHEMA_ENFORCEMENT_PROMPT += f'"{key}": <{key}> ({key_type}), '
220+
SCHEMA_ENFORCEMENT_PROMPT += f'"{key}": <{key}> ({key_type.__name__}), '
221221
SCHEMA_ENFORCEMENT_PROMPT = SCHEMA_ENFORCEMENT_PROMPT[:-2] + "}" # remove trailing comma and space
222222
judge_prompt[0]["content"] += SCHEMA_ENFORCEMENT_PROMPT
223223
return judge_prompt

judgeval/scorers/score.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,15 @@ async def a_execute_scoring(
273273
semaphore = asyncio.Semaphore(max_concurrent)
274274

275275
async def execute_with_semaphore(func: Callable, *args, **kwargs):
276-
async with semaphore:
277-
return await func(*args, **kwargs)
276+
try:
277+
async with semaphore:
278+
return await func(*args, **kwargs)
279+
except Exception as e:
280+
error(f"Error executing function: {e}")
281+
if kwargs.get('ignore_errors', False):
282+
# Return None when ignoring errors
283+
return None
284+
raise
278285

279286
if verbose_mode is not None:
280287
for scorer in scorers:
@@ -406,7 +413,7 @@ async def a_eval_examples_helper(
406413
# the results and update the process example with the scorer data
407414
for scorer in scorers:
408415
# At this point, the scorer has been executed and already contains data.
409-
if scorer.skipped:
416+
if getattr(scorer, 'skipped', False):
410417
continue
411418

412419
scorer_data = create_scorer_data(scorer) # Fetch scorer data from completed scorer evaluation

0 commit comments

Comments
 (0)