diff --git a/src/demo/simple_trace.py b/src/demo/simple_trace.py index 11afc069..c6187596 100644 --- a/src/demo/simple_trace.py +++ b/src/demo/simple_trace.py @@ -38,12 +38,12 @@ async def gather_information(city: str): weather = await get_weather(city) attractions = await get_attractions(city) - # judgment.async_evaluate( - # scorers=[AnswerRelevancyScorer(threshold=0.5)], - # input="What is the weather in Paris?", - # actual_output=weather, - # model="gpt-4", - # ) + judgment.async_evaluate( + scorers=[AnswerRelevancyScorer(threshold=0.5)], + input="What is the weather in Paris?", + actual_output=weather, + model="gpt-4", + ) return { "weather": weather, diff --git a/src/e2etests/test_judgee_traces_update.py b/src/e2etests/test_judgee_traces_update.py index 118f55c7..52a3769f 100644 --- a/src/e2etests/test_judgee_traces_update.py +++ b/src/e2etests/test_judgee_traces_update.py @@ -181,7 +181,7 @@ async def test_trace_save_increment(client, cleanup_traces): "project_name": "test_project", "trace_id": trace_id, "created_at": datetime.fromtimestamp(timestamp).isoformat(), - "entries": [ + "trace_spans": [ { "timestamp": datetime.fromtimestamp(timestamp).isoformat(), "type": "span", @@ -272,7 +272,7 @@ async def save_trace(index): "project_name": "test_project", "trace_id": trace_id, "created_at": datetime.fromtimestamp(timestamp).isoformat(), - "entries": [ + "trace_spans": [ { "timestamp": datetime.fromtimestamp(timestamp).isoformat(), "type": "span", @@ -354,7 +354,7 @@ async def test_failed_trace_counting(client): "project_name": "test_project", "trace_id": str(uuid4()), "created_at": str(timestamp), # Convert to string - # Missing entries, which should cause a validation error + # Missing trace_spans, which should cause a validation error "duration": 0.1, "token_counts": {"total": 10}, "empty_save": False, @@ -463,7 +463,7 @@ async def test_burst_request_handling(client): "project_name": "test_project", "trace_id": trace_id, "created_at": datetime.fromtimestamp(timestamp).isoformat(), - "entries": [ + "trace_spans": [ { "timestamp": datetime.fromtimestamp(timestamp).isoformat(), "type": "span", @@ -488,8 +488,8 @@ async def save_trace(): # Create a unique trace ID for each request local_trace_data = trace_data.copy() local_trace_data["trace_id"] = str(uuid4()) - local_trace_data["entries"][0]["span_id"] = str(uuid4()) - local_trace_data["entries"][0]["trace_id"] = local_trace_data["trace_id"] + local_trace_data["trace_spans"][0]["span_id"] = str(uuid4()) + local_trace_data["trace_spans"][0]["trace_id"] = local_trace_data["trace_id"] response = await client.post( f"{SERVER_URL}/traces/save/", diff --git a/src/e2etests/test_tracer.py b/src/e2etests/test_tracer.py index 1f9cbf65..cb3e2059 100644 --- a/src/e2etests/test_tracer.py +++ b/src/e2etests/test_tracer.py @@ -590,17 +590,6 @@ async def run_async_stream(prompt): return result # --- END NEW TESTS --- - -# Helper function to print trace hierarchy -def print_trace_hierarchy(entries): - """Print a hierarchical representation of the trace for debugging.""" - # First, organize entries by parent_span_id - entries_by_parent = {} - for entry in entries: - parent_id = entry["parent_span_id"] - if parent_id not in entries_by_parent: - entries_by_parent[parent_id] = [] - entries_by_parent[parent_id].append(entry) # --- NEW COMPREHENSIVE TOKEN COUNTING TEST --- diff --git a/src/judgeval/common/tracer.py b/src/judgeval/common/tracer.py index 2c7285af..ce4d79e4 100644 --- a/src/judgeval/common/tracer.py +++ b/src/judgeval/common/tracer.py @@ -596,7 +596,7 @@ def save(self, overwrite: bool = False) -> Tuple[str, dict]: "project_name": self.project_name, "created_at": datetime.utcfromtimestamp(self.start_time).isoformat(), "duration": total_duration, - "entries": [span.model_dump() for span in self.trace_spans], + "trace_spans": [span.model_dump() for span in self.trace_spans], "evaluation_runs": [run.model_dump() for run in self.evaluation_runs], "overwrite": overwrite, "offline_mode": self.tracer.offline_mode, diff --git a/src/judgeval/data/datasets/dataset.py b/src/judgeval/data/datasets/dataset.py index 9759ac17..ffbf503d 100644 --- a/src/judgeval/data/datasets/dataset.py +++ b/src/judgeval/data/datasets/dataset.py @@ -5,14 +5,15 @@ import os import yaml from dataclasses import dataclass, field -from typing import List, Union, Literal +from typing import List, Union, Literal, Optional -from judgeval.data import Example +from judgeval.data import Example, Trace from judgeval.common.logger import debug, error, warning, info @dataclass class EvalDataset: examples: List[Example] + traces: List[Trace] _alias: Union[str, None] = field(default=None) _id: Union[str, None] = field(default=None) judgment_api_key: str = field(default="") @@ -20,12 +21,13 @@ class EvalDataset: def __init__(self, judgment_api_key: str = os.getenv("JUDGMENT_API_KEY"), organization_id: str = os.getenv("JUDGMENT_ORG_ID"), - examples: List[Example] = [], + examples: Optional[List[Example]] = None, + traces: Optional[List[Trace]] = None ): - debug(f"Initializing EvalDataset with {len(examples)} examples") if not judgment_api_key: warning("No judgment_api_key provided") - self.examples = examples + self.examples = examples or [] + self.traces = traces or [] self._alias = None self._id = None self.judgment_api_key = judgment_api_key @@ -218,8 +220,11 @@ def add_from_yaml(self, file_path: str) -> None: self.add_example(e) def add_example(self, e: Example) -> None: - self.examples = self.examples + [e] + self.examples.append(e) # TODO if we need to add rank, then we need to do it here + + def add_trace(self, t: Trace) -> None: + self.traces.append(t) def save_as(self, file_type: Literal["json", "csv", "yaml"], dir_path: str, save_name: str = None) -> None: """ @@ -307,6 +312,7 @@ def __str__(self): return ( f"{self.__class__.__name__}(" f"examples={self.examples}, " + f"traces={self.traces}, " f"_alias={self._alias}, " f"_id={self._id}" f")" diff --git a/src/judgeval/data/datasets/eval_dataset_client.py b/src/judgeval/data/datasets/eval_dataset_client.py index a84eae9e..4e91f692 100644 --- a/src/judgeval/data/datasets/eval_dataset_client.py +++ b/src/judgeval/data/datasets/eval_dataset_client.py @@ -13,7 +13,7 @@ JUDGMENT_DATASETS_INSERT_API_URL, JUDGMENT_DATASETS_EXPORT_JSONL_API_URL ) -from judgeval.data import Example +from judgeval.data import Example, Trace from judgeval.data.datasets import EvalDataset @@ -58,6 +58,7 @@ def push(self, dataset: EvalDataset, alias: str, project_name: str, overwrite: O "dataset_alias": alias, "project_name": project_name, "examples": [e.to_dict() for e in dataset.examples], + "traces": [t.model_dump() for t in dataset.traces], "overwrite": overwrite, } try: @@ -202,6 +203,7 @@ def pull(self, alias: str, project_name: str) -> EvalDataset: info(f"Successfully pulled dataset with alias '{alias}'") payload = response.json() dataset.examples = [Example(**e) for e in payload.get("examples", [])] + dataset.traces = [Trace(**t) for t in payload.get("traces", [])] dataset._alias = payload.get("alias") dataset._id = payload.get("id") progress.update( diff --git a/src/judgeval/data/trace.py b/src/judgeval/data/trace.py index 42bcdaf3..2ed8e6a7 100644 --- a/src/judgeval/data/trace.py +++ b/src/judgeval/data/trace.py @@ -117,7 +117,7 @@ class Trace(BaseModel): name: str created_at: str duration: float - entries: List[TraceSpan] + trace_spans: List[TraceSpan] overwrite: bool = False offline_mode: bool = False rules: Optional[Dict[str, Any]] = None diff --git a/src/judgeval/run_evaluation.py b/src/judgeval/run_evaluation.py index d2cca944..347ec117 100644 --- a/src/judgeval/run_evaluation.py +++ b/src/judgeval/run_evaluation.py @@ -420,7 +420,7 @@ def run_trace_eval(trace_run: TraceRun, override: bool = False, ignore_errors: b for i, trace in enumerate(tracer.traces): # We set the root-level trace span with the expected tools of the Trace trace = Trace(**trace) - trace.entries[0].expected_tools = examples[i].expected_tools + trace.trace_spans[0].expected_tools = examples[i].expected_tools new_traces.append(trace) trace_run.traces = new_traces tracer.traces = [] diff --git a/src/tests/common/test_tracer.py b/src/tests/common/test_tracer.py index 147d9902..96c3539a 100644 --- a/src/tests/common/test_tracer.py +++ b/src/tests/common/test_tracer.py @@ -135,7 +135,7 @@ def test_trace_client_span(trace_client): assert len(trace_client.trace_spans) == initial_spans_count + 1 def test_trace_client_nested_spans(trace_client): - """Test nested spans maintain proper depth recorded in entries""" + """Test nested spans maintain proper depth recorded in trace_spans""" root_span_id = current_span_var.get() # From the fixture with trace_client.span("outer") as outer_span: