1
1
import pytest
2
2
import time
3
- from unittest .mock import Mock , patch
4
- from judgeval .common .tracer import tracer , TraceEntry , TraceClient
3
+ from unittest .mock import Mock , patch , MagicMock
4
+ from datetime import datetime
5
+ from openai import OpenAI
6
+ from together import Together
7
+ from anthropic import Anthropic
8
+
9
+ from judgeval .common .tracer import Tracer , TraceEntry , TraceClient , wrap
5
10
6
11
@pytest .fixture
7
- def reset_tracer ():
8
- """Reset tracer state between tests"""
9
- tracer .depth = 0
10
- tracer ._current_trace = None
11
- tracer .api_key = None
12
- yield
13
-
12
+ def tracer ():
13
+ """Provide a configured tracer instance"""
14
+ return Tracer (api_key = "test_api_key" )
15
+
14
16
@pytest .fixture
15
- def configured_tracer ( reset_tracer ):
16
- """Provide a configured tracer """
17
- tracer .configure ( "test_api_key" )
18
- return tracer
17
+ def trace_client ( tracer ):
18
+ """Provide a trace client instance """
19
+ with tracer .trace ( "test_trace" ) as client :
20
+ yield client
19
21
20
22
def test_tracer_singleton ():
21
23
"""Test that Tracer maintains singleton pattern"""
22
- from judgeval .common .tracer import Tracer
23
- tracer1 = Tracer ()
24
- tracer2 = Tracer ()
24
+ tracer1 = Tracer (api_key = "test1" )
25
+ tracer2 = Tracer (api_key = "test2" )
25
26
assert tracer1 is tracer2
27
+ assert tracer1 .api_key == "test2" # Should have new api_key
26
28
27
- def test_tracer_configuration (reset_tracer ):
28
- """Test tracer configuration"""
29
- assert tracer .api_key is None
30
- tracer .configure ("test_api_key" )
31
- assert tracer .api_key == "test_api_key"
29
+ def test_tracer_requires_api_key ():
30
+ """Test that Tracer requires an API key"""
31
+ # Clear any existing singleton instance first
32
+ Tracer ._instance = None
33
+
34
+ with pytest .raises (ValueError ):
35
+ tracer = Tracer (api_key = None )
36
+ print (tracer .api_key )
32
37
33
38
def test_trace_entry_print (capsys ):
34
39
"""Test TraceEntry print formatting"""
35
- # Test each type of entry
36
40
entries = [
37
- TraceEntry (type = "enter" , function = "test_func" , depth = 1 , message = "" , timestamp = 0 ),
38
- TraceEntry (type = "exit" , function = "test_func" , depth = 1 , message = "" , timestamp = 0 , duration = 0.5 ),
39
- TraceEntry (type = "output" , function = "test_func" , depth = 1 , message = "" , timestamp = 0 , output = "result" ),
40
- TraceEntry (type = "input" , function = "test_func" , depth = 1 , message = "" , timestamp = 0 , inputs = {"arg" : 1 }),
41
+ TraceEntry (type = "enter" , function = "test_func" , depth = 1 , message = "test " , timestamp = 0 ),
42
+ TraceEntry (type = "exit" , function = "test_func" , depth = 1 , message = "test " , timestamp = 0 , duration = 0.5 ),
43
+ TraceEntry (type = "output" , function = "test_func" , depth = 1 , message = "test " , timestamp = 0 , output = "result" ),
44
+ TraceEntry (type = "input" , function = "test_func" , depth = 1 , message = "test " , timestamp = 0 , inputs = {"arg" : 1 }),
41
45
]
42
46
43
47
expected_outputs = [
44
- " → test_func\n " ,
48
+ " → test_func (trace: test) \n " ,
45
49
" ← test_func (0.500s)\n " ,
46
50
" Output: result\n " ,
47
51
" Input: {'arg': 1}\n " ,
@@ -52,124 +56,166 @@ def test_trace_entry_print(capsys):
52
56
captured = capsys .readouterr ()
53
57
assert captured .out == expected
54
58
55
- @pytest .mark .asyncio
56
- async def test_trace_client_operations (configured_tracer ):
57
- """Test TraceClient basic operations"""
58
- trace = configured_tracer .start_trace ("test_trace" )
59
-
60
- # Test adding entries
59
+ def test_trace_entry_to_dict ():
60
+ """Test TraceEntry serialization"""
61
+ # Test basic serialization
61
62
entry = TraceEntry (
62
63
type = "enter" ,
63
64
function = "test_func" ,
64
- depth = 0 ,
65
+ depth = 1 ,
65
66
message = "test" ,
66
- timestamp = time . time ()
67
+ timestamp = 0
67
68
)
68
- trace .add_entry (entry )
69
- assert len (trace .entries ) == 1
70
- assert trace .entries [0 ] == entry
71
-
72
- # Test duration calculation
73
- time .sleep (0.1 )
74
- duration = trace .get_duration ()
75
- assert duration > 0
69
+ data = entry .to_dict ()
70
+ assert data ["type" ] == "enter"
71
+ assert data ["function" ] == "test_func"
76
72
77
- def test_observe_decorator ( configured_tracer ):
78
- """Test the @tracer.observe decorator"""
79
- results = []
73
+ # Test with non-serializable output
74
+ class NonSerializable :
75
+ pass
80
76
81
- @tracer .observe
82
- def test_function (x , y ):
83
- results .append (f"Function called with { x } , { y } " )
84
- return x + y
85
-
86
- with patch .object (TraceClient , 'add_entry' ) as mock_add_entry :
87
- trace = configured_tracer .start_trace ("test_trace" )
88
- result = test_function (1 , 2 )
89
-
90
- # Verify function execution
91
- assert result == 3
92
- assert results == ["Function called with 1, 2" ]
93
-
94
- # Verify trace entries
95
- assert mock_add_entry .call_count == 4 # enter, input, output, exit
96
-
97
- # Verify entry types
98
- entry_types = [call .args [0 ].type for call in mock_add_entry .call_args_list ]
99
- assert entry_types == ["enter" , "input" , "output" , "exit" ]
100
-
101
- @pytest .mark .asyncio
102
- async def test_save_trace (configured_tracer , mocker ):
103
- """Test saving trace data to API"""
104
- mock_post = mocker .patch ('requests.post' )
105
- mock_post .return_value .raise_for_status = Mock ()
106
-
107
- trace = configured_tracer .start_trace ("test_trace" )
108
-
109
- # Add some test entries
110
- @tracer .observe
111
- def test_function ():
112
- return "test_result"
77
+ entry = TraceEntry (
78
+ type = "output" ,
79
+ function = "test_func" ,
80
+ depth = 1 ,
81
+ message = "test" ,
82
+ timestamp = 0 ,
83
+ output = NonSerializable ()
84
+ )
113
85
114
- test_function ()
86
+ with pytest .warns (UserWarning ):
87
+ data = entry .to_dict ()
88
+ assert data ["output" ] is None
89
+
90
+ def test_trace_client_span (trace_client ):
91
+ """Test span context manager"""
92
+ initial_entries = len (trace_client .entries ) # Get initial count
115
93
116
- # Save trace
117
- trace_id , trace_data = trace .save_trace ()
94
+ with trace_client .span ("test_span" ) as span :
95
+ assert trace_client ._current_span == "test_span"
96
+ assert len (trace_client .entries ) == initial_entries + 1 # Compare to initial count
118
97
119
- # Verify API call
120
- mock_post .assert_called_once ()
121
- assert mock_post .call_args [1 ]['json' ]['trace_id' ] == trace_id
122
- assert mock_post .call_args [1 ]['json' ]['name' ] == "test_trace"
123
- assert len (mock_post .call_args [1 ]['json' ]['entries' ]) > 0
98
+ assert len (trace_client .entries ) == initial_entries + 2 # Account for both enter and exit
99
+ assert trace_client .entries [- 1 ].type == "exit"
100
+ assert trace_client ._current_span == "test_trace"
101
+
102
+ def test_trace_client_nested_spans (trace_client ):
103
+ """Test nested spans maintain proper depth"""
104
+ with trace_client .span ("outer" ):
105
+ assert trace_client .tracer .depth == 2 # 1 for trace + 1 for span
106
+ with trace_client .span ("inner" ):
107
+ assert trace_client .tracer .depth == 3
108
+ assert trace_client .tracer .depth == 2
109
+ assert trace_client .tracer .depth == 1
110
+
111
+ def test_record_input_output (trace_client ):
112
+ """Test recording inputs and outputs"""
113
+ with trace_client .span ("test_span" ):
114
+ trace_client .record_input ({"arg" : 1 })
115
+ trace_client .record_output ("result" )
116
+
117
+ # Filter entries to only include those for the current span
118
+ entries = [e .type for e in trace_client .entries if e .function == "test_span" ]
119
+ assert entries == ["enter" , "input" , "output" , "exit" ]
124
120
125
- def test_condense_trace (configured_tracer ):
121
+ def test_condense_trace (trace_client ):
126
122
"""Test trace condensing functionality"""
127
- trace = configured_tracer .start_trace ("test_trace" )
128
-
129
- # Create sample entries
130
123
entries = [
131
124
{"type" : "enter" , "function" : "test_func" , "depth" : 0 , "timestamp" : 1.0 },
132
- {"type" : "input" , "function" : "test_func" , "depth" : 0 , "timestamp" : 1.1 , "inputs" : {"x" : 1 }},
133
- {"type" : "output" , "function" : "test_func" , "depth" : 0 , "timestamp" : 1.2 , "output" : "result" },
134
- {"type" : "exit" , "function" : "test_func" , "depth" : 0 , "timestamp" : 1.3 },
125
+ {"type" : "input" , "function" : "test_func" , "depth" : 1 , "timestamp" : 1.1 , "inputs" : {"x" : 1 }},
126
+ {"type" : "output" , "function" : "test_func" , "depth" : 1 , "timestamp" : 1.2 , "output" : "result" },
127
+ {"type" : "exit" , "function" : "test_func" , "depth" : 0 , "timestamp" : 2.0 },
135
128
]
136
129
137
- condensed = trace .condense_trace (entries )
138
-
130
+ condensed = trace_client .condense_trace (entries )
139
131
assert len (condensed ) == 1
140
132
assert condensed [0 ]["function" ] == "test_func"
133
+ assert condensed [0 ]["depth" ] == 1
141
134
assert condensed [0 ]["inputs" ] == {"x" : 1 }
142
135
assert condensed [0 ]["output" ] == "result"
143
- assert condensed [0 ]["duration" ] == pytest . approx ( 0.3 )
136
+ assert condensed [0 ]["duration" ] == 1.0
144
137
145
- def test_nested_function_depth (configured_tracer ):
146
- """Test depth tracking for nested function calls"""
138
+ @patch ('requests.post' )
139
+ def test_save_trace (mock_post , trace_client ):
140
+ """Test saving trace data"""
141
+ mock_post .return_value .raise_for_status = Mock ()
142
+
143
+ with trace_client .span ("test_span" ):
144
+ trace_client .record_input ({"arg" : 1 })
145
+ trace_client .record_output ("result" )
146
+
147
+ trace_id , data = trace_client .save ()
148
+
149
+ assert mock_post .called
150
+ assert data ["trace_id" ] == trace_client .trace_id
151
+ assert data ["name" ] == "test_trace"
152
+ assert len (data ["entries" ]) > 0
153
+ assert isinstance (data ["created_at" ], str )
154
+ assert isinstance (data ["duration" ], float )
155
+
156
+ def test_observe_decorator (tracer ):
157
+ """Test the @tracer.observe decorator"""
147
158
@tracer .observe
148
- def outer ():
149
- @tracer .observe
150
- def inner ():
151
- pass
152
- inner ()
153
-
154
- trace = configured_tracer .start_trace ("test_trace" )
155
- outer ()
156
-
157
- # Verify depths
158
- print (f"{ trace .entries = } " )
159
- depths = [entry .depth for entry in trace .entries ]
160
- assert depths == [0 , 1 , 1 , 2 , 2 , 1 , 1 , 0 ] # outer(enter), inner(enter), inner(input), inner(output), inner(exit), outer(exit)
161
-
162
- def test_error_handling (configured_tracer ):
163
- """Test error handling in traced functions"""
159
+ def test_function (x , y ):
160
+ return x + y
161
+
162
+ with tracer .trace ("test_trace" ):
163
+ result = test_function (1 , 2 )
164
+
165
+ assert result == 3
166
+
167
+ def test_observe_decorator_with_error (tracer ):
168
+ """Test decorator error handling"""
164
169
@tracer .observe
165
170
def failing_function ():
166
171
raise ValueError ("Test error" )
167
172
168
- trace = configured_tracer .start_trace ("test_trace" )
173
+ with tracer .trace ("test_trace" ):
174
+ with pytest .raises (ValueError ):
175
+ failing_function ()
176
+
177
+ @patch ('requests.post' )
178
+ def test_wrap_openai (mock_post , tracer ):
179
+ """Test wrapping OpenAI client"""
180
+ client = OpenAI ()
181
+ mock_response = MagicMock ()
182
+ mock_response .choices = [MagicMock (message = MagicMock (content = "test response" ))]
183
+ mock_response .usage = MagicMock (prompt_tokens = 10 , completion_tokens = 20 , total_tokens = 30 )
184
+ client .chat .completions .create = MagicMock (return_value = mock_response )
185
+
186
+ wrapped_client = wrap (client )
187
+
188
+ with tracer .trace ("test_trace" ):
189
+ response = wrapped_client .chat .completions .create (
190
+ model = "gpt-4" ,
191
+ messages = [{"role" : "user" , "content" : "test" }]
192
+ )
193
+
194
+ assert response == mock_response
195
+
196
+ @patch ('requests.post' )
197
+ def test_wrap_anthropic (mock_post , tracer ):
198
+ """Test wrapping Anthropic client"""
199
+ client = Anthropic ()
200
+ mock_response = MagicMock ()
201
+ mock_response .content = [MagicMock (text = "test response" )]
202
+ mock_response .usage = MagicMock (input_tokens = 10 , output_tokens = 20 )
203
+ client .messages .create = MagicMock (return_value = mock_response )
204
+
205
+ wrapped_client = wrap (client )
206
+
207
+ with tracer .trace ("test_trace" ):
208
+ response = wrapped_client .messages .create (
209
+ model = "claude-3" ,
210
+ messages = [{"role" : "user" , "content" : "test" }]
211
+ )
212
+
213
+ assert response == mock_response
214
+
215
+ def test_wrap_unsupported_client (tracer ):
216
+ """Test wrapping unsupported client type"""
217
+ class UnsupportedClient :
218
+ pass
169
219
170
220
with pytest .raises (ValueError ):
171
- failing_function ()
172
-
173
- # Verify that exit entry was still recorded
174
- assert trace .entries [- 1 ].type == "exit"
175
- assert tracer .depth == 0 # Depth should be reset even after error
221
+ wrap (UnsupportedClient ())
0 commit comments