Skip to content

Commit 6a7b924

Browse files
committed
cumulative cost tracking
1 parent 2f4e1a2 commit 6a7b924

File tree

4 files changed

+152
-37
lines changed

4 files changed

+152
-37
lines changed

src/judgeval/tracer/__init__.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,14 @@
6060
Cls = TypeVar("Cls", bound=Type)
6161
ApiClient = TypeVar("ApiClient", bound=Any)
6262

63-
_current_agent_context: ContextVar[Optional[Dict[str, str]]] = ContextVar(
63+
_current_agent_context: ContextVar[Optional[Dict[str, str | bool]]] = ContextVar(
6464
"current_agent_context", default=None
6565
)
6666

67+
_current_cost_context: ContextVar[Optional[Dict[str, float]]] = ContextVar(
68+
"current_cost_context", default=None
69+
)
70+
6771

6872
def resolve_project_id(
6973
api_key: str, organization_id: str, project_name: str
@@ -207,7 +211,27 @@ def get_current_span(self):
207211
def get_tracer(self):
208212
return self.tracer
209213

210-
def _add_agent_attributes_to_span(
214+
def get_current_agent_context(self):
215+
return _current_agent_context
216+
217+
def get_current_cost_context(self):
218+
return _current_cost_context
219+
220+
def add_cost_to_current_context(self, cost: float) -> None:
221+
"""Add cost to the current cost context and update span attribute."""
222+
current_cost_context = _current_cost_context.get()
223+
if current_cost_context is not None:
224+
current_cumulative_cost = current_cost_context.get("cumulative_cost", 0.0)
225+
new_cumulative_cost = float(current_cumulative_cost) + cost
226+
current_cost_context["cumulative_cost"] = new_cumulative_cost
227+
228+
span = get_current_span()
229+
if span and span.is_recording():
230+
span.set_attribute(
231+
AttributeKeys.JUDGMENT_CUMULATIVE_LLM_COST, new_cumulative_cost
232+
)
233+
234+
def add_agent_attributes_to_span(
211235
self, span, attributes: Optional[Dict[str, Any]] = None
212236
):
213237
"""Add agent ID, class name, and instance name to span if they exist in context"""
@@ -238,7 +262,7 @@ def _add_agent_attributes_to_span(
238262
current_agent_context["is_agent_entry_point"],
239263
)
240264
current_agent_context["is_agent_entry_point"] = (
241-
"false" # only true for entry point to agent
265+
False # only true for entry point to agent
242266
)
243267

244268
def _wrap_sync(
@@ -248,7 +272,7 @@ def _wrap_sync(
248272
def wrapper(*args, **kwargs):
249273
n = name or f.__qualname__
250274
with sync_span_context(self, n, attributes) as span:
251-
self._add_agent_attributes_to_span(span, attributes)
275+
self.add_agent_attributes_to_span(span, attributes)
252276
try:
253277
span.set_attribute(
254278
AttributeKeys.JUDGMENT_INPUT,
@@ -276,7 +300,7 @@ def _wrap_async(
276300
async def wrapper(*args, **kwargs):
277301
n = name or f.__qualname__
278302
with sync_span_context(self, n, attributes) as span:
279-
self._add_agent_attributes_to_span(span, attributes)
303+
self.add_agent_attributes_to_span(span, attributes)
280304
try:
281305
span.set_attribute(
282306
AttributeKeys.JUDGMENT_INPUT,
@@ -390,7 +414,7 @@ async def async_wrapper(*args, **kwargs):
390414
agent_context["parent_agent_id"] = current_agent_context[
391415
"agent_id"
392416
]
393-
agent_context["is_agent_entry_point"] = "true"
417+
agent_context["is_agent_entry_point"] = True
394418
token = _current_agent_context.set(agent_context)
395419
try:
396420
return await f(*args, **kwargs)
@@ -418,7 +442,7 @@ def sync_wrapper(*args, **kwargs):
418442
agent_context["parent_agent_id"] = current_agent_context[
419443
"agent_id"
420444
]
421-
agent_context["is_agent_entry_point"] = "true"
445+
agent_context["is_agent_entry_point"] = True
422446
token = _current_agent_context.set(agent_context)
423447
try:
424448
return f(*args, **kwargs)

src/judgeval/tracer/keys.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class AttributeKeys:
2020
JUDGMENT_AGENT_CLASS_NAME = "judgment.agent_class_name"
2121
JUDGMENT_AGENT_INSTANCE_NAME = "judgment.agent_instance_name"
2222
JUDGMENT_IS_AGENT_ENTRY_POINT = "judgment.is_agent_entry_point"
23+
JUDGMENT_CUMULATIVE_LLM_COST = "judgment.cumulative_llm_cost"
2324

2425
# GenAI-specific attributes (semantic conventions)
2526
GEN_AI_PROMPT = gen_ai_attributes.GEN_AI_PROMPT
@@ -34,6 +35,10 @@ class AttributeKeys:
3435
GEN_AI_REQUEST_MAX_TOKENS = gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS
3536
GEN_AI_RESPONSE_FINISH_REASONS = gen_ai_attributes.GEN_AI_RESPONSE_FINISH_REASONS
3637

38+
# GenAI-specific attributes (custom namespace)
39+
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
40+
GEN_AI_USAGE_TOTAL_COST = "gen_ai.usage.total_cost"
41+
3742

3843
class ResourceKeys:
3944
SERVICE_NAME = ResourceAttributes.SERVICE_NAME

src/judgeval/tracer/llm/__init__.py

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22
import functools
3-
import sys
43
from typing import Callable, Tuple, Optional, Any, TYPE_CHECKING
54
from functools import wraps
65
from judgeval.data.trace import TraceUsage
@@ -55,6 +54,9 @@ def wrapper(*args, **kwargs):
5554
with sync_span_context(
5655
tracer, span_name, {AttributeKeys.SPAN_TYPE: "llm"}
5756
) as span:
57+
tracer.add_agent_attributes_to_span(
58+
span, {AttributeKeys.SPAN_TYPE: "llm"}
59+
)
5860
span.set_attribute(AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs))
5961
try:
6062
response = function(*args, **kwargs)
@@ -76,6 +78,18 @@ def wrapper(*args, **kwargs):
7678
AttributeKeys.GEN_AI_USAGE_COMPLETION_TOKENS,
7779
usage.completion_tokens,
7880
)
81+
if usage.total_tokens:
82+
span.set_attribute(
83+
AttributeKeys.GEN_AI_USAGE_TOTAL_TOKENS,
84+
usage.total_tokens,
85+
)
86+
if usage.total_cost_usd:
87+
span.set_attribute(
88+
AttributeKeys.GEN_AI_USAGE_TOTAL_COST,
89+
usage.total_cost_usd,
90+
)
91+
# Add cost to cumulative context tracking
92+
tracer.add_cost_to_current_context(usage.total_cost_usd)
7993
return response
8094
except Exception as e:
8195
span.record_exception(e)
@@ -89,6 +103,9 @@ async def wrapper(*args, **kwargs):
89103
async with async_span_context(
90104
tracer, span_name, {AttributeKeys.SPAN_TYPE: "llm"}
91105
) as span:
106+
tracer.add_agent_attributes_to_span(
107+
span, {AttributeKeys.SPAN_TYPE: "llm"}
108+
)
92109
span.set_attribute(AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs))
93110
try:
94111
response = await function(*args, **kwargs)
@@ -110,6 +127,17 @@ async def wrapper(*args, **kwargs):
110127
AttributeKeys.GEN_AI_USAGE_COMPLETION_TOKENS,
111128
usage.completion_tokens,
112129
)
130+
if usage.total_tokens:
131+
span.set_attribute(
132+
AttributeKeys.GEN_AI_USAGE_TOTAL_TOKENS,
133+
usage.total_tokens,
134+
)
135+
if usage.total_cost_usd:
136+
span.set_attribute(
137+
AttributeKeys.GEN_AI_USAGE_TOTAL_COST,
138+
usage.total_cost_usd,
139+
)
140+
tracer.add_cost_to_current_context(usage.total_cost_usd)
113141
return response
114142
except Exception as e:
115143
span.record_exception(e)
@@ -160,9 +188,9 @@ async def wrapper(*args, **kwargs):
160188
)
161189

162190
assert google_genai_Client is not None, "Google GenAI client not found"
163-
assert (
164-
google_genai_AsyncClient is not None
165-
), "Google GenAI async client not found"
191+
assert google_genai_AsyncClient is not None, (
192+
"Google GenAI async client not found"
193+
)
166194
if isinstance(client, google_genai_Client):
167195
setattr(client.models, "generate_content", wrapped(original_create))
168196
elif isinstance(client, google_genai_AsyncClient):
@@ -225,9 +253,9 @@ def _get_client_config(client: ApiClient) -> tuple[str, Callable]:
225253
)
226254

227255
assert google_genai_Client is not None, "Google GenAI client not found"
228-
assert (
229-
google_genai_AsyncClient is not None
230-
), "Google GenAI async client not found"
256+
assert google_genai_AsyncClient is not None, (
257+
"Google GenAI async client not found"
258+
)
231259
if isinstance(client, google_genai_Client):
232260
return "GOOGLE_API_CALL", client.models.generate_content
233261
elif isinstance(client, google_genai_AsyncClient):
@@ -269,9 +297,9 @@ def _format_output_data(
269297
assert openai_AsyncOpenAI is not None, "OpenAI async client not found"
270298
assert openai_ChatCompletion is not None, "OpenAI chat completion not found"
271299
assert openai_Response is not None, "OpenAI response not found"
272-
assert (
273-
openai_ParsedChatCompletion is not None
274-
), "OpenAI parsed chat completion not found"
300+
assert openai_ParsedChatCompletion is not None, (
301+
"OpenAI parsed chat completion not found"
302+
)
275303

276304
if isinstance(client, openai_OpenAI) or isinstance(client, openai_AsyncOpenAI):
277305
if isinstance(response, openai_ChatCompletion):
@@ -318,7 +346,11 @@ def _format_output_data(
318346
else 0
319347
)
320348
output0 = response.output[0]
321-
if hasattr(output0, "content") and output0.content and hasattr(output0.content, "__iter__"): # type: ignore[attr-defined]
349+
if (
350+
hasattr(output0, "content")
351+
and output0.content
352+
and hasattr(output0.content, "__iter__")
353+
): # type: ignore[attr-defined]
322354
message_content = "".join(
323355
seg.text # type: ignore[attr-defined]
324356
for seg in output0.content # type: ignore[attr-defined]
@@ -346,9 +378,23 @@ def _format_output_data(
346378
client, together_AsyncTogether
347379
):
348380
model_name = (response.model or "") if hasattr(response, "model") else ""
349-
prompt_tokens = response.usage.prompt_tokens if hasattr(response.usage, "prompt_tokens") and response.usage.prompt_tokens is not None else 0 # type: ignore[attr-defined]
350-
completion_tokens = response.usage.completion_tokens if hasattr(response.usage, "completion_tokens") and response.usage.completion_tokens is not None else 0 # type: ignore[attr-defined]
351-
message_content = response.choices[0].message.content if hasattr(response, "choices") else None # type: ignore[attr-defined]
381+
prompt_tokens = (
382+
response.usage.prompt_tokens
383+
if hasattr(response.usage, "prompt_tokens")
384+
and response.usage.prompt_tokens is not None
385+
else 0
386+
) # type: ignore[attr-defined]
387+
completion_tokens = (
388+
response.usage.completion_tokens
389+
if hasattr(response.usage, "completion_tokens")
390+
and response.usage.completion_tokens is not None
391+
else 0
392+
) # type: ignore[attr-defined]
393+
message_content = (
394+
response.choices[0].message.content
395+
if hasattr(response, "choices")
396+
else None
397+
) # type: ignore[attr-defined]
352398

353399
if model_name:
354400
return message_content, _create_usage(
@@ -366,9 +412,9 @@ def _format_output_data(
366412
)
367413

368414
assert google_genai_Client is not None, "Google GenAI client not found"
369-
assert (
370-
google_genai_AsyncClient is not None
371-
), "Google GenAI async client not found"
415+
assert google_genai_AsyncClient is not None, (
416+
"Google GenAI async client not found"
417+
)
372418
if isinstance(client, google_genai_Client) or isinstance(
373419
client, google_genai_AsyncClient
374420
):
@@ -467,9 +513,23 @@ def _format_output_data(
467513
assert groq_AsyncGroq is not None, "Groq async client not found"
468514
if isinstance(client, groq_Groq) or isinstance(client, groq_AsyncGroq):
469515
model_name = (response.model or "") if hasattr(response, "model") else ""
470-
prompt_tokens = response.usage.prompt_tokens if hasattr(response.usage, "prompt_tokens") and response.usage.prompt_tokens is not None else 0 # type: ignore[attr-defined]
471-
completion_tokens = response.usage.completion_tokens if hasattr(response.usage, "completion_tokens") and response.usage.completion_tokens is not None else 0 # type: ignore[attr-defined]
472-
message_content = response.choices[0].message.content if hasattr(response, "choices") else None # type: ignore[attr-defined]
516+
prompt_tokens = (
517+
response.usage.prompt_tokens
518+
if hasattr(response.usage, "prompt_tokens")
519+
and response.usage.prompt_tokens is not None
520+
else 0
521+
) # type: ignore[attr-defined]
522+
completion_tokens = (
523+
response.usage.completion_tokens
524+
if hasattr(response.usage, "completion_tokens")
525+
and response.usage.completion_tokens is not None
526+
else 0
527+
) # type: ignore[attr-defined]
528+
message_content = (
529+
response.choices[0].message.content
530+
if hasattr(response, "choices")
531+
else None
532+
) # type: ignore[attr-defined]
473533

474534
if model_name:
475535
return message_content, _create_usage(

src/judgeval/tracer/managers.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from contextlib import asynccontextmanager, contextmanager
44
from typing import TYPE_CHECKING, Dict, Optional
5+
from judgeval.tracer.keys import AttributeKeys
56

67
if TYPE_CHECKING:
78
from judgeval.tracer import Tracer
@@ -16,11 +17,24 @@ def sync_span_context(
1617
if span_attributes is None:
1718
span_attributes = {}
1819

19-
with tracer.get_tracer().start_as_current_span(
20-
name=name,
21-
attributes=span_attributes,
22-
) as span:
23-
yield span
20+
current_cost_context = tracer.get_current_cost_context()
21+
22+
cost_context = {"cumulative_cost": 0.0}
23+
24+
cost_token = current_cost_context.set(cost_context)
25+
26+
try:
27+
with tracer.get_tracer().start_as_current_span(
28+
name=name,
29+
attributes=span_attributes,
30+
) as span:
31+
# Set initial cumulative cost attribute
32+
span.set_attribute(AttributeKeys.JUDGMENT_CUMULATIVE_LLM_COST, 0.0)
33+
yield span
34+
finally:
35+
current_cost_context.reset(cost_token)
36+
child_cost = float(cost_context.get("cumulative_cost", 0.0))
37+
tracer.add_cost_to_current_context(child_cost)
2438

2539

2640
@asynccontextmanager
@@ -30,8 +44,20 @@ async def async_span_context(
3044
if span_attributes is None:
3145
span_attributes = {}
3246

33-
with tracer.get_tracer().start_as_current_span(
34-
name=name,
35-
attributes=span_attributes,
36-
) as span:
37-
yield span
47+
current_cost_context = tracer.get_current_cost_context()
48+
49+
cost_context = {"cumulative_cost": 0.0}
50+
51+
cost_token = current_cost_context.set(cost_context)
52+
53+
try:
54+
with tracer.get_tracer().start_as_current_span(
55+
name=name,
56+
attributes=span_attributes,
57+
) as span:
58+
span.set_attribute(AttributeKeys.JUDGMENT_CUMULATIVE_LLM_COST, 0.0)
59+
yield span
60+
finally:
61+
current_cost_context.reset(cost_token)
62+
child_cost = float(cost_context.get("cumulative_cost", 0.0))
63+
tracer.add_cost_to_current_context(child_cost)

0 commit comments

Comments
 (0)