@@ -158,9 +158,11 @@ def get_model(self, model_name: str, run_id: str, force_creation: bool = True):
158
158
check_path = run_id_data .get ('check_path' )
159
159
if os .path .exists (check_path ) and not force_creation :
160
160
print (f'Loading model from { check_path } ' )
161
- filepath = run_id_data .get ('filepath' ).replace ('/' , os .path .sep )
162
- run_id_data .pop ('filepath' )
163
- model = tf .keras .models .load_model (filepath , ** run_id_data .get ('load_model_params' ))
161
+ params = {
162
+ ** run_id_data .get ('load_model_params' , {}),
163
+ 'filepath' : check_path ,
164
+ }
165
+ model = tf .keras .models .load_model (** params )
164
166
return model , run_id_data
165
167
elif force_creation :
166
168
print (f'Force creation is { force_creation } . Deleting old logs and model' )
@@ -195,7 +197,6 @@ def add_dataset(
195
197
):
196
198
if name in self .datasets :
197
199
print (f'Dataset { name } already exists' )
198
- return False
199
200
200
201
self .datasets [name ] = {
201
202
'train_ds' : train_ds ,
@@ -209,7 +210,6 @@ def add_dataset(
209
210
self .datasets [name ]['test_ds' ] = test_ds
210
211
211
212
print (f'Dataset { name } was added' )
212
- return True
213
213
214
214
def update_dataset (
215
215
self ,
@@ -223,7 +223,6 @@ def update_dataset(
223
223
):
224
224
if name not in self .datasets :
225
225
print (f'Dataset { name } not found' )
226
- return False
227
226
228
227
if batch_size is not None :
229
228
self .datasets [name ]['batch_size' ] = batch_size
@@ -239,16 +238,13 @@ def update_dataset(
239
238
self .datasets [name ]['test_ds' ] = test_ds
240
239
241
240
print (f'Dataset { name } was updated' )
242
- return True
243
241
244
242
def delete_dataset (self , name : str ):
245
243
if name not in self .datasets :
246
244
print (f'Dataset { name } not found' )
247
- return False
248
245
249
246
del self .datasets [name ]
250
247
print (f'Dataset { name } was deleted' )
251
- return True
252
248
253
249
def add_model (
254
250
self ,
@@ -263,7 +259,7 @@ def add_model(
263
259
model_params_str = '_' .join (map (str , model_params .values ()))
264
260
model_ident = '_' .join ([ model_name , model_params_str ])
265
261
266
- runs_model = self .models .get (model_name )
262
+ runs_model = self .models .get (model_name , {} )
267
263
if run_id is not None and run_id in runs_model :
268
264
raise ValueError (f'Model { model_name } /{ run_id } already exists' )
269
265
@@ -337,12 +333,10 @@ def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
337
333
model = self .models .get (model_name )
338
334
if model is None :
339
335
print (f'Model { model_name } not found' )
340
- return False
341
336
342
337
run_id_data = model .get (run_id )
343
338
if run_id_data is None :
344
339
print (f'Run { run_id } not found' )
345
- return False
346
340
347
341
if delete_folder :
348
342
self .__delete_by_path (run_id_data .get ('path_model' ), is_dir = True )
@@ -353,13 +347,12 @@ def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
353
347
354
348
self .save ()
355
349
print (f'Model { model_name } /{ run_id } was deleted' )
356
- return True
357
350
358
351
def train (
359
352
self ,
360
353
model_data : T .Dict ,
361
354
dataset_name : str ,
362
- batch_size : int = 32 ,
355
+ batch_size : int = None ,
363
356
initial_epoch : int = 0 ,
364
357
shuffle_buffer : int = None ,
365
358
force_creation : bool = False ,
@@ -381,6 +374,12 @@ def train(
381
374
run_id = model_data .get ('run_id' )
382
375
383
376
self .clear_session ()
377
+ print (f'Training { model_name } /{ run_id } ...' )
378
+ print (f'Epochs: { epochs } ' )
379
+ print (f'Batch size: { batch_size } ' )
380
+ print (f'Shuffle buffer: { shuffle_buffer } ' )
381
+ print (f'Start time: { time .strftime ("%Y-%m-%d %H:%M:%S" )} ' )
382
+
384
383
model , run_data = self .get_model (model_name , run_id , force_creation )
385
384
386
385
if self .notifier :
@@ -403,11 +402,6 @@ def train(
403
402
val_ds = val_ds .shuffle (shuffle_buffer )
404
403
405
404
start_time = time .time ()
406
- print (f'Training { model_name } /{ run_id } ...' )
407
- print (f'Epochs: { epochs } ' )
408
- print (f'Batch size: { batch_size } ' )
409
- print (f'Shuffle buffer: { shuffle_buffer } ' )
410
- print (f'Start time: { time .strftime ("%Y-%m-%d %H:%M:%S" )} ' )
411
405
412
406
model .fit (
413
407
train_ds ,
0 commit comments