Skip to content

Commit 0ba04cb

Browse files
committed
state before/after
1 parent 6a7b924 commit 0ba04cb

File tree

2 files changed

+106
-9
lines changed

2 files changed

+106
-9
lines changed

src/judgeval/tracer/__init__.py

Lines changed: 104 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Type,
1717
TypeVar,
1818
overload,
19+
Literal,
1920
)
2021
from functools import partial
2122
from warnings import warn
@@ -60,9 +61,9 @@
6061
Cls = TypeVar("Cls", bound=Type)
6162
ApiClient = TypeVar("ApiClient", bound=Any)
6263

63-
_current_agent_context: ContextVar[Optional[Dict[str, str | bool]]] = ContextVar(
64-
"current_agent_context", default=None
65-
)
64+
_current_agent_context: ContextVar[
65+
Optional[Dict[str, str | bool | Dict | List | None]]
66+
] = ContextVar("current_agent_context", default=None)
6667

6768
_current_cost_context: ContextVar[Optional[Dict[str, float]]] = ContextVar(
6869
"current_cost_context", default=None
@@ -265,6 +266,51 @@ def add_agent_attributes_to_span(
265266
False # only true for entry point to agent
266267
)
267268

269+
def record_instance_state(self, record_point: Literal["before", "after"], span):
270+
current_agent_context = _current_agent_context.get()
271+
272+
if current_agent_context and current_agent_context.get("track_state"):
273+
instance_untyped = current_agent_context.get("instance", None)
274+
instance = (
275+
instance_untyped if isinstance(instance_untyped, object) else None
276+
)
277+
track_attributes_untyped = current_agent_context.get(
278+
"track_attributes", None
279+
)
280+
track_attributes: List | None = (
281+
track_attributes_untyped
282+
if isinstance(track_attributes_untyped, list)
283+
else None
284+
)
285+
field_mappings_untyped = current_agent_context.get("field_mappings", {})
286+
field_mappings: Dict[str, str] = (
287+
field_mappings_untyped
288+
if isinstance(field_mappings_untyped, dict)
289+
else {}
290+
)
291+
if track_attributes is not None:
292+
attributes = {
293+
field_mappings.get(attr, attr): getattr(instance, attr, None)
294+
for attr in track_attributes
295+
}
296+
else:
297+
attributes = {
298+
field_mappings.get(k, k): v
299+
for k, v in instance.__dict__.items()
300+
if not k.startswith("_")
301+
}
302+
303+
if record_point == "before":
304+
span.set_attribute(
305+
AttributeKeys.JUDGMENT_STATE_BEFORE,
306+
safe_serialize(attributes),
307+
)
308+
else:
309+
span.set_attribute(
310+
AttributeKeys.JUDGMENT_STATE_AFTER,
311+
safe_serialize(attributes),
312+
)
313+
268314
def _wrap_sync(
269315
self, f: Callable, name: Optional[str], attributes: Optional[Dict[str, Any]]
270316
):
@@ -273,6 +319,7 @@ def wrapper(*args, **kwargs):
273319
n = name or f.__qualname__
274320
with sync_span_context(self, n, attributes) as span:
275321
self.add_agent_attributes_to_span(span, attributes)
322+
self.record_instance_state("before", span)
276323
try:
277324
span.set_attribute(
278325
AttributeKeys.JUDGMENT_INPUT,
@@ -289,6 +336,7 @@ def wrapper(*args, **kwargs):
289336
AttributeKeys.JUDGMENT_OUTPUT,
290337
safe_serialize(result),
291338
)
339+
self.record_instance_state("after", span)
292340
return result
293341

294342
return wrapper
@@ -301,6 +349,7 @@ async def wrapper(*args, **kwargs):
301349
n = name or f.__qualname__
302350
with sync_span_context(self, n, attributes) as span:
303351
self.add_agent_attributes_to_span(span, attributes)
352+
self.record_instance_state("before", span)
304353
try:
305354
span.set_attribute(
306355
AttributeKeys.JUDGMENT_INPUT,
@@ -316,6 +365,7 @@ async def wrapper(*args, **kwargs):
316365
AttributeKeys.JUDGMENT_OUTPUT,
317366
safe_serialize(result),
318367
)
368+
self.record_instance_state("after", span)
319369
return result
320370

321371
return wrapper
@@ -348,15 +398,38 @@ def observe(
348398
return self._wrap_sync(func, name, attributes)
349399

350400
@overload
351-
def agent(self, func: C, /, *, identifier: str | None = None) -> C: ...
401+
def agent(
402+
self,
403+
func: C,
404+
/,
405+
*,
406+
identifier: str | None = None,
407+
track_state: bool = False,
408+
track_attributes: List[str] | None = None,
409+
field_mappings: Dict[str, str] = {},
410+
) -> C: ...
352411

353412
@overload
354413
def agent(
355-
self, func: None = None, /, *, identifier: str | None = None
414+
self,
415+
func: None = None,
416+
/,
417+
*,
418+
identifier: str | None = None,
419+
track_state: bool = False,
420+
track_attributes: List[str] | None = None,
421+
field_mappings: Dict[str, str] = {},
356422
) -> Callable[[C], C]: ...
357423

358424
def agent(
359-
self, func: Callable | None = None, /, *, identifier: str | None = None
425+
self,
426+
func: Callable | None = None,
427+
/,
428+
*,
429+
identifier: str | None = None,
430+
track_state: bool = False,
431+
track_attributes: List[str] | None = None,
432+
field_mappings: Dict[str, str] = {},
360433
) -> Callable | None:
361434
"""
362435
Agent decorator that creates an agent ID and propagates it to child spans.
@@ -382,7 +455,13 @@ def my_agent_method(self):
382455
identifier: Name of the instance attribute to use as the instance name
383456
"""
384457
if func is None:
385-
return partial(self.agent, identifier=identifier)
458+
return partial(
459+
self.agent,
460+
identifier=identifier,
461+
track_state=track_state,
462+
track_attributes=track_attributes,
463+
field_mappings=field_mappings,
464+
)
386465

387466
if not self.enable_monitoring:
388467
return func
@@ -403,9 +482,17 @@ async def async_wrapper(*args, **kwargs):
403482
if class_name:
404483
agent_context["class_name"] = class_name
405484

406-
if identifier and args and hasattr(args[0], identifier):
485+
agent_context["track_state"] = track_state
486+
agent_context["track_attributes"] = track_attributes
487+
agent_context["field_mappings"] = field_mappings
488+
489+
instance = args[0] if args else None
490+
491+
agent_context["instance"] = instance
492+
493+
if identifier and instance and hasattr(instance, identifier):
407494
try:
408-
instance_name = str(getattr(args[0], identifier))
495+
instance_name = str(getattr(instance, identifier))
409496
agent_context["instance_name"] = instance_name
410497
except Exception:
411498
pass
@@ -431,6 +518,14 @@ def sync_wrapper(*args, **kwargs):
431518
if class_name:
432519
agent_context["class_name"] = class_name
433520

521+
agent_context["track_state"] = track_state
522+
agent_context["track_attributes"] = track_attributes
523+
agent_context["field_mappings"] = field_mappings
524+
525+
instance = args[0] if args else None
526+
527+
agent_context["instance"] = instance
528+
434529
if identifier and args and hasattr(args[0], identifier):
435530
try:
436531
instance_name = str(getattr(args[0], identifier))

src/judgeval/tracer/keys.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class AttributeKeys:
2121
JUDGMENT_AGENT_INSTANCE_NAME = "judgment.agent_instance_name"
2222
JUDGMENT_IS_AGENT_ENTRY_POINT = "judgment.is_agent_entry_point"
2323
JUDGMENT_CUMULATIVE_LLM_COST = "judgment.cumulative_llm_cost"
24+
JUDGMENT_STATE_BEFORE = "judgment.state_before"
25+
JUDGMENT_STATE_AFTER = "judgment.state_after"
2426

2527
# GenAI-specific attributes (semantic conventions)
2628
GEN_AI_PROMPT = gen_ai_attributes.GEN_AI_PROMPT

0 commit comments

Comments
 (0)