Skip to content

Commit d7cc3a9

Browse files
authored
Merge pull request #65 from JudgmentLabs/az-all-user-db-endpoint
Retrieve all user datasets
2 parents 2040096 + 168180c commit d7cc3a9

File tree

8 files changed

+457
-1045
lines changed

8 files changed

+457
-1045
lines changed

Pipfile.lock

Lines changed: 192 additions & 898 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/e2etests/judgment_client_test.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
)
2323
from judgeval.judges import TogetherJudge, JudgevalJudge
2424
from playground import CustomFaithfulnessMetric
25-
from judgeval.data.datasets.dataset import EvalDataset
25+
from judgeval.data.datasets.dataset import EvalDataset, GroundTruthExample
26+
from judgeval.data.datasets.eval_dataset_client import EvalDatasetClient
2627
from judgeval.scorers.prompt_scorer import ClassifierScorer
2728

2829
# Configure logging
@@ -62,6 +63,30 @@ def test_dataset(self, client: JudgmentClient):
6263
dataset = client.pull_dataset(alias="test_dataset_5")
6364
assert dataset, "Failed to pull dataset"
6465

66+
def test_pull_all_user_dataset_stats(self, client: JudgmentClient):
67+
dataset: EvalDataset = client.create_dataset()
68+
dataset.add_example(Example(input="input 1", actual_output="output 1"))
69+
dataset.add_example(Example(input="input 2", actual_output="output 2"))
70+
dataset.add_example(Example(input="input 3", actual_output="output 3"))
71+
random_name1 = ''.join(random.choices(string.ascii_letters + string.digits, k=20))
72+
client.push_dataset(alias=random_name1, dataset=dataset, overwrite=False)
73+
74+
dataset: EvalDataset = client.create_dataset()
75+
dataset.add_example(Example(input="input 1", actual_output="output 1"))
76+
dataset.add_example(Example(input="input 2", actual_output="output 2"))
77+
dataset.add_ground_truth(GroundTruthExample(input="input 1", actual_output="output 1"))
78+
dataset.add_ground_truth(GroundTruthExample(input="input 2", actual_output="output 2"))
79+
random_name2 = ''.join(random.choices(string.ascii_letters + string.digits, k=20))
80+
client.push_dataset(alias=random_name2, dataset=dataset, overwrite=False)
81+
82+
all_datasets_stats = client.pull_all_user_dataset_stats()
83+
print(all_datasets_stats)
84+
assert all_datasets_stats, "Failed to pull dataset"
85+
assert all_datasets_stats[random_name1]["example_count"] == 3, f"{random_name1} should have 3 examples"
86+
assert all_datasets_stats[random_name1]["ground_truth_count"] == 0, f"{random_name1} should have 0 ground truths"
87+
assert all_datasets_stats[random_name2]["example_count"] == 2, f"{random_name2} should have 2 examples"
88+
assert all_datasets_stats[random_name2]["ground_truth_count"] == 2, f"{random_name2} should have 2 ground truths"
89+
6590
def test_run_eval(self, client: JudgmentClient):
6691
"""Test basic evaluation workflow."""
6792
# Single step in our workflow, an outreach Sales Agent
@@ -405,6 +430,7 @@ def run_selected_tests(client, test_names: list[str]):
405430

406431
test_map = {
407432
'dataset': test_basic_operations.test_dataset,
433+
'pull_all_user_dataset_stats': test_basic_operations.test_pull_all_user_dataset_stats,
408434
'run_eval': test_basic_operations.test_run_eval,
409435
'assert_test': test_basic_operations.test_assert_test,
410436
'json_scorer': test_advanced_features.test_json_scorer,
@@ -433,11 +459,12 @@ def run_selected_tests(client, test_names: list[str]):
433459

434460
run_selected_tests(client, [
435461
'dataset',
436-
'run_eval',
437-
'assert_test',
438-
'json_scorer',
439-
'override_eval',
440-
'evaluate_dataset',
441-
'classifier_scorer',
442-
'custom_judge_vertexai'
462+
'pull_all_user_dataset_stats',
463+
# 'run_eval',
464+
# 'assert_test',
465+
# 'json_scorer',
466+
# 'override_eval',
467+
# 'evaluate_dataset',
468+
# 'classifier_scorer',
469+
# 'custom_judge_vertexai'
443470
])

src/judgeval/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _missing_(cls, value):
3636
JUDGMENT_EVAL_API_URL = f"{ROOT_API}/evaluate/"
3737
JUDGMENT_DATASETS_PUSH_API_URL = f"{ROOT_API}/datasets/push/"
3838
JUDGMENT_DATASETS_PULL_API_URL = f"{ROOT_API}/datasets/pull/"
39+
JUDGMENT_DATASETS_PULL_ALL_API_URL = f"{ROOT_API}/datasets/get_all_stats/"
3940
JUDGMENT_EVAL_LOG_API_URL = f"{ROOT_API}/log_eval_results/"
4041
JUDGMENT_EVAL_FETCH_API_URL = f"{ROOT_API}/fetch_eval_results/"
4142
JUDGMENT_TRACES_SAVE_API_URL = f"{ROOT_API}/traces/save/"
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from judgeval.data.datasets.dataset import EvalDataset
22
from judgeval.data.datasets.ground_truth import GroundTruthExample
3+
from judgeval.data.datasets.eval_dataset_client import EvalDatasetClient
34

4-
__all__ = ["EvalDataset", "GroundTruthExample"]
5+
__all__ = ["EvalDataset", "EvalDatasetClient", "GroundTruthExample"]

src/judgeval/data/datasets/dataset.py

Lines changed: 1 addition & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,11 @@
22
import csv
33
import datetime
44
import json
5-
from rich.console import Console
6-
from rich.progress import Progress, SpinnerColumn, TextColumn
7-
import requests
85
from dataclasses import dataclass, field
96
import os
107
from typing import List, Optional, Union, Literal
118

12-
from judgeval.constants import JUDGMENT_DATASETS_PUSH_API_URL, JUDGMENT_DATASETS_PULL_API_URL
139
from judgeval.data.datasets.ground_truth import GroundTruthExample
14-
from judgeval.data.datasets.utils import ground_truths_to_examples, examples_to_ground_truths
1510
from judgeval.data import Example
1611
from judgeval.common.logger import debug, error, warning, info
1712

@@ -37,120 +32,6 @@ def __init__(self,
3732
self._id = None
3833
self.judgment_api_key = judgment_api_key
3934

40-
def push(self, alias: str, overwrite: Optional[bool] = False) -> bool:
41-
debug(f"Pushing dataset with alias '{alias}' (overwrite={overwrite})")
42-
if overwrite:
43-
warning(f"Overwrite enabled for alias '{alias}'")
44-
"""
45-
Pushes the dataset to Judgment platform
46-
47-
Mock request:
48-
{
49-
"alias": alias,
50-
"ground_truths": [...],
51-
"examples": [...],
52-
"overwrite": overwrite
53-
} ==>
54-
{
55-
"_alias": alias,
56-
"_id": "..." # ID of the dataset
57-
}
58-
"""
59-
with Progress(
60-
SpinnerColumn(style="rgb(106,0,255)"),
61-
TextColumn("[progress.description]{task.description}"),
62-
transient=False,
63-
) as progress:
64-
task_id = progress.add_task(
65-
f"Pushing [rgb(106,0,255)]'{alias}' to Judgment...",
66-
total=100,
67-
)
68-
content = {
69-
"alias": alias,
70-
"ground_truths": [g.to_dict() for g in self.ground_truths],
71-
"examples": [e.to_dict() for e in self.examples],
72-
"overwrite": overwrite,
73-
"judgment_api_key": self.judgment_api_key
74-
}
75-
try:
76-
response = requests.post(
77-
JUDGMENT_DATASETS_PUSH_API_URL,
78-
json=content
79-
)
80-
if response.status_code == 500:
81-
error(f"Server error during push: {content.get('message')}")
82-
return False
83-
response.raise_for_status()
84-
except requests.exceptions.HTTPError as err:
85-
if response.status_code == 422:
86-
error(f"Validation error during push: {err.response.json()}")
87-
else:
88-
error(f"HTTP error during push: {err}")
89-
90-
info(f"Successfully pushed dataset with alias '{alias}'")
91-
payload = response.json()
92-
self._alias = payload.get("_alias")
93-
self._id = payload.get("_id")
94-
progress.update(
95-
task_id,
96-
description=f"{progress.tasks[task_id].description} [rgb(25,227,160)]Done!)",
97-
)
98-
return True
99-
100-
def pull(self, alias: str):
101-
debug(f"Pulling dataset with alias '{alias}'")
102-
"""
103-
Pulls the dataset from Judgment platform
104-
105-
Mock request:
106-
{
107-
"alias": alias,
108-
"user_id": user_id
109-
}
110-
==>
111-
{
112-
"ground_truths": [...],
113-
"examples": [...],
114-
"_alias": alias,
115-
"_id": "..." # ID of the dataset
116-
}
117-
"""
118-
# Make a POST request to the Judgment API to get the dataset
119-
120-
with Progress(
121-
SpinnerColumn(style="rgb(106,0,255)"),
122-
TextColumn("[progress.description]{task.description}"),
123-
transient=False,
124-
) as progress:
125-
task_id = progress.add_task(
126-
f"Pulling [rgb(106,0,255)]'{alias}'[/rgb(106,0,255)] from Judgment...",
127-
total=100,
128-
)
129-
request_body = {
130-
"alias": alias,
131-
"judgment_api_key": self.judgment_api_key
132-
}
133-
134-
try:
135-
response = requests.post(
136-
JUDGMENT_DATASETS_PULL_API_URL,
137-
json=request_body
138-
)
139-
response.raise_for_status()
140-
except requests.exceptions.RequestException as e:
141-
error(f"Error pulling dataset: {str(e)}")
142-
raise
143-
144-
info(f"Successfully pulled dataset with alias '{alias}'")
145-
payload = response.json()
146-
self.ground_truths = [GroundTruthExample(**g) for g in payload.get("ground_truths", [])]
147-
self.examples = [Example(**e) for e in payload.get("examples", [])]
148-
self._alias = payload.get("_alias")
149-
self._id = payload.get("_id")
150-
progress.update(
151-
task_id,
152-
description=f"{progress.tasks[task_id].description} [rgb(25,227,160)]Done!)",
153-
)
15435

15536
def add_from_json(self, file_path: str) -> None:
15637
debug(f"Loading dataset from JSON file: {file_path}")
@@ -402,6 +283,4 @@ def __str__(self):
402283
f"_alias={self._alias}, "
403284
f"_id={self._id}"
404285
f")"
405-
)
406-
407-
286+
)

0 commit comments

Comments
 (0)