Skip to content

Commit 6b09627

Browse files
committed
fixes
1 parent 0ba04cb commit 6b09627

File tree

2 files changed

+50
-63
lines changed

2 files changed

+50
-63
lines changed

src/judgeval/tracer/__init__.py

Lines changed: 48 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from functools import partial
2222
from warnings import warn
23+
from contextvars import Token
2324

2425
from opentelemetry.context import Context
2526
from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
@@ -232,9 +233,7 @@ def add_cost_to_current_context(self, cost: float) -> None:
232233
AttributeKeys.JUDGMENT_CUMULATIVE_LLM_COST, new_cumulative_cost
233234
)
234235

235-
def add_agent_attributes_to_span(
236-
self, span, attributes: Optional[Dict[str, Any]] = None
237-
):
236+
def add_agent_attributes_to_span(self, span):
238237
"""Add agent ID, class name, and instance name to span if they exist in context"""
239238
current_agent_context = _current_agent_context.get()
240239
if current_agent_context:
@@ -318,7 +317,7 @@ def _wrap_sync(
318317
def wrapper(*args, **kwargs):
319318
n = name or f.__qualname__
320319
with sync_span_context(self, n, attributes) as span:
321-
self.add_agent_attributes_to_span(span, attributes)
320+
self.add_agent_attributes_to_span(span)
322321
self.record_instance_state("before", span)
323322
try:
324323
span.set_attribute(
@@ -348,7 +347,7 @@ def _wrap_async(
348347
async def wrapper(*args, **kwargs):
349348
n = name or f.__qualname__
350349
with sync_span_context(self, n, attributes) as span:
351-
self.add_agent_attributes_to_span(span, attributes)
350+
self.add_agent_attributes_to_span(span)
352351
self.record_instance_state("before", span)
353352
try:
354353
span.set_attribute(
@@ -472,37 +471,54 @@ def my_agent_method(self):
472471
if len(parts) >= 2:
473472
class_name = parts[-2]
474473

474+
def _create_agent_context(args) -> Token:
475+
"""Create agent context and return token"""
476+
agent_id = str(uuid.uuid4())
477+
agent_context: Dict[str, str | bool | Dict | List | Any] = {
478+
"agent_id": agent_id
479+
}
480+
481+
if class_name:
482+
agent_context["class_name"] = class_name
483+
484+
agent_context["track_state"] = track_state
485+
agent_context["track_attributes"] = track_attributes
486+
agent_context["field_mappings"] = field_mappings
487+
488+
instance = args[0] if args else None
489+
agent_context["instance"] = instance
490+
491+
if identifier:
492+
if not class_name or not instance or not isinstance(instance, object):
493+
raise Exception(
494+
"'identifier' is set but no class name or instance is available. 'identifier' can only be specified when using the agent() decorator on a class method."
495+
)
496+
if (
497+
instance
498+
and hasattr(instance, identifier)
499+
and not callable(getattr(instance, identifier))
500+
):
501+
instance_name = str(getattr(instance, identifier))
502+
agent_context["instance_name"] = instance_name
503+
else:
504+
raise Exception(
505+
f"Attribute {identifier} does not exist for {class_name}. Check your agent() decorator."
506+
)
507+
508+
current_agent_context = _current_agent_context.get()
509+
if current_agent_context and "agent_id" in current_agent_context:
510+
agent_context["parent_agent_id"] = current_agent_context["agent_id"]
511+
512+
agent_context["is_agent_entry_point"] = True
513+
token = _current_agent_context.set(agent_context)
514+
return token
515+
475516
def _wrap_with_agent_context(f: Callable):
476517
if inspect.iscoroutinefunction(f):
477518

478519
@functools.wraps(f)
479520
async def async_wrapper(*args, **kwargs):
480-
agent_id = str(uuid.uuid4())
481-
agent_context = {"agent_id": agent_id}
482-
if class_name:
483-
agent_context["class_name"] = class_name
484-
485-
agent_context["track_state"] = track_state
486-
agent_context["track_attributes"] = track_attributes
487-
agent_context["field_mappings"] = field_mappings
488-
489-
instance = args[0] if args else None
490-
491-
agent_context["instance"] = instance
492-
493-
if identifier and instance and hasattr(instance, identifier):
494-
try:
495-
instance_name = str(getattr(instance, identifier))
496-
agent_context["instance_name"] = instance_name
497-
except Exception:
498-
pass
499-
current_agent_context = _current_agent_context.get()
500-
if current_agent_context and "agent_id" in current_agent_context:
501-
agent_context["parent_agent_id"] = current_agent_context[
502-
"agent_id"
503-
]
504-
agent_context["is_agent_entry_point"] = True
505-
token = _current_agent_context.set(agent_context)
521+
token = _create_agent_context(args)
506522
try:
507523
return await f(*args, **kwargs)
508524
finally:
@@ -513,32 +529,7 @@ async def async_wrapper(*args, **kwargs):
513529

514530
@functools.wraps(f)
515531
def sync_wrapper(*args, **kwargs):
516-
agent_id = str(uuid.uuid4())
517-
agent_context = {"agent_id": agent_id}
518-
if class_name:
519-
agent_context["class_name"] = class_name
520-
521-
agent_context["track_state"] = track_state
522-
agent_context["track_attributes"] = track_attributes
523-
agent_context["field_mappings"] = field_mappings
524-
525-
instance = args[0] if args else None
526-
527-
agent_context["instance"] = instance
528-
529-
if identifier and args and hasattr(args[0], identifier):
530-
try:
531-
instance_name = str(getattr(args[0], identifier))
532-
agent_context["instance_name"] = instance_name
533-
except Exception:
534-
pass
535-
current_agent_context = _current_agent_context.get()
536-
if current_agent_context and "agent_id" in current_agent_context:
537-
agent_context["parent_agent_id"] = current_agent_context[
538-
"agent_id"
539-
]
540-
agent_context["is_agent_entry_point"] = True
541-
token = _current_agent_context.set(agent_context)
532+
token = _create_agent_context(args)
542533
try:
543534
return f(*args, **kwargs)
544535
finally:

src/judgeval/tracer/llm/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def wrapper(*args, **kwargs):
5454
with sync_span_context(
5555
tracer, span_name, {AttributeKeys.SPAN_TYPE: "llm"}
5656
) as span:
57-
tracer.add_agent_attributes_to_span(
58-
span, {AttributeKeys.SPAN_TYPE: "llm"}
59-
)
57+
tracer.add_agent_attributes_to_span(span)
6058
span.set_attribute(AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs))
6159
try:
6260
response = function(*args, **kwargs)
@@ -103,9 +101,7 @@ async def wrapper(*args, **kwargs):
103101
async with async_span_context(
104102
tracer, span_name, {AttributeKeys.SPAN_TYPE: "llm"}
105103
) as span:
106-
tracer.add_agent_attributes_to_span(
107-
span, {AttributeKeys.SPAN_TYPE: "llm"}
108-
)
104+
tracer.add_agent_attributes_to_span(span)
109105
span.set_attribute(AttributeKeys.GEN_AI_PROMPT, safe_serialize(kwargs))
110106
try:
111107
response = await function(*args, **kwargs)

0 commit comments

Comments
 (0)