Skip to content

Commit 88b6509

Browse files
authored
Trainer error handling (#513)
* add try except * cleanup try except * fix config
1 parent aa3c40c commit 88b6509

File tree

3 files changed

+125
-83
lines changed

3 files changed

+125
-83
lines changed

src/judgeval/common/trainer/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class TrainerConfig:
1414
rft_provider: str = "fireworks"
1515
num_steps: int = 5
1616
num_generations_per_prompt: int = (
17-
5 # Number of rollouts/generations per input prompt
17+
4 # Number of rollouts/generations per input prompt
1818
)
1919
num_prompts_per_step: int = 4 # Number of input prompts to sample per training step
2020
concurrency: int = 100

src/judgeval/common/trainer/trainable_model.py

Lines changed: 90 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .config import TrainerConfig, ModelConfig
33
from typing import Optional, Dict, Any, Callable
44
from .console import _model_spinner_progress, _print_model_progress
5+
from judgeval.common.exceptions import JudgmentAPIError
56

67

78
class TrainableModel:
@@ -20,13 +21,18 @@ def __init__(self, config: TrainerConfig):
2021
Args:
2122
config: TrainerConfig instance with model configuration
2223
"""
23-
self.config = config
24-
self.current_step = 0
25-
self._current_model = None
26-
self._tracer_wrapper_func = None
24+
try:
25+
self.config = config
26+
self.current_step = 0
27+
self._current_model = None
28+
self._tracer_wrapper_func = None
2729

28-
self._base_model = self._create_base_model()
29-
self._current_model = self._base_model
30+
self._base_model = self._create_base_model()
31+
self._current_model = self._base_model
32+
except Exception as e:
33+
raise JudgmentAPIError(
34+
f"Failed to initialize TrainableModel: {str(e)}"
35+
) from e
3036

3137
@classmethod
3238
def from_model_config(cls, model_config: ModelConfig) -> "TrainableModel":
@@ -58,38 +64,48 @@ def from_model_config(cls, model_config: ModelConfig) -> "TrainableModel":
5864

5965
def _create_base_model(self):
6066
"""Create and configure the base model."""
61-
with _model_spinner_progress(
62-
"Creating and deploying base model..."
63-
) as update_progress:
64-
update_progress("Creating base model instance...")
65-
base_model = LLM(
66-
model=self.config.base_model_name,
67-
deployment_type="on-demand",
68-
id=self.config.deployment_id,
69-
enable_addons=self.config.enable_addons,
70-
)
71-
update_progress("Applying deployment configuration...")
72-
base_model.apply()
73-
_print_model_progress("Base model deployment ready")
74-
return base_model
67+
try:
68+
with _model_spinner_progress(
69+
"Creating and deploying base model..."
70+
) as update_progress:
71+
update_progress("Creating base model instance...")
72+
base_model = LLM(
73+
model=self.config.base_model_name,
74+
deployment_type="on-demand",
75+
id=self.config.deployment_id,
76+
enable_addons=self.config.enable_addons,
77+
)
78+
update_progress("Applying deployment configuration...")
79+
base_model.apply()
80+
_print_model_progress("Base model deployment ready")
81+
return base_model
82+
except Exception as e:
83+
raise JudgmentAPIError(
84+
f"Failed to create and deploy base model '{self.config.base_model_name}': {str(e)}"
85+
) from e
7586

7687
def _load_trained_model(self, model_name: str):
7788
"""Load a trained model by name."""
78-
with _model_spinner_progress(
79-
f"Loading and deploying trained model: {model_name}"
80-
) as update_progress:
81-
update_progress("Creating trained model instance...")
82-
self._current_model = LLM(
83-
model=model_name,
84-
deployment_type="on-demand-lora",
85-
base_id=self.config.deployment_id,
86-
)
87-
update_progress("Applying deployment configuration...")
88-
self._current_model.apply()
89-
_print_model_progress("Trained model deployment ready")
89+
try:
90+
with _model_spinner_progress(
91+
f"Loading and deploying trained model: {model_name}"
92+
) as update_progress:
93+
update_progress("Creating trained model instance...")
94+
self._current_model = LLM(
95+
model=model_name,
96+
deployment_type="on-demand-lora",
97+
base_id=self.config.deployment_id,
98+
)
99+
update_progress("Applying deployment configuration...")
100+
self._current_model.apply()
101+
_print_model_progress("Trained model deployment ready")
90102

91-
if self._tracer_wrapper_func:
92-
self._tracer_wrapper_func(self._current_model)
103+
if self._tracer_wrapper_func:
104+
self._tracer_wrapper_func(self._current_model)
105+
except Exception as e:
106+
raise JudgmentAPIError(
107+
f"Failed to load and deploy trained model '{model_name}': {str(e)}"
108+
) from e
93109

94110
def get_current_model(self):
95111
return self._current_model
@@ -111,29 +127,32 @@ def advance_to_next_step(self, step: int):
111127
Args:
112128
step: The current training step number
113129
"""
114-
self.current_step = step
115-
116-
if step == 0:
117-
self._current_model = self._base_model
118-
else:
119-
model_name = (
120-
f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{step}"
121-
)
122-
with _model_spinner_progress(
123-
f"Creating and deploying model snapshot: {model_name}"
124-
) as update_progress:
125-
update_progress("Creating model snapshot instance...")
126-
self._current_model = LLM(
127-
model=model_name,
128-
deployment_type="on-demand-lora",
129-
base_id=self.config.deployment_id,
130-
)
131-
update_progress("Applying deployment configuration...")
132-
self._current_model.apply()
133-
_print_model_progress("Model snapshot deployment ready")
134-
135-
if self._tracer_wrapper_func:
136-
self._tracer_wrapper_func(self._current_model)
130+
try:
131+
self.current_step = step
132+
133+
if step == 0:
134+
self._current_model = self._base_model
135+
else:
136+
model_name = f"accounts/{self.config.user_id}/models/{self.config.model_id}-v{step}"
137+
with _model_spinner_progress(
138+
f"Creating and deploying model snapshot: {model_name}"
139+
) as update_progress:
140+
update_progress("Creating model snapshot instance...")
141+
self._current_model = LLM(
142+
model=model_name,
143+
deployment_type="on-demand-lora",
144+
base_id=self.config.deployment_id,
145+
)
146+
update_progress("Applying deployment configuration...")
147+
self._current_model.apply()
148+
_print_model_progress("Model snapshot deployment ready")
149+
150+
if self._tracer_wrapper_func:
151+
self._tracer_wrapper_func(self._current_model)
152+
except Exception as e:
153+
raise JudgmentAPIError(
154+
f"Failed to advance to training step {step}: {str(e)}"
155+
) from e
137156

138157
def perform_reinforcement_step(self, dataset, step: int):
139158
"""
@@ -146,15 +165,20 @@ def perform_reinforcement_step(self, dataset, step: int):
146165
Returns:
147166
Training job object
148167
"""
149-
model_name = f"{self.config.model_id}-v{step + 1}"
150-
return self._current_model.reinforcement_step(
151-
dataset=dataset,
152-
output_model=model_name,
153-
epochs=self.config.epochs,
154-
learning_rate=self.config.learning_rate,
155-
accelerator_count=self.config.accelerator_count,
156-
accelerator_type=self.config.accelerator_type,
157-
)
168+
try:
169+
model_name = f"{self.config.model_id}-v{step + 1}"
170+
return self._current_model.reinforcement_step(
171+
dataset=dataset,
172+
output_model=model_name,
173+
epochs=self.config.epochs,
174+
learning_rate=self.config.learning_rate,
175+
accelerator_count=self.config.accelerator_count,
176+
accelerator_type=self.config.accelerator_type,
177+
)
178+
except Exception as e:
179+
raise JudgmentAPIError(
180+
f"Failed to start reinforcement learning step {step + 1}: {str(e)}"
181+
) from e
158182

159183
def get_model_config(
160184
self, training_params: Optional[Dict[str, Any]] = None

src/judgeval/common/trainer/trainer.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from judgeval.scorers import BaseScorer, APIScorerConfig
1010
from judgeval.data import Example
1111
from .console import _spinner_progress, _print_progress, _print_progress_update
12+
from judgeval.common.exceptions import JudgmentAPIError
1213

1314

1415
class JudgmentTrainer:
@@ -35,17 +36,22 @@ def __init__(
3536
trainable_model: Optional trainable model instance
3637
project_name: Project name for organizing training runs and evaluations
3738
"""
38-
self.config = config
39-
self.tracer = tracer
40-
self.tracer.show_trace_urls = False
41-
self.project_name = project_name or "judgment_training"
42-
43-
if trainable_model is None:
44-
self.trainable_model = TrainableModel(self.config)
45-
else:
46-
self.trainable_model = trainable_model
47-
48-
self.judgment_client = JudgmentClient()
39+
try:
40+
self.config = config
41+
self.tracer = tracer
42+
self.tracer.show_trace_urls = False
43+
self.project_name = project_name or "judgment_training"
44+
45+
if trainable_model is None:
46+
self.trainable_model = TrainableModel(self.config)
47+
else:
48+
self.trainable_model = trainable_model
49+
50+
self.judgment_client = JudgmentClient()
51+
except Exception as e:
52+
raise JudgmentAPIError(
53+
f"Failed to initialize JudgmentTrainer: {str(e)}"
54+
) from e
4955

5056
async def generate_rollouts_and_rewards(
5157
self,
@@ -97,7 +103,9 @@ async def generate_single_response(prompt_id, generation_id):
97103
pass
98104

99105
example = Example(
100-
input=prompt_input, messages=messages, actual_output=response_data
106+
input=prompt_input,
107+
messages=messages,
108+
actual_output=response_data,
101109
)
102110

103111
scoring_results = self.judgment_client.run_evaluation(
@@ -238,7 +246,9 @@ async def run_reinforcement_learning(
238246
time.sleep(10)
239247
job = job.get()
240248
if job is None:
241-
raise Exception("Job was deleted while waiting for completion")
249+
raise JudgmentAPIError(
250+
"Training job was deleted while waiting for completion"
251+
)
242252

243253
_print_progress(
244254
f"Training completed! New model: {job.output_model}",
@@ -277,7 +287,15 @@ async def train(
277287
Returns:
278288
ModelConfig: Configuration of the trained model for future loading
279289
"""
280-
if rft_provider is not None:
281-
self.config.rft_provider = rft_provider
290+
try:
291+
if rft_provider is not None:
292+
self.config.rft_provider = rft_provider
282293

283-
return await self.run_reinforcement_learning(agent_function, scorers, prompts)
294+
return await self.run_reinforcement_learning(
295+
agent_function, scorers, prompts
296+
)
297+
except JudgmentAPIError:
298+
# Re-raise JudgmentAPIError as-is
299+
raise
300+
except Exception as e:
301+
raise JudgmentAPIError(f"Training process failed: {str(e)}") from e

0 commit comments

Comments
 (0)