@@ -83,6 +83,14 @@ def __init__(
83
83
84
84
self .notifier = NotifierCallback (** params_notifier ) if params_notifier else None
85
85
86
+ try :
87
+ gpu_info = tf .config .experimental .list_physical_devices ('GPU' )
88
+ except :
89
+ gpu_info = ''
90
+
91
+ info_env = f'GPU connected: { gpu_info } ' if len (gpu_info ) > 0 else 'Not connected to a GPU'
92
+ print (f'GPU info: { info_env } ' )
93
+
86
94
def __delete_by_path (self , path : str , is_dir : bool = False ):
87
95
if is_dir :
88
96
shutil .rmtree (path , ignore_errors = True )
@@ -150,12 +158,15 @@ def get_model(self, model_name: str, run_id: str, force_creation: bool = True):
150
158
check_path = run_id_data .get ('check_path' )
151
159
if os .path .exists (check_path ) and not force_creation :
152
160
print (f'Loading model from { check_path } ' )
153
- model = tf .keras .models .load_model (** run_id_data .get ('load_model_params' ))
154
- else :
161
+ filepath = run_id_data .get ('filepath' ).replace ('/' , os .path .sep )
162
+ run_id_data .pop ('filepath' )
163
+ model = tf .keras .models .load_model (filepath , ** run_id_data .get ('load_model_params' ))
164
+ return model , run_id_data
165
+ elif force_creation :
155
166
print (f'Force creation is { force_creation } . Deleting old logs and model' )
156
167
self .__delete_by_path (run_id_data .get ('log_dir' ), is_dir = True )
157
168
self .__delete_by_path (run_id_data .get ('check_path' ), is_dir = False )
158
-
169
+
159
170
print ('Creating model' )
160
171
model = self .builder_function (** run_id_data .get ('model_params' ))
161
172
model .compile (** run_id_data .get ('compile_params' ))
@@ -197,7 +208,7 @@ def add_dataset(
197
208
if test_ds is not None :
198
209
self .datasets [name ]['test_ds' ] = test_ds
199
210
200
- print (f'Dataset { name } added' )
211
+ print (f'Dataset { name } was added' )
201
212
return True
202
213
203
214
def update_dataset (
@@ -227,7 +238,7 @@ def update_dataset(
227
238
if test_ds is not None :
228
239
self .datasets [name ]['test_ds' ] = test_ds
229
240
230
- print (f'Dataset { name } updated' )
241
+ print (f'Dataset { name } was updated' )
231
242
return True
232
243
233
244
def delete_dataset (self , name : str ):
@@ -236,6 +247,7 @@ def delete_dataset(self, name: str):
236
247
return False
237
248
238
249
del self .datasets [name ]
250
+ print (f'Dataset { name } was deleted' )
239
251
return True
240
252
241
253
def add_model (
@@ -267,16 +279,8 @@ def add_model(
267
279
path_params = f'{ path_model } /{ model_name } _params.json'
268
280
self .__create_file (path_params , model_params , is_json = True )
269
281
270
- try :
271
- gpu_info = tf .config .experimental .list_physical_devices ('GPU' )
272
- except :
273
- gpu_info = ''
274
-
275
- info_env = gpu_info [0 ].name if len (gpu_info ) > 0 else 'Not connected to a GPU'
276
- print (f'GPU info: { info_env } \n \n ' )
277
-
278
282
model_data = self .models .get (model_name , {})
279
-
283
+
280
284
load_model_params ['filepath' ] = check_path
281
285
model_data [run_id ] = {
282
286
'run_id' : run_id ,
@@ -292,7 +296,7 @@ def add_model(
292
296
293
297
self .models [model_name ] = model_data
294
298
self .save ()
295
- print (f'Model { model_name } /{ run_id } added' )
299
+ print (f'Model { model_name } /{ run_id } was added' )
296
300
return model_data [run_id ]
297
301
298
302
def update_model (
@@ -323,7 +327,7 @@ def update_model(
323
327
324
328
model_data [run_id ] = model
325
329
self .models [model_name ] = model_data
326
- print (f'Model { model_name } /{ run_id } updated' )
330
+ print (f'Model { model_name } /{ run_id } was updated' )
327
331
return model_data [run_id ]
328
332
329
333
def delete_model (self , model_data : T .Dict , delete_folder : bool = False ):
@@ -348,18 +352,18 @@ def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
348
352
del self .models [model_name ]
349
353
350
354
self .save ()
351
- print (f'Model { model_name } /{ run_id } deleted' )
355
+ print (f'Model { model_name } /{ run_id } was deleted' )
352
356
return True
353
357
354
358
def train (
355
359
self ,
356
360
model_data : T .Dict ,
357
361
dataset_name : str ,
358
- epochs : int = 100 ,
359
362
batch_size : int = 32 ,
360
363
initial_epoch : int = 0 ,
361
364
shuffle_buffer : int = None ,
362
365
force_creation : bool = False ,
366
+ epochs : T .Union [int , None ] = None ,
363
367
train_ds : T .Union [tf .data .Dataset , T .List [T .Any ], T .Any ] = None ,
364
368
val_ds : T .Union [tf .data .Dataset , T .List [T .Any ], T .Any ] = None ,
365
369
):
@@ -399,6 +403,12 @@ def train(
399
403
val_ds = val_ds .shuffle (shuffle_buffer )
400
404
401
405
start_time = time .time ()
406
+ print (f'Training { model_name } /{ run_id } ...' )
407
+ print (f'Epochs: { epochs } ' )
408
+ print (f'Batch size: { batch_size } ' )
409
+ print (f'Shuffle buffer: { shuffle_buffer } ' )
410
+ print (f'Start time: { time .strftime ("%Y-%m-%d %H:%M:%S" )} ' )
411
+
402
412
model .fit (
403
413
train_ds ,
404
414
epochs = epochs ,
0 commit comments