@@ -383,6 +383,24 @@ def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str,
383
383
raw_entries = [entry .to_dict () for entry in self .entries ]
384
384
condensed_entries = self .condense_trace (raw_entries )
385
385
386
+ # Calculate total token counts from LLM API calls
387
+ total_prompt_tokens = 0
388
+ total_completion_tokens = 0
389
+ total_tokens = 0
390
+
391
+ for entry in condensed_entries :
392
+ if entry .get ("span_type" ) == "llm" and isinstance (entry .get ("output" ), dict ):
393
+ usage = entry ["output" ].get ("usage" , {})
394
+ # Handle OpenAI/Together format
395
+ if "prompt_tokens" in usage :
396
+ total_prompt_tokens += usage .get ("prompt_tokens" , 0 )
397
+ total_completion_tokens += usage .get ("completion_tokens" , 0 )
398
+ # Handle Anthropic format
399
+ elif "input_tokens" in usage :
400
+ total_prompt_tokens += usage .get ("input_tokens" , 0 )
401
+ total_completion_tokens += usage .get ("output_tokens" , 0 )
402
+ total_tokens += usage .get ("total_tokens" , 0 )
403
+
386
404
# Create trace document
387
405
trace_data = {
388
406
"trace_id" : self .trace_id ,
@@ -392,10 +410,10 @@ def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str,
392
410
"created_at" : datetime .fromtimestamp (self .start_time ).isoformat (),
393
411
"duration" : total_duration ,
394
412
"token_counts" : {
395
- "prompt_tokens" : 0 , # Dummy value
396
- "completion_tokens" : 0 , # Dummy value
397
- "total_tokens" : 0 , # Dummy value
398
- }, # TODO: Add token counts
413
+ "prompt_tokens" : total_prompt_tokens ,
414
+ "completion_tokens" : total_completion_tokens ,
415
+ "total_tokens" : total_tokens ,
416
+ },
399
417
"entries" : condensed_entries ,
400
418
"empty_save" : empty_save ,
401
419
"overwrite" : overwrite
0 commit comments