Skip to content

Commit 7fc322f

Browse files
authored
Pydantic Typing (#268)
* Pydantic Typing * UT fix
1 parent 186a33b commit 7fc322f

File tree

5 files changed

+32
-11
lines changed

5 files changed

+32
-11
lines changed

src/demo/sequence_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def generate_itinerary(destination, start_date, end_date):
159159
judgment.assert_test(
160160
project_name="travel_agent_demo",
161161
examples=[example],
162-
scorers=[ToolOrderScorer(threshold=0.5)],
162+
scorers=[ToolOrderScorer()],
163163
model="gpt-4.1-mini",
164164
function=generate_itinerary,
165165
tracer=tracer,

src/judgeval/data/example.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import BaseModel, Field, field_validator
99
from enum import Enum
1010
from datetime import datetime
11+
from judgeval.data.tool import Tool
1112
import time
1213

1314

@@ -31,7 +32,7 @@ class Example(BaseModel):
3132
retrieval_context: Optional[List[str]] = None
3233
additional_metadata: Optional[Dict[str, Any]] = None
3334
tools_called: Optional[List[str]] = None
34-
expected_tools: Optional[List[Dict[str, Any]]] = None
35+
expected_tools: Optional[List[Tool]] = None
3536
name: Optional[str] = None
3637
example_id: str = Field(default_factory=lambda: str(uuid4()))
3738
example_index: Optional[int] = None
@@ -82,17 +83,17 @@ def validate_expected_output(cls, v):
8283
raise ValueError(f"All items in expected_output must be strings but got {v}")
8384
return v
8485

85-
@field_validator('expected_tools', mode='before')
86+
@field_validator('expected_tools')
8687
@classmethod
8788
def validate_expected_tools(cls, v):
8889
if v is not None:
8990
if not isinstance(v, list):
90-
raise ValueError(f"Expected tools must be a list of dictionaries or None but got {v} of type {type(v)}")
91+
raise ValueError(f"Expected tools must be a list of Tools or None but got {v} of type {type(v)}")
9192

92-
# Check that each item in the list is a dictionary
93+
# Check that each item in the list is a Tool
9394
for i, item in enumerate(v):
94-
if not isinstance(item, dict):
95-
raise ValueError(f"Expected tools must be a list of dictionaries, but item at index {i} is {item} of type {type(item)}")
95+
if not isinstance(item, Tool):
96+
raise ValueError(f"Expected tools must be a list of Tools, but item at index {i} is {item} of type {type(item)}")
9697

9798
return v
9899

src/judgeval/data/tool.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from pydantic import BaseModel, field_validator
2+
from typing import Dict, Any, Optional
3+
import warnings
4+
5+
class Tool(BaseModel):
6+
tool_name: str
7+
parameters: Optional[Dict[str, Any]] = None
8+
9+
@field_validator('tool_name')
10+
def validate_tool_name(cls, v):
11+
if not v:
12+
warnings.warn("Tool name is empty or None", UserWarning)
13+
return v
14+
15+
@field_validator('parameters')
16+
def validate_parameters(cls, v):
17+
if v is not None and not isinstance(v, dict):
18+
warnings.warn(f"Parameters should be a dictionary, got {type(v)}", UserWarning)
19+
return v

src/judgeval/data/trace.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pydantic import BaseModel
22
from typing import Optional, Dict, Any, List
33
from judgeval.evaluation_run import EvaluationRun
4+
from judgeval.data.tool import Tool
45
import json
56
from datetime import datetime, timezone
67

@@ -17,7 +18,7 @@ class TraceSpan(BaseModel):
1718
duration: Optional[float] = None
1819
annotation: Optional[List[Dict[str, Any]]] = None
1920
evaluation_runs: Optional[List[EvaluationRun]] = []
20-
expected_tools: Optional[List[Dict[str, Any]]] = None
21+
expected_tools: Optional[List[Tool]] = None
2122
additional_metadata: Optional[Dict[str, Any]] = None
2223

2324
def model_dump(self, **kwargs):

src/tests/data/test_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import datetime
77
from pydantic import ValidationError
88
from judgeval.data import Example
9-
9+
from judgeval.data.tool import Tool
1010

1111
def test_basic_example_creation():
1212
example = Example(
@@ -30,7 +30,7 @@ def test_full_example_creation():
3030
retrieval_context=["retrieval1", "retrieval2"],
3131
additional_metadata={"key": "value"},
3232
tools_called=["tool1", "tool2"],
33-
expected_tools=[{"tool_name": "expected_tool1"}, {"tool_name": "expected_tool2"}],
33+
expected_tools=[Tool(tool_name="expected_tool1"), Tool(tool_name="expected_tool2")],
3434
name="test example",
3535
example_id="123",
3636
timestamp="20240101_120000",
@@ -43,7 +43,7 @@ def test_full_example_creation():
4343
assert example.retrieval_context == ["retrieval1", "retrieval2"]
4444
assert example.additional_metadata == {"key": "value"}
4545
assert example.tools_called == ["tool1", "tool2"]
46-
assert example.expected_tools == [{"tool_name": "expected_tool1"}, {"tool_name": "expected_tool2"}]
46+
assert example.expected_tools == [Tool(tool_name="expected_tool1"), Tool(tool_name="expected_tool2")]
4747
assert example.name == "test example"
4848
assert example.example_id == "123"
4949
assert example.timestamp == "20240101_120000"

0 commit comments

Comments
 (0)