2
2
from .config import TrainerConfig , ModelConfig
3
3
from typing import Optional , Dict , Any , Callable
4
4
from .console import _model_spinner_progress , _print_model_progress
5
+ from judgeval .common .exceptions import JudgmentAPIError
5
6
6
7
7
8
class TrainableModel :
@@ -20,13 +21,18 @@ def __init__(self, config: TrainerConfig):
20
21
Args:
21
22
config: TrainerConfig instance with model configuration
22
23
"""
23
- self .config = config
24
- self .current_step = 0
25
- self ._current_model = None
26
- self ._tracer_wrapper_func = None
24
+ try :
25
+ self .config = config
26
+ self .current_step = 0
27
+ self ._current_model = None
28
+ self ._tracer_wrapper_func = None
27
29
28
- self ._base_model = self ._create_base_model ()
29
- self ._current_model = self ._base_model
30
+ self ._base_model = self ._create_base_model ()
31
+ self ._current_model = self ._base_model
32
+ except Exception as e :
33
+ raise JudgmentAPIError (
34
+ f"Failed to initialize TrainableModel: { str (e )} "
35
+ ) from e
30
36
31
37
@classmethod
32
38
def from_model_config (cls , model_config : ModelConfig ) -> "TrainableModel" :
@@ -58,38 +64,48 @@ def from_model_config(cls, model_config: ModelConfig) -> "TrainableModel":
58
64
59
65
def _create_base_model (self ):
60
66
"""Create and configure the base model."""
61
- with _model_spinner_progress (
62
- "Creating and deploying base model..."
63
- ) as update_progress :
64
- update_progress ("Creating base model instance..." )
65
- base_model = LLM (
66
- model = self .config .base_model_name ,
67
- deployment_type = "on-demand" ,
68
- id = self .config .deployment_id ,
69
- enable_addons = self .config .enable_addons ,
70
- )
71
- update_progress ("Applying deployment configuration..." )
72
- base_model .apply ()
73
- _print_model_progress ("Base model deployment ready" )
74
- return base_model
67
+ try :
68
+ with _model_spinner_progress (
69
+ "Creating and deploying base model..."
70
+ ) as update_progress :
71
+ update_progress ("Creating base model instance..." )
72
+ base_model = LLM (
73
+ model = self .config .base_model_name ,
74
+ deployment_type = "on-demand" ,
75
+ id = self .config .deployment_id ,
76
+ enable_addons = self .config .enable_addons ,
77
+ )
78
+ update_progress ("Applying deployment configuration..." )
79
+ base_model .apply ()
80
+ _print_model_progress ("Base model deployment ready" )
81
+ return base_model
82
+ except Exception as e :
83
+ raise JudgmentAPIError (
84
+ f"Failed to create and deploy base model '{ self .config .base_model_name } ': { str (e )} "
85
+ ) from e
75
86
76
87
def _load_trained_model (self , model_name : str ):
77
88
"""Load a trained model by name."""
78
- with _model_spinner_progress (
79
- f"Loading and deploying trained model: { model_name } "
80
- ) as update_progress :
81
- update_progress ("Creating trained model instance..." )
82
- self ._current_model = LLM (
83
- model = model_name ,
84
- deployment_type = "on-demand-lora" ,
85
- base_id = self .config .deployment_id ,
86
- )
87
- update_progress ("Applying deployment configuration..." )
88
- self ._current_model .apply ()
89
- _print_model_progress ("Trained model deployment ready" )
89
+ try :
90
+ with _model_spinner_progress (
91
+ f"Loading and deploying trained model: { model_name } "
92
+ ) as update_progress :
93
+ update_progress ("Creating trained model instance..." )
94
+ self ._current_model = LLM (
95
+ model = model_name ,
96
+ deployment_type = "on-demand-lora" ,
97
+ base_id = self .config .deployment_id ,
98
+ )
99
+ update_progress ("Applying deployment configuration..." )
100
+ self ._current_model .apply ()
101
+ _print_model_progress ("Trained model deployment ready" )
90
102
91
- if self ._tracer_wrapper_func :
92
- self ._tracer_wrapper_func (self ._current_model )
103
+ if self ._tracer_wrapper_func :
104
+ self ._tracer_wrapper_func (self ._current_model )
105
+ except Exception as e :
106
+ raise JudgmentAPIError (
107
+ f"Failed to load and deploy trained model '{ model_name } ': { str (e )} "
108
+ ) from e
93
109
94
110
def get_current_model (self ):
95
111
return self ._current_model
@@ -111,29 +127,32 @@ def advance_to_next_step(self, step: int):
111
127
Args:
112
128
step: The current training step number
113
129
"""
114
- self .current_step = step
115
-
116
- if step == 0 :
117
- self ._current_model = self ._base_model
118
- else :
119
- model_name = (
120
- f"accounts/{ self .config .user_id } /models/{ self .config .model_id } -v{ step } "
121
- )
122
- with _model_spinner_progress (
123
- f"Creating and deploying model snapshot: { model_name } "
124
- ) as update_progress :
125
- update_progress ("Creating model snapshot instance..." )
126
- self ._current_model = LLM (
127
- model = model_name ,
128
- deployment_type = "on-demand-lora" ,
129
- base_id = self .config .deployment_id ,
130
- )
131
- update_progress ("Applying deployment configuration..." )
132
- self ._current_model .apply ()
133
- _print_model_progress ("Model snapshot deployment ready" )
134
-
135
- if self ._tracer_wrapper_func :
136
- self ._tracer_wrapper_func (self ._current_model )
130
+ try :
131
+ self .current_step = step
132
+
133
+ if step == 0 :
134
+ self ._current_model = self ._base_model
135
+ else :
136
+ model_name = f"accounts/{ self .config .user_id } /models/{ self .config .model_id } -v{ step } "
137
+ with _model_spinner_progress (
138
+ f"Creating and deploying model snapshot: { model_name } "
139
+ ) as update_progress :
140
+ update_progress ("Creating model snapshot instance..." )
141
+ self ._current_model = LLM (
142
+ model = model_name ,
143
+ deployment_type = "on-demand-lora" ,
144
+ base_id = self .config .deployment_id ,
145
+ )
146
+ update_progress ("Applying deployment configuration..." )
147
+ self ._current_model .apply ()
148
+ _print_model_progress ("Model snapshot deployment ready" )
149
+
150
+ if self ._tracer_wrapper_func :
151
+ self ._tracer_wrapper_func (self ._current_model )
152
+ except Exception as e :
153
+ raise JudgmentAPIError (
154
+ f"Failed to advance to training step { step } : { str (e )} "
155
+ ) from e
137
156
138
157
def perform_reinforcement_step (self , dataset , step : int ):
139
158
"""
@@ -146,15 +165,20 @@ def perform_reinforcement_step(self, dataset, step: int):
146
165
Returns:
147
166
Training job object
148
167
"""
149
- model_name = f"{ self .config .model_id } -v{ step + 1 } "
150
- return self ._current_model .reinforcement_step (
151
- dataset = dataset ,
152
- output_model = model_name ,
153
- epochs = self .config .epochs ,
154
- learning_rate = self .config .learning_rate ,
155
- accelerator_count = self .config .accelerator_count ,
156
- accelerator_type = self .config .accelerator_type ,
157
- )
168
+ try :
169
+ model_name = f"{ self .config .model_id } -v{ step + 1 } "
170
+ return self ._current_model .reinforcement_step (
171
+ dataset = dataset ,
172
+ output_model = model_name ,
173
+ epochs = self .config .epochs ,
174
+ learning_rate = self .config .learning_rate ,
175
+ accelerator_count = self .config .accelerator_count ,
176
+ accelerator_type = self .config .accelerator_type ,
177
+ )
178
+ except Exception as e :
179
+ raise JudgmentAPIError (
180
+ f"Failed to start reinforcement learning step { step + 1 } : { str (e )} "
181
+ ) from e
158
182
159
183
def get_model_config (
160
184
self , training_params : Optional [Dict [str , Any ]] = None
0 commit comments