Skip to content

Trainer error handling #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/judgeval/common/trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TrainerConfig:
rft_provider: str = "fireworks"
num_steps: int = 5
num_generations_per_prompt: int = (
5 # Number of rollouts/generations per input prompt
4 # Number of rollouts/generations per input prompt
)
num_prompts_per_step: int = 4 # Number of input prompts to sample per training step
concurrency: int = 100
Expand Down
156 changes: 90 additions & 66 deletions src/judgeval/common/trainer/trainable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .config import TrainerConfig, ModelConfig
from typing import Optional, Dict, Any, Callable
from .console import _model_spinner_progress, _print_model_progress
from judgeval.common.exceptions import JudgmentAPIError


class TrainableModel:
Expand All @@ -20,13 +21,18 @@ def __init__(self, config: TrainerConfig):
Args:
config: TrainerConfig instance with model configuration
"""
self.config = config
self.current_step = 0
self._current_model = None
self._tracer_wrapper_func = None
try:
self.config = config
self.current_step = 0
self._current_model = None
self._tracer_wrapper_func = None

self._base_model = self._create_base_model()
self._current_model = self._base_model
self._base_model = self._create_base_model()
self._current_model = self._base_model
except Exception as e:
raise JudgmentAPIError(
f"Failed to initialize TrainableModel: {str(e)}"
) from e

@classmethod
def from_model_config(cls, model_config: ModelConfig) -> "TrainableModel":
Expand Down Expand Up @@ -58,38 +64,48 @@ def from_model_config(cls, model_config: ModelConfig) -> "TrainableModel":

def _create_base_model(self):
"""Create and configure the base model."""
with _model_spinner_progress(
"Creating and deploying base model..."
) as update_progress:
update_progress("Creating base model instance...")
base_model = LLM(
model=self.config.base_model_name,
deployment_type="on-demand",
id=self.config.deployment_id,
enable_addons=self.config.enable_addons,
)
update_progress("Applying deployment configuration...")
base_model.apply()
_print_model_progress("Base model deployment ready")
return base_model
try:
with _model_spinner_progress(
"Creating and deploying base model..."
) as update_progress:
update_progress("Creating base model instance...")
base_model = LLM(
model=self.config.base_model_name,
deployment_type="on-demand",
id=self.config.deployment_id,
enable_addons=self.config.enable_addons,
)
update_progress("Applying deployment configuration...")
base_model.apply()
_print_model_progress("Base model deployment ready")
return base_model
except Exception as e:
raise JudgmentAPIError(
f"Failed to create and deploy base model '{self.config.base_model_name}': {str(e)}"
) from e

def _load_trained_model(self, model_name: str):
"""Load a trained model by name."""
with _model_spinner_progress(
f"Loading and deploying trained model: {model_name}"
) as update_progress:
update_progress("Creating trained model instance...")
self._current_model = LLM(
model=model_name,
deployment_type="on-demand-lora",
base_id=self.config.deployment_id,
)
update_progress("Applying deployment configuration...")
self._current_model.apply()
_print_model_progress("Trained model deployment ready")
try:
with _model_spinner_progress(
f"Loading and deploying trained model: {model_name}"
) as update_progress:
update_progress("Creating trained model instance...")
self._current_model = LLM(
model=model_name,
deployment_type="on-demand-lora",
base_id=self.config.deployment_id,
)
update_progress("Applying deployment configuration...")
self._current_model.apply()
_print_model_progress("Trained model deployment ready")

if self._tracer_wrapper_func:
self._tracer_wrapper_func(self._current_model)
if self._tracer_wrapper_func:
self._tracer_wrapper_func(self._current_model)
except Exception as e:
raise JudgmentAPIError(
f"Failed to load and deploy trained model '{model_name}': {str(e)}"
) from e

def get_current_model(self):
return self._current_model
Expand All @@ -111,29 +127,32 @@ def advance_to_next_step(self, step: int):
Args:
step: The current training step number
"""
self.current_step = step

if step == 0:
self._current_model = self._base_model
else:
model_name = (
f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{step}"
)
with _model_spinner_progress(
f"Creating and deploying model snapshot: {model_name}"
) as update_progress:
update_progress("Creating model snapshot instance...")
self._current_model = LLM(
model=model_name,
deployment_type="on-demand-lora",
base_id=self.config.deployment_id,
)
update_progress("Applying deployment configuration...")
self._current_model.apply()
_print_model_progress("Model snapshot deployment ready")

if self._tracer_wrapper_func:
self._tracer_wrapper_func(self._current_model)
try:
self.current_step = step

if step == 0:
self._current_model = self._base_model
else:
model_name = f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{step}"
with _model_spinner_progress(
f"Creating and deploying model snapshot: {model_name}"
) as update_progress:
update_progress("Creating model snapshot instance...")
self._current_model = LLM(
model=model_name,
deployment_type="on-demand-lora",
base_id=self.config.deployment_id,
)
update_progress("Applying deployment configuration...")
self._current_model.apply()
_print_model_progress("Model snapshot deployment ready")

if self._tracer_wrapper_func:
self._tracer_wrapper_func(self._current_model)
except Exception as e:
raise JudgmentAPIError(
f"Failed to advance to training step {step}: {str(e)}"
) from e

def perform_reinforcement_step(self, dataset, step: int):
"""
Expand All @@ -146,15 +165,20 @@ def perform_reinforcement_step(self, dataset, step: int):
Returns:
Training job object
"""
model_name = f"{self.config.model_id}-v{step + 1}"
return self._current_model.reinforcement_step(
dataset=dataset,
output_model=model_name,
epochs=self.config.epochs,
learning_rate=self.config.learning_rate,
accelerator_count=self.config.accelerator_count,
accelerator_type=self.config.accelerator_type,
)
try:
model_name = f"{self.config.model_id}-v{step + 1}"
return self._current_model.reinforcement_step(
dataset=dataset,
output_model=model_name,
epochs=self.config.epochs,
learning_rate=self.config.learning_rate,
accelerator_count=self.config.accelerator_count,
accelerator_type=self.config.accelerator_type,
)
except Exception as e:
raise JudgmentAPIError(
f"Failed to start reinforcement learning step {step + 1}: {str(e)}"
) from e

def get_model_config(
self, training_params: Optional[Dict[str, Any]] = None
Expand Down
50 changes: 34 additions & 16 deletions src/judgeval/common/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from judgeval.scorers import BaseScorer, APIScorerConfig
from judgeval.data import Example
from .console import _spinner_progress, _print_progress, _print_progress_update
from judgeval.common.exceptions import JudgmentAPIError


class JudgmentTrainer:
Expand All @@ -35,17 +36,22 @@ def __init__(
trainable_model: Optional trainable model instance
project_name: Project name for organizing training runs and evaluations
"""
self.config = config
self.tracer = tracer
self.tracer.show_trace_urls = False
self.project_name = project_name or "judgment_training"

if trainable_model is None:
self.trainable_model = TrainableModel(self.config)
else:
self.trainable_model = trainable_model

self.judgment_client = JudgmentClient()
try:
self.config = config
self.tracer = tracer
self.tracer.show_trace_urls = False
self.project_name = project_name or "judgment_training"

if trainable_model is None:
self.trainable_model = TrainableModel(self.config)
else:
self.trainable_model = trainable_model

self.judgment_client = JudgmentClient()
except Exception as e:
raise JudgmentAPIError(
f"Failed to initialize JudgmentTrainer: {str(e)}"
) from e

async def generate_rollouts_and_rewards(
self,
Expand Down Expand Up @@ -97,7 +103,9 @@ async def generate_single_response(prompt_id, generation_id):
pass

example = Example(
input=prompt_input, messages=messages, actual_output=response_data
input=prompt_input,
messages=messages,
actual_output=response_data,
)

scoring_results = self.judgment_client.run_evaluation(
Expand Down Expand Up @@ -238,7 +246,9 @@ async def run_reinforcement_learning(
time.sleep(10)
job = job.get()
if job is None:
raise Exception("Job was deleted while waiting for completion")
raise JudgmentAPIError(
"Training job was deleted while waiting for completion"
)

_print_progress(
f"Training completed! New model: {job.output_model}",
Expand Down Expand Up @@ -277,7 +287,15 @@ async def train(
Returns:
ModelConfig: Configuration of the trained model for future loading
"""
if rft_provider is not None:
self.config.rft_provider = rft_provider
try:
if rft_provider is not None:
self.config.rft_provider = rft_provider

return await self.run_reinforcement_learning(agent_function, scorers, prompts)
return await self.run_reinforcement_learning(
agent_function, scorers, prompts
)
except JudgmentAPIError:
# Re-raise JudgmentAPIError as-is
raise
except Exception as e:
raise JudgmentAPIError(f"Training process failed: {str(e)}") from e
Loading