Skip to content

Commit 2d7de4c

Browse files
committed
Update tracer UT to work with the new Tracer implementation
1 parent f8041d3 commit 2d7de4c

File tree

1 file changed

+166
-120
lines changed

1 file changed

+166
-120
lines changed

tests/common/test_tracer.py

Lines changed: 166 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,51 @@
11
import pytest
22
import time
3-
from unittest.mock import Mock, patch
4-
from judgeval.common.tracer import tracer, TraceEntry, TraceClient
3+
from unittest.mock import Mock, patch, MagicMock
4+
from datetime import datetime
5+
from openai import OpenAI
6+
from together import Together
7+
from anthropic import Anthropic
8+
9+
from judgeval.common.tracer import Tracer, TraceEntry, TraceClient, wrap
510

611
@pytest.fixture
7-
def reset_tracer():
8-
"""Reset tracer state between tests"""
9-
tracer.depth = 0
10-
tracer._current_trace = None
11-
tracer.api_key = None
12-
yield
13-
12+
def tracer():
13+
"""Provide a configured tracer instance"""
14+
return Tracer(api_key="test_api_key")
15+
1416
@pytest.fixture
15-
def configured_tracer(reset_tracer):
16-
"""Provide a configured tracer"""
17-
tracer.configure("test_api_key")
18-
return tracer
17+
def trace_client(tracer):
18+
"""Provide a trace client instance"""
19+
with tracer.trace("test_trace") as client:
20+
yield client
1921

2022
def test_tracer_singleton():
2123
"""Test that Tracer maintains singleton pattern"""
22-
from judgeval.common.tracer import Tracer
23-
tracer1 = Tracer()
24-
tracer2 = Tracer()
24+
tracer1 = Tracer(api_key="test1")
25+
tracer2 = Tracer(api_key="test2")
2526
assert tracer1 is tracer2
27+
assert tracer1.api_key == "test2" # Should have new api_key
2628

27-
def test_tracer_configuration(reset_tracer):
28-
"""Test tracer configuration"""
29-
assert tracer.api_key is None
30-
tracer.configure("test_api_key")
31-
assert tracer.api_key == "test_api_key"
29+
def test_tracer_requires_api_key():
30+
"""Test that Tracer requires an API key"""
31+
# Clear any existing singleton instance first
32+
Tracer._instance = None
33+
34+
with pytest.raises(ValueError):
35+
tracer = Tracer(api_key=None)
36+
print(tracer.api_key)
3237

3338
def test_trace_entry_print(capsys):
3439
"""Test TraceEntry print formatting"""
35-
# Test each type of entry
3640
entries = [
37-
TraceEntry(type="enter", function="test_func", depth=1, message="", timestamp=0),
38-
TraceEntry(type="exit", function="test_func", depth=1, message="", timestamp=0, duration=0.5),
39-
TraceEntry(type="output", function="test_func", depth=1, message="", timestamp=0, output="result"),
40-
TraceEntry(type="input", function="test_func", depth=1, message="", timestamp=0, inputs={"arg": 1}),
41+
TraceEntry(type="enter", function="test_func", depth=1, message="test", timestamp=0),
42+
TraceEntry(type="exit", function="test_func", depth=1, message="test", timestamp=0, duration=0.5),
43+
TraceEntry(type="output", function="test_func", depth=1, message="test", timestamp=0, output="result"),
44+
TraceEntry(type="input", function="test_func", depth=1, message="test", timestamp=0, inputs={"arg": 1}),
4145
]
4246

4347
expected_outputs = [
44-
" → test_func\n",
48+
" → test_func (trace: test)\n",
4549
" ← test_func (0.500s)\n",
4650
" Output: result\n",
4751
" Input: {'arg': 1}\n",
@@ -52,124 +56,166 @@ def test_trace_entry_print(capsys):
5256
captured = capsys.readouterr()
5357
assert captured.out == expected
5458

55-
@pytest.mark.asyncio
56-
async def test_trace_client_operations(configured_tracer):
57-
"""Test TraceClient basic operations"""
58-
trace = configured_tracer.start_trace("test_trace")
59-
60-
# Test adding entries
59+
def test_trace_entry_to_dict():
60+
"""Test TraceEntry serialization"""
61+
# Test basic serialization
6162
entry = TraceEntry(
6263
type="enter",
6364
function="test_func",
64-
depth=0,
65+
depth=1,
6566
message="test",
66-
timestamp=time.time()
67+
timestamp=0
6768
)
68-
trace.add_entry(entry)
69-
assert len(trace.entries) == 1
70-
assert trace.entries[0] == entry
71-
72-
# Test duration calculation
73-
time.sleep(0.1)
74-
duration = trace.get_duration()
75-
assert duration > 0
69+
data = entry.to_dict()
70+
assert data["type"] == "enter"
71+
assert data["function"] == "test_func"
7672

77-
def test_observe_decorator(configured_tracer):
78-
"""Test the @tracer.observe decorator"""
79-
results = []
73+
# Test with non-serializable output
74+
class NonSerializable:
75+
pass
8076

81-
@tracer.observe
82-
def test_function(x, y):
83-
results.append(f"Function called with {x}, {y}")
84-
return x + y
85-
86-
with patch.object(TraceClient, 'add_entry') as mock_add_entry:
87-
trace = configured_tracer.start_trace("test_trace")
88-
result = test_function(1, 2)
89-
90-
# Verify function execution
91-
assert result == 3
92-
assert results == ["Function called with 1, 2"]
93-
94-
# Verify trace entries
95-
assert mock_add_entry.call_count == 4 # enter, input, output, exit
96-
97-
# Verify entry types
98-
entry_types = [call.args[0].type for call in mock_add_entry.call_args_list]
99-
assert entry_types == ["enter", "input", "output", "exit"]
100-
101-
@pytest.mark.asyncio
102-
async def test_save_trace(configured_tracer, mocker):
103-
"""Test saving trace data to API"""
104-
mock_post = mocker.patch('requests.post')
105-
mock_post.return_value.raise_for_status = Mock()
106-
107-
trace = configured_tracer.start_trace("test_trace")
108-
109-
# Add some test entries
110-
@tracer.observe
111-
def test_function():
112-
return "test_result"
77+
entry = TraceEntry(
78+
type="output",
79+
function="test_func",
80+
depth=1,
81+
message="test",
82+
timestamp=0,
83+
output=NonSerializable()
84+
)
11385

114-
test_function()
86+
with pytest.warns(UserWarning):
87+
data = entry.to_dict()
88+
assert data["output"] is None
89+
90+
def test_trace_client_span(trace_client):
91+
"""Test span context manager"""
92+
initial_entries = len(trace_client.entries) # Get initial count
11593

116-
# Save trace
117-
trace_id, trace_data = trace.save_trace()
94+
with trace_client.span("test_span") as span:
95+
assert trace_client._current_span == "test_span"
96+
assert len(trace_client.entries) == initial_entries + 1 # Compare to initial count
11897

119-
# Verify API call
120-
mock_post.assert_called_once()
121-
assert mock_post.call_args[1]['json']['trace_id'] == trace_id
122-
assert mock_post.call_args[1]['json']['name'] == "test_trace"
123-
assert len(mock_post.call_args[1]['json']['entries']) > 0
98+
assert len(trace_client.entries) == initial_entries + 2 # Account for both enter and exit
99+
assert trace_client.entries[-1].type == "exit"
100+
assert trace_client._current_span == "test_trace"
101+
102+
def test_trace_client_nested_spans(trace_client):
103+
"""Test nested spans maintain proper depth"""
104+
with trace_client.span("outer"):
105+
assert trace_client.tracer.depth == 2 # 1 for trace + 1 for span
106+
with trace_client.span("inner"):
107+
assert trace_client.tracer.depth == 3
108+
assert trace_client.tracer.depth == 2
109+
assert trace_client.tracer.depth == 1
110+
111+
def test_record_input_output(trace_client):
112+
"""Test recording inputs and outputs"""
113+
with trace_client.span("test_span"):
114+
trace_client.record_input({"arg": 1})
115+
trace_client.record_output("result")
116+
117+
# Filter entries to only include those for the current span
118+
entries = [e.type for e in trace_client.entries if e.function == "test_span"]
119+
assert entries == ["enter", "input", "output", "exit"]
124120

125-
def test_condense_trace(configured_tracer):
121+
def test_condense_trace(trace_client):
126122
"""Test trace condensing functionality"""
127-
trace = configured_tracer.start_trace("test_trace")
128-
129-
# Create sample entries
130123
entries = [
131124
{"type": "enter", "function": "test_func", "depth": 0, "timestamp": 1.0},
132-
{"type": "input", "function": "test_func", "depth": 0, "timestamp": 1.1, "inputs": {"x": 1}},
133-
{"type": "output", "function": "test_func", "depth": 0, "timestamp": 1.2, "output": "result"},
134-
{"type": "exit", "function": "test_func", "depth": 0, "timestamp": 1.3},
125+
{"type": "input", "function": "test_func", "depth": 1, "timestamp": 1.1, "inputs": {"x": 1}},
126+
{"type": "output", "function": "test_func", "depth": 1, "timestamp": 1.2, "output": "result"},
127+
{"type": "exit", "function": "test_func", "depth": 0, "timestamp": 2.0},
135128
]
136129

137-
condensed = trace.condense_trace(entries)
138-
130+
condensed = trace_client.condense_trace(entries)
139131
assert len(condensed) == 1
140132
assert condensed[0]["function"] == "test_func"
133+
assert condensed[0]["depth"] == 1
141134
assert condensed[0]["inputs"] == {"x": 1}
142135
assert condensed[0]["output"] == "result"
143-
assert condensed[0]["duration"] == pytest.approx(0.3)
136+
assert condensed[0]["duration"] == 1.0
144137

145-
def test_nested_function_depth(configured_tracer):
146-
"""Test depth tracking for nested function calls"""
138+
@patch('requests.post')
139+
def test_save_trace(mock_post, trace_client):
140+
"""Test saving trace data"""
141+
mock_post.return_value.raise_for_status = Mock()
142+
143+
with trace_client.span("test_span"):
144+
trace_client.record_input({"arg": 1})
145+
trace_client.record_output("result")
146+
147+
trace_id, data = trace_client.save()
148+
149+
assert mock_post.called
150+
assert data["trace_id"] == trace_client.trace_id
151+
assert data["name"] == "test_trace"
152+
assert len(data["entries"]) > 0
153+
assert isinstance(data["created_at"], str)
154+
assert isinstance(data["duration"], float)
155+
156+
def test_observe_decorator(tracer):
157+
"""Test the @tracer.observe decorator"""
147158
@tracer.observe
148-
def outer():
149-
@tracer.observe
150-
def inner():
151-
pass
152-
inner()
153-
154-
trace = configured_tracer.start_trace("test_trace")
155-
outer()
156-
157-
# Verify depths
158-
print(f"{trace.entries=}")
159-
depths = [entry.depth for entry in trace.entries]
160-
assert depths == [0, 1, 1, 2, 2, 1, 1, 0] # outer(enter), inner(enter), inner(input), inner(output), inner(exit), outer(exit)
161-
162-
def test_error_handling(configured_tracer):
163-
"""Test error handling in traced functions"""
159+
def test_function(x, y):
160+
return x + y
161+
162+
with tracer.trace("test_trace"):
163+
result = test_function(1, 2)
164+
165+
assert result == 3
166+
167+
def test_observe_decorator_with_error(tracer):
168+
"""Test decorator error handling"""
164169
@tracer.observe
165170
def failing_function():
166171
raise ValueError("Test error")
167172

168-
trace = configured_tracer.start_trace("test_trace")
173+
with tracer.trace("test_trace"):
174+
with pytest.raises(ValueError):
175+
failing_function()
176+
177+
@patch('requests.post')
178+
def test_wrap_openai(mock_post, tracer):
179+
"""Test wrapping OpenAI client"""
180+
client = OpenAI()
181+
mock_response = MagicMock()
182+
mock_response.choices = [MagicMock(message=MagicMock(content="test response"))]
183+
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
184+
client.chat.completions.create = MagicMock(return_value=mock_response)
185+
186+
wrapped_client = wrap(client)
187+
188+
with tracer.trace("test_trace"):
189+
response = wrapped_client.chat.completions.create(
190+
model="gpt-4",
191+
messages=[{"role": "user", "content": "test"}]
192+
)
193+
194+
assert response == mock_response
195+
196+
@patch('requests.post')
197+
def test_wrap_anthropic(mock_post, tracer):
198+
"""Test wrapping Anthropic client"""
199+
client = Anthropic()
200+
mock_response = MagicMock()
201+
mock_response.content = [MagicMock(text="test response")]
202+
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
203+
client.messages.create = MagicMock(return_value=mock_response)
204+
205+
wrapped_client = wrap(client)
206+
207+
with tracer.trace("test_trace"):
208+
response = wrapped_client.messages.create(
209+
model="claude-3",
210+
messages=[{"role": "user", "content": "test"}]
211+
)
212+
213+
assert response == mock_response
214+
215+
def test_wrap_unsupported_client(tracer):
216+
"""Test wrapping unsupported client type"""
217+
class UnsupportedClient:
218+
pass
169219

170220
with pytest.raises(ValueError):
171-
failing_function()
172-
173-
# Verify that exit entry was still recorded
174-
assert trace.entries[-1].type == "exit"
175-
assert tracer.depth == 0 # Depth should be reset even after error
221+
wrap(UnsupportedClient())

0 commit comments

Comments
 (0)