Skip to content

Add Custom Judge Models for Custom Scorers #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions e2etests/judgment_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
FaithfulnessScorer,
HallucinationScorer,
)
from judgeval.judges import TogetherJudge
from judgeval.judges import TogetherJudge, judgevalJudge
from judgeval.playground import CustomFaithfulnessMetric
from judgeval.data.datasets.dataset import EvalDataset
from dotenv import load_dotenv
Expand Down Expand Up @@ -76,6 +76,7 @@ def test_run_eval(client: JudgmentClient):
results = client.pull_eval(project_name=PROJECT_NAME, eval_run_name=EVAL_RUN_NAME)
print(f"Evaluation results for {EVAL_RUN_NAME} from database:", results)


def test_override_eval(client: JudgmentClient):
example1 = Example(
input="What if these shoes don't fit?",
Expand Down Expand Up @@ -147,7 +148,6 @@ def test_override_eval(client: JudgmentClient):
raise
print(f"Successfully caught expected error: {e}")



def test_evaluate_dataset(client: JudgmentClient):

Expand Down Expand Up @@ -180,6 +180,7 @@ def test_evaluate_dataset(client: JudgmentClient):

print(res)


def test_classifier_scorer(client: JudgmentClient):
classifier_scorer = client.fetch_classifier_scorer("tonescorer-72gl")
faithfulness_scorer = FaithfulnessScorer(threshold=0.5)
Expand All @@ -197,6 +198,57 @@ def test_classifier_scorer(client: JudgmentClient):
)
print(res)


def test_custom_judge_vertexai(client: JudgmentClient):

import vertexai
from vertexai.generative_models import GenerativeModel

PROJECT_ID = "judgment-labs"
vertexai.init(project=PROJECT_ID, location="us-west1")

class VertexAIJudge(judgevalJudge):

def __init__(self, model_name: str = "gemini-1.5-flash-002"):
self.model_name = model_name
self.model = GenerativeModel(self.model_name)

def load_model(self):
return self.model

def generate(self, prompt) -> str:
# prompt is a List[dict] (conversation history)
# For models that don't support conversation history, we need to convert to string
# If you're using a model that supports chat history, you can just pass the prompt directly
response = self.model.generate_content(str(prompt))
return response.text

async def a_generate(self, prompt) -> str:
# prompt is a List[dict] (conversation history)
# For models that don't support conversation history, we need to convert to string
# If you're using a model that supports chat history, you can just pass the prompt directly
response = await self.model.generate_content_async(str(prompt))
return response.text

def get_model_name(self) -> str:
return self.model_name

example = Example(
input="What is the largest animal in the world?",
actual_output="The blue whale is the largest known animal.",
retrieval_context=["The blue whale is the largest known animal."],
)

judge = VertexAIJudge()

res = client.run_evaluation(
examples=[example],
scorers=[CustomFaithfulnessMetric()],
model=judge,
)
print(res)


if __name__ == "__main__":
# Test client functionality
client = get_client()
Expand Down Expand Up @@ -229,4 +281,9 @@ def test_classifier_scorer(client: JudgmentClient):
print("Classifier scorer test successful")
print("*" * 40)

print("Testing custom judge")
test_custom_judge_vertexai(ui_client)
print("Custom judge test successful")
print("*" * 40)

print("All tests passed successfully")
123 changes: 0 additions & 123 deletions judgeval/common/telemetry.py

This file was deleted.

32 changes: 24 additions & 8 deletions judgeval/evaluation_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from pydantic import BaseModel, field_validator

from judgeval.data import Example
from judgeval.data.datasets import EvalDataset
from judgeval.scorers import CustomScorer, JudgmentScorer
from judgeval.constants import ACCEPTABLE_MODELS
from judgeval.common.logger import debug, error
from judgeval.judges import judgevalJudge

class EvaluationRun(BaseModel):
"""
Stores example and evaluation scorers together for running an eval task
Expand All @@ -27,7 +28,7 @@ class EvaluationRun(BaseModel):
eval_name: Optional[str] = None
examples: List[Example]
scorers: List[Union[JudgmentScorer, CustomScorer]]
model: Union[str, List[str]]
model: Union[str, List[str], judgevalJudge]
aggregator: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
# API Key will be "" until user calls client.run_eval(), then API Key will be set
Expand Down Expand Up @@ -74,18 +75,33 @@ def validate_scorers(cls, v):
return v

@field_validator('model')
def validate_model(cls, v):
def validate_model(cls, v, values):
if not v:
raise ValueError("Model cannot be empty.")
if not isinstance(v, str) and not isinstance(v, list):
raise ValueError("Model must be a string or a list of strings.")
if isinstance(v, str) and v not in ACCEPTABLE_MODELS:
raise ValueError(f"Model name {v} not recognized.")

# Check if model is a judgevalJudge
if isinstance(v, judgevalJudge):
# Verify all scorers are CustomScorer when using judgevalJudge
scorers = values.data.get('scorers', [])
if not all(isinstance(s, CustomScorer) for s in scorers):
raise ValueError("When using a judgevalJudge model, all scorers must be CustomScorer type")
return v

# Check if model is string or list of strings
if isinstance(v, str):
if v not in ACCEPTABLE_MODELS:
raise ValueError(f"Model name {v} not recognized.")
return v

if isinstance(v, list):
if not all(isinstance(m, str) for m in v):
raise ValueError("When providing a list of models, all elements must be strings")
for m in v:
if m not in ACCEPTABLE_MODELS:
raise ValueError(f"Model name {m} not recognized.")
return v
return v

raise ValueError("Model must be one of: string, list of strings, or judgevalJudge instance")

@field_validator('aggregator', mode='before')
def validate_aggregator(cls, v, values):
Expand Down
3 changes: 2 additions & 1 deletion judgeval/judgment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from judgeval.scorers import JudgmentScorer, CustomScorer, ClassifierScorer
from judgeval.evaluation_run import EvaluationRun
from judgeval.run_evaluation import run_eval
from judgeval.judges import judgevalJudge
from judgeval.constants import JUDGMENT_EVAL_FETCH_API_URL
from judgeval.common.exceptions import JudgmentAPIError
from pydantic import BaseModel
Expand Down Expand Up @@ -38,7 +39,7 @@ def run_evaluation(
self,
examples: List[Example],
scorers: List[Union[JudgmentScorer, CustomScorer]],
model: Union[str, List[str]],
model: Union[str, List[str], judgevalJudge],
aggregator: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
log_results: bool = False,
Expand Down
31 changes: 15 additions & 16 deletions judgeval/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from judgeval.judges.utils import create_judge
from judgeval.scorers.custom_scorer import CustomScorer
from judgeval.scorers.score import *
from judgeval.common.telemetry import capture_metric_type

"""
Testing implementation of CustomFaithfulness
Expand Down Expand Up @@ -195,22 +194,22 @@ def metric_progress_indicator(
total: int = 9999,
transient: bool = True,
):
with capture_metric_type(metric.__name__):
console = Console(file=sys.stderr) # Direct output to standard error
if _show_indicator:
with Progress(
SpinnerColumn(style="rgb(106,0,255)"),
TextColumn("[progress.description]{task.description}"),
console=console, # Use the custom console
transient=transient,
) as progress:
progress.add_task(
description=scorer_console_msg(metric, async_mode),
total=total,
)
yield
else:

console = Console(file=sys.stderr) # Direct output to standard error
if _show_indicator:
with Progress(
SpinnerColumn(style="rgb(106,0,255)"),
TextColumn("[progress.description]{task.description}"),
console=console, # Use the custom console
transient=transient,
) as progress:
progress.add_task(
description=scorer_console_msg(metric, async_mode),
total=total,
)
yield
else:
yield


def prettify_list(lst: List[Any]):
Expand Down
1 change: 0 additions & 1 deletion judgeval/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def run_eval(evaluation_run: EvaluationRun, override: bool = False) -> List[Scor
else:
custom_scorers.append(scorer)
debug(f"Added custom scorer: {type(scorer).__name__}")

debug(f"Found {len(judgment_scorers)} judgment scorers and {len(custom_scorers)} custom scorers")

api_results: List[ScoringResult] = []
Expand Down
Loading
Loading