16
16
Type ,
17
17
TypeVar ,
18
18
overload ,
19
+ Literal ,
19
20
)
20
21
from functools import partial
21
22
from warnings import warn
60
61
Cls = TypeVar ("Cls" , bound = Type )
61
62
ApiClient = TypeVar ("ApiClient" , bound = Any )
62
63
63
- _current_agent_context : ContextVar [Optional [ Dict [ str , str | bool ]]] = ContextVar (
64
- "current_agent_context" , default = None
65
- )
64
+ _current_agent_context : ContextVar [
65
+ Optional [ Dict [ str , str | bool | Dict | List | None ]]
66
+ ] = ContextVar ( "current_agent_context" , default = None )
66
67
67
68
_current_cost_context : ContextVar [Optional [Dict [str , float ]]] = ContextVar (
68
69
"current_cost_context" , default = None
@@ -265,6 +266,51 @@ def add_agent_attributes_to_span(
265
266
False # only true for entry point to agent
266
267
)
267
268
269
+ def record_instance_state (self , record_point : Literal ["before" , "after" ], span ):
270
+ current_agent_context = _current_agent_context .get ()
271
+
272
+ if current_agent_context and current_agent_context .get ("track_state" ):
273
+ instance_untyped = current_agent_context .get ("instance" , None )
274
+ instance = (
275
+ instance_untyped if isinstance (instance_untyped , object ) else None
276
+ )
277
+ track_attributes_untyped = current_agent_context .get (
278
+ "track_attributes" , None
279
+ )
280
+ track_attributes : List | None = (
281
+ track_attributes_untyped
282
+ if isinstance (track_attributes_untyped , list )
283
+ else None
284
+ )
285
+ field_mappings_untyped = current_agent_context .get ("field_mappings" , {})
286
+ field_mappings : Dict [str , str ] = (
287
+ field_mappings_untyped
288
+ if isinstance (field_mappings_untyped , dict )
289
+ else {}
290
+ )
291
+ if track_attributes is not None :
292
+ attributes = {
293
+ field_mappings .get (attr , attr ): getattr (instance , attr , None )
294
+ for attr in track_attributes
295
+ }
296
+ else :
297
+ attributes = {
298
+ field_mappings .get (k , k ): v
299
+ for k , v in instance .__dict__ .items ()
300
+ if not k .startswith ("_" )
301
+ }
302
+
303
+ if record_point == "before" :
304
+ span .set_attribute (
305
+ AttributeKeys .JUDGMENT_STATE_BEFORE ,
306
+ safe_serialize (attributes ),
307
+ )
308
+ else :
309
+ span .set_attribute (
310
+ AttributeKeys .JUDGMENT_STATE_AFTER ,
311
+ safe_serialize (attributes ),
312
+ )
313
+
268
314
def _wrap_sync (
269
315
self , f : Callable , name : Optional [str ], attributes : Optional [Dict [str , Any ]]
270
316
):
@@ -273,6 +319,7 @@ def wrapper(*args, **kwargs):
273
319
n = name or f .__qualname__
274
320
with sync_span_context (self , n , attributes ) as span :
275
321
self .add_agent_attributes_to_span (span , attributes )
322
+ self .record_instance_state ("before" , span )
276
323
try :
277
324
span .set_attribute (
278
325
AttributeKeys .JUDGMENT_INPUT ,
@@ -289,6 +336,7 @@ def wrapper(*args, **kwargs):
289
336
AttributeKeys .JUDGMENT_OUTPUT ,
290
337
safe_serialize (result ),
291
338
)
339
+ self .record_instance_state ("after" , span )
292
340
return result
293
341
294
342
return wrapper
@@ -301,6 +349,7 @@ async def wrapper(*args, **kwargs):
301
349
n = name or f .__qualname__
302
350
with sync_span_context (self , n , attributes ) as span :
303
351
self .add_agent_attributes_to_span (span , attributes )
352
+ self .record_instance_state ("before" , span )
304
353
try :
305
354
span .set_attribute (
306
355
AttributeKeys .JUDGMENT_INPUT ,
@@ -316,6 +365,7 @@ async def wrapper(*args, **kwargs):
316
365
AttributeKeys .JUDGMENT_OUTPUT ,
317
366
safe_serialize (result ),
318
367
)
368
+ self .record_instance_state ("after" , span )
319
369
return result
320
370
321
371
return wrapper
@@ -348,15 +398,38 @@ def observe(
348
398
return self ._wrap_sync (func , name , attributes )
349
399
350
400
@overload
351
- def agent (self , func : C , / , * , identifier : str | None = None ) -> C : ...
401
+ def agent (
402
+ self ,
403
+ func : C ,
404
+ / ,
405
+ * ,
406
+ identifier : str | None = None ,
407
+ track_state : bool = False ,
408
+ track_attributes : List [str ] | None = None ,
409
+ field_mappings : Dict [str , str ] = {},
410
+ ) -> C : ...
352
411
353
412
@overload
354
413
def agent (
355
- self , func : None = None , / , * , identifier : str | None = None
414
+ self ,
415
+ func : None = None ,
416
+ / ,
417
+ * ,
418
+ identifier : str | None = None ,
419
+ track_state : bool = False ,
420
+ track_attributes : List [str ] | None = None ,
421
+ field_mappings : Dict [str , str ] = {},
356
422
) -> Callable [[C ], C ]: ...
357
423
358
424
def agent (
359
- self , func : Callable | None = None , / , * , identifier : str | None = None
425
+ self ,
426
+ func : Callable | None = None ,
427
+ / ,
428
+ * ,
429
+ identifier : str | None = None ,
430
+ track_state : bool = False ,
431
+ track_attributes : List [str ] | None = None ,
432
+ field_mappings : Dict [str , str ] = {},
360
433
) -> Callable | None :
361
434
"""
362
435
Agent decorator that creates an agent ID and propagates it to child spans.
@@ -382,7 +455,13 @@ def my_agent_method(self):
382
455
identifier: Name of the instance attribute to use as the instance name
383
456
"""
384
457
if func is None :
385
- return partial (self .agent , identifier = identifier )
458
+ return partial (
459
+ self .agent ,
460
+ identifier = identifier ,
461
+ track_state = track_state ,
462
+ track_attributes = track_attributes ,
463
+ field_mappings = field_mappings ,
464
+ )
386
465
387
466
if not self .enable_monitoring :
388
467
return func
@@ -403,9 +482,17 @@ async def async_wrapper(*args, **kwargs):
403
482
if class_name :
404
483
agent_context ["class_name" ] = class_name
405
484
406
- if identifier and args and hasattr (args [0 ], identifier ):
485
+ agent_context ["track_state" ] = track_state
486
+ agent_context ["track_attributes" ] = track_attributes
487
+ agent_context ["field_mappings" ] = field_mappings
488
+
489
+ instance = args [0 ] if args else None
490
+
491
+ agent_context ["instance" ] = instance
492
+
493
+ if identifier and instance and hasattr (instance , identifier ):
407
494
try :
408
- instance_name = str (getattr (args [ 0 ] , identifier ))
495
+ instance_name = str (getattr (instance , identifier ))
409
496
agent_context ["instance_name" ] = instance_name
410
497
except Exception :
411
498
pass
@@ -431,6 +518,14 @@ def sync_wrapper(*args, **kwargs):
431
518
if class_name :
432
519
agent_context ["class_name" ] = class_name
433
520
521
+ agent_context ["track_state" ] = track_state
522
+ agent_context ["track_attributes" ] = track_attributes
523
+ agent_context ["field_mappings" ] = field_mappings
524
+
525
+ instance = args [0 ] if args else None
526
+
527
+ agent_context ["instance" ] = instance
528
+
434
529
if identifier and args and hasattr (args [0 ], identifier ):
435
530
try :
436
531
instance_name = str (getattr (args [0 ], identifier ))
0 commit comments