20
20
)
21
21
from functools import partial
22
22
from warnings import warn
23
+ from contextvars import Token
23
24
24
25
from opentelemetry .context import Context
25
26
from opentelemetry .sdk .trace import SpanProcessor , TracerProvider
@@ -232,9 +233,7 @@ def add_cost_to_current_context(self, cost: float) -> None:
232
233
AttributeKeys .JUDGMENT_CUMULATIVE_LLM_COST , new_cumulative_cost
233
234
)
234
235
235
- def add_agent_attributes_to_span (
236
- self , span , attributes : Optional [Dict [str , Any ]] = None
237
- ):
236
+ def add_agent_attributes_to_span (self , span ):
238
237
"""Add agent ID, class name, and instance name to span if they exist in context"""
239
238
current_agent_context = _current_agent_context .get ()
240
239
if current_agent_context :
@@ -318,7 +317,7 @@ def _wrap_sync(
318
317
def wrapper (* args , ** kwargs ):
319
318
n = name or f .__qualname__
320
319
with sync_span_context (self , n , attributes ) as span :
321
- self .add_agent_attributes_to_span (span , attributes )
320
+ self .add_agent_attributes_to_span (span )
322
321
self .record_instance_state ("before" , span )
323
322
try :
324
323
span .set_attribute (
@@ -348,7 +347,7 @@ def _wrap_async(
348
347
async def wrapper (* args , ** kwargs ):
349
348
n = name or f .__qualname__
350
349
with sync_span_context (self , n , attributes ) as span :
351
- self .add_agent_attributes_to_span (span , attributes )
350
+ self .add_agent_attributes_to_span (span )
352
351
self .record_instance_state ("before" , span )
353
352
try :
354
353
span .set_attribute (
@@ -472,37 +471,54 @@ def my_agent_method(self):
472
471
if len (parts ) >= 2 :
473
472
class_name = parts [- 2 ]
474
473
474
+ def _create_agent_context (args ) -> Token :
475
+ """Create agent context and return token"""
476
+ agent_id = str (uuid .uuid4 ())
477
+ agent_context : Dict [str , str | bool | Dict | List | Any ] = {
478
+ "agent_id" : agent_id
479
+ }
480
+
481
+ if class_name :
482
+ agent_context ["class_name" ] = class_name
483
+
484
+ agent_context ["track_state" ] = track_state
485
+ agent_context ["track_attributes" ] = track_attributes
486
+ agent_context ["field_mappings" ] = field_mappings
487
+
488
+ instance = args [0 ] if args else None
489
+ agent_context ["instance" ] = instance
490
+
491
+ if identifier :
492
+ if not class_name or not instance or not isinstance (instance , object ):
493
+ raise Exception (
494
+ "'identifier' is set but no class name or instance is available. 'identifier' can only be specified when using the agent() decorator on a class method."
495
+ )
496
+ if (
497
+ instance
498
+ and hasattr (instance , identifier )
499
+ and not callable (getattr (instance , identifier ))
500
+ ):
501
+ instance_name = str (getattr (instance , identifier ))
502
+ agent_context ["instance_name" ] = instance_name
503
+ else :
504
+ raise Exception (
505
+ f"Attribute { identifier } does not exist for { class_name } . Check your agent() decorator."
506
+ )
507
+
508
+ current_agent_context = _current_agent_context .get ()
509
+ if current_agent_context and "agent_id" in current_agent_context :
510
+ agent_context ["parent_agent_id" ] = current_agent_context ["agent_id" ]
511
+
512
+ agent_context ["is_agent_entry_point" ] = True
513
+ token = _current_agent_context .set (agent_context )
514
+ return token
515
+
475
516
def _wrap_with_agent_context (f : Callable ):
476
517
if inspect .iscoroutinefunction (f ):
477
518
478
519
@functools .wraps (f )
479
520
async def async_wrapper (* args , ** kwargs ):
480
- agent_id = str (uuid .uuid4 ())
481
- agent_context = {"agent_id" : agent_id }
482
- if class_name :
483
- agent_context ["class_name" ] = class_name
484
-
485
- agent_context ["track_state" ] = track_state
486
- agent_context ["track_attributes" ] = track_attributes
487
- agent_context ["field_mappings" ] = field_mappings
488
-
489
- instance = args [0 ] if args else None
490
-
491
- agent_context ["instance" ] = instance
492
-
493
- if identifier and instance and hasattr (instance , identifier ):
494
- try :
495
- instance_name = str (getattr (instance , identifier ))
496
- agent_context ["instance_name" ] = instance_name
497
- except Exception :
498
- pass
499
- current_agent_context = _current_agent_context .get ()
500
- if current_agent_context and "agent_id" in current_agent_context :
501
- agent_context ["parent_agent_id" ] = current_agent_context [
502
- "agent_id"
503
- ]
504
- agent_context ["is_agent_entry_point" ] = True
505
- token = _current_agent_context .set (agent_context )
521
+ token = _create_agent_context (args )
506
522
try :
507
523
return await f (* args , ** kwargs )
508
524
finally :
@@ -513,32 +529,7 @@ async def async_wrapper(*args, **kwargs):
513
529
514
530
@functools .wraps (f )
515
531
def sync_wrapper (* args , ** kwargs ):
516
- agent_id = str (uuid .uuid4 ())
517
- agent_context = {"agent_id" : agent_id }
518
- if class_name :
519
- agent_context ["class_name" ] = class_name
520
-
521
- agent_context ["track_state" ] = track_state
522
- agent_context ["track_attributes" ] = track_attributes
523
- agent_context ["field_mappings" ] = field_mappings
524
-
525
- instance = args [0 ] if args else None
526
-
527
- agent_context ["instance" ] = instance
528
-
529
- if identifier and args and hasattr (args [0 ], identifier ):
530
- try :
531
- instance_name = str (getattr (args [0 ], identifier ))
532
- agent_context ["instance_name" ] = instance_name
533
- except Exception :
534
- pass
535
- current_agent_context = _current_agent_context .get ()
536
- if current_agent_context and "agent_id" in current_agent_context :
537
- agent_context ["parent_agent_id" ] = current_agent_context [
538
- "agent_id"
539
- ]
540
- agent_context ["is_agent_entry_point" ] = True
541
- token = _current_agent_context .set (agent_context )
532
+ token = _create_agent_context (args )
542
533
try :
543
534
return f (* args , ** kwargs )
544
535
finally :
0 commit comments