|
2 | 2 | import os
|
3 | 3 | import time
|
4 | 4 | import asyncio
|
5 |
| -from typing import List |
| 5 | +from typing import Dict |
6 | 6 | import pytest
|
7 |
| -import re |
8 |
| -import sys |
9 |
| -from io import StringIO |
10 |
| -import json |
11 |
| -import inspect # Added for function signature inspection |
12 |
| -import pytest_asyncio # For async fixtures if needed later |
13 | 7 |
|
14 | 8 | # Third-party imports
|
15 | 9 | from openai import OpenAI, AsyncOpenAI
|
|
18 | 12 | from google import genai
|
19 | 13 |
|
20 | 14 | # Local imports
|
21 |
| -from judgeval.tracer import Tracer, wrap, TraceClient, TraceManagerClient |
22 |
| -from judgeval.constants import APIScorer |
| 15 | +from judgeval.tracer import Tracer, wrap, TraceManagerClient |
23 | 16 | from judgeval.scorers import FaithfulnessScorer, AnswerRelevancyScorer
|
24 | 17 | from judgeval.data import Example
|
25 |
| -# Import the utility functions from the new location |
26 |
| -from e2etests.utils import validate_trace_token_counts, validate_trace_tokens |
27 | 18 |
|
28 | 19 | # Initialize the tracer and clients
|
29 | 20 | # Ensure relevant API keys (OPENAI_API_KEY, ANTHROPIC_API_KEY, TOGETHER_API_KEY, GOOGLE_API_KEY) are set
|
|
59 | 50 | else:
|
60 | 51 | print("Warning: GOOGLE_API_KEY not found. Skipping Google tests.")
|
61 | 52 |
|
| 53 | +# Helper function |
| 54 | +def validate_trace_token_counts(trace_client) -> Dict[str, int]: |
| 55 | + """ |
| 56 | + Validates token counts from trace spans and performs assertions. |
| 57 | + |
| 58 | + Args: |
| 59 | + trace_client: The trace client instance containing trace spans |
| 60 | + |
| 61 | + Returns: |
| 62 | + Dict with calculated token counts (prompt_tokens, completion_tokens, total_tokens) |
| 63 | + |
| 64 | + Raises: |
| 65 | + AssertionError: If token count validations fail |
| 66 | + """ |
| 67 | + if not trace_client: |
| 68 | + pytest.fail("Failed to get trace client for token count validation") |
| 69 | + |
| 70 | + # Get spans from the trace client |
| 71 | + trace_spans = trace_client.trace_spans |
| 72 | + |
| 73 | + # Manually calculate token counts from trace spans |
| 74 | + manual_prompt_tokens = 0 |
| 75 | + manual_completion_tokens = 0 |
| 76 | + manual_total_tokens = 0 |
| 77 | + |
| 78 | + # Known LLM API call function names |
| 79 | + llm_span_names = {"OPENAI_API_CALL", "ANTHROPIC_API_CALL", "TOGETHER_API_CALL", "GOOGLE_API_CALL"} |
| 80 | + |
| 81 | + for span in trace_spans: |
| 82 | + if span.span_type == "llm" and span.function in llm_span_names: |
| 83 | + usage = span.usage |
| 84 | + if usage and "info" not in usage: # Check if it's actual usage data |
| 85 | + # Correctly handle different key names from different providers |
| 86 | + |
| 87 | + prompt_tokens = usage.prompt_tokens |
| 88 | + completion_tokens = usage.completion_tokens |
| 89 | + total_tokens = usage.total_tokens |
| 90 | + |
| 91 | + # Accumulate separately |
| 92 | + manual_prompt_tokens += prompt_tokens |
| 93 | + manual_completion_tokens += completion_tokens |
| 94 | + manual_total_tokens += total_tokens |
| 95 | + |
| 96 | + assert manual_prompt_tokens > 0, "Prompt tokens should be counted" |
| 97 | + assert manual_completion_tokens > 0, "Completion tokens should be counted" |
| 98 | + assert manual_total_tokens > 0, "Total tokens should be counted" |
| 99 | + assert manual_total_tokens == (manual_prompt_tokens + manual_completion_tokens), \ |
| 100 | + "Total tokens should equal prompt + completion" |
| 101 | + |
| 102 | + return { |
| 103 | + "prompt_tokens": manual_prompt_tokens, |
| 104 | + "completion_tokens": manual_completion_tokens, |
| 105 | + "total_tokens": manual_total_tokens |
| 106 | + } |
| 107 | + |
| 108 | +# Helper function |
| 109 | +def validate_trace_tokens(trace, fail_on_missing=True): |
| 110 | + """ |
| 111 | + Helper function to validate token counts in a trace |
| 112 | + |
| 113 | + Args: |
| 114 | + trace: The trace client to validate |
| 115 | + fail_on_missing: Whether to fail the test if no trace is available |
| 116 | + |
| 117 | + Returns: |
| 118 | + The token counts if validation succeeded |
| 119 | + """ |
| 120 | + if not trace: |
| 121 | + print("Warning: Could not get current trace to perform assertions.") |
| 122 | + if fail_on_missing: |
| 123 | + pytest.fail("Failed to get current trace within decorated function.") |
| 124 | + return None |
| 125 | + |
| 126 | + print("\nAttempting assertions on current trace state (before decorator save)...") |
| 127 | + |
| 128 | + # Use the utility function for token count validation |
| 129 | + token_counts = validate_trace_token_counts(trace) |
| 130 | + |
| 131 | + print(f"Calculated token counts: P={token_counts['prompt_tokens']}, C={token_counts['completion_tokens']}, T={token_counts['total_tokens']}") |
| 132 | + |
| 133 | + return token_counts |
62 | 134 |
|
63 | 135 | # --- Test Functions ---
|
64 | 136 |
|
|
0 commit comments