4
4
import time
5
5
import copy
6
6
import shutil
7
- import typing as T
8
7
import pickle as pkl
9
8
import tensorflow as tf
10
9
import subprocess as sp
@@ -38,7 +37,7 @@ def __init__(self,
38
37
39
38
def on_epoch_end (self , batch , logs = {}):
40
39
self .epoch_count += 1
41
- if self .epoch_count % self .frequency_epoch != 0 :
40
+ if self .epoch_count % self .frequency_epoch != 0 or self . epoch_count == 0 :
42
41
return
43
42
44
43
try :
@@ -50,13 +49,6 @@ def on_epoch_end(self, batch, logs={}):
50
49
except Exception as e :
51
50
print ('There was an error sending the notification:' , e )
52
51
53
- class ParamsNotifier (T .TypedDict ):
54
- title : T .Optional [str ]
55
- email : T .Optional [str ]
56
- chat_id : T .Optional [str ]
57
- api_token : T .Optional [str ]
58
- webhook_url : T .Optional [str ]
59
-
60
52
def NoImplementedError (message : str ):
61
53
raise NotImplementedError (message )
62
54
@@ -66,11 +58,11 @@ class IAFlow(object):
66
58
def __init__ (
67
59
self ,
68
60
models_folder : str ,
69
- builder_function : T . Callable = lambda ** kwargs : NoImplementedError ('Builder function not implemented' ),
70
- callbacks : T . Union [ T . List [ T . Any ], T . Any ] = [],
71
- checkpoint_params : T . Dict = {},
72
- tensorboard_params : T . Dict = {},
73
- params_notifier : ParamsNotifier = None ,
61
+ builder_function = lambda ** kwargs : NoImplementedError ('Builder function not implemented' ),
62
+ callbacks = [],
63
+ checkpoint_params = {},
64
+ tensorboard_params = {},
65
+ params_notifier = None ,
74
66
):
75
67
self .models = {}
76
68
self .datasets = {}
@@ -110,7 +102,7 @@ def __find_endwith(self, path: str, endwith: str):
110
102
return os .path .join (path , filename )
111
103
return None
112
104
113
- def __get_params_models (self , load_model : bool , path_model : str , model_params : T . Dict ):
105
+ def __get_params_models (self , load_model : bool , path_model : str , model_params ):
114
106
if not load_model :
115
107
return model_params
116
108
@@ -125,7 +117,7 @@ def __get_params_models(self, load_model: bool, path_model: str, model_params: T
125
117
126
118
return model_params
127
119
128
- def __create_file (self , path : str , content : T . Any , mode : str = 'w' , is_json : bool = False ):
120
+ def __create_file (self , path : str , content , mode : str = 'w' , is_json : bool = False ):
129
121
if not os .path .exists (path ):
130
122
with open (path , mode ) as file :
131
123
if is_json :
@@ -136,7 +128,7 @@ def __create_file(self, path: str, content: T.Any, mode: str = 'w', is_json: boo
136
128
def __get_config (self ):
137
129
pass
138
130
139
- def set_builder_function (self , builder_function : T . Callable ):
131
+ def set_builder_function (self , builder_function ):
140
132
self .builder_function = builder_function
141
133
142
134
def set_notifier_parameters (self , params : ParamsNotifier ):
@@ -189,11 +181,11 @@ def add_dataset(
189
181
self ,
190
182
name : str ,
191
183
epochs : int ,
192
- train_ds : T . Union [ tf . data . Dataset , T . List [ T . Any ], T . Any ] ,
184
+ train_ds ,
193
185
batch_size : int = None ,
194
186
shuffle_buffer : int = None ,
195
- val_ds : T . Union [ tf . data . Dataset , T . List [ T . Any ], T . Any ] = None ,
196
- test_ds : T . Union [ tf . data . Dataset , T . List [ T . Any ], T . Any ] = None
187
+ val_ds = None ,
188
+ test_ds = None
197
189
):
198
190
if name in self .datasets :
199
191
print (f'Dataset { name } already exists' )
@@ -217,9 +209,9 @@ def update_dataset(
217
209
epochs : int = None ,
218
210
batch_size : int = None ,
219
211
shuffle_buffer : int = None ,
220
- train_ds : T . Union [ tf . data . Dataset , T . List [ T . Any ], T . Any ] = None ,
221
- val_ds : T . Union [ tf . data . Dataset , T . List [ T . Any ], T . Any ] = None ,
222
- test_ds : T . Union [ tf . data . Dataset , T . List [ T . Any ], T . Any ] = None
212
+ train_ds = None ,
213
+ val_ds = None ,
214
+ test_ds = None
223
215
):
224
216
if name not in self .datasets :
225
217
print (f'Dataset { name } not found' )
@@ -250,10 +242,10 @@ def add_model(
250
242
self ,
251
243
model_name : str ,
252
244
run_id : str = None ,
253
- model_params : T . Dict = {},
254
- compile_params : T . Dict = {},
255
- load_model_params : T . Union [ T . Dict , None ] = {}
256
- ) -> T . Tuple [ tf . keras . Model , str , str , str ] :
245
+ model_params = {},
246
+ compile_params = {},
247
+ load_model_params = {}
248
+ ):
257
249
258
250
models_folder = self .models_folder
259
251
model_params_str = '_' .join (map (str , model_params .values ()))
@@ -299,10 +291,10 @@ def update_model(
299
291
self ,
300
292
model_name : str ,
301
293
run_id : str ,
302
- model_params : T . Dict = {},
303
- compile_params : T . Dict = {},
304
- load_model_params : T . Union [ T . Dict , None ] = {}
305
- ) -> T . Tuple [ tf . keras . Model , str , str , str ] :
294
+ model_params = {},
295
+ compile_params = {},
296
+ load_model_params = {}
297
+ ):
306
298
307
299
models_folder = self .models_folder
308
300
model_params_str = '_' .join (map (str , model_params .values ()))
@@ -326,7 +318,7 @@ def update_model(
326
318
print (f'Model { model_name } /{ run_id } was updated' )
327
319
return model_data [run_id ]
328
320
329
- def delete_model (self , model_data : T . Dict , delete_folder : bool = False ):
321
+ def delete_model (self , model_data , delete_folder : bool = False ):
330
322
model_name = model_data .get ('model_name' )
331
323
run_id = model_data .get ('run_id' )
332
324
@@ -350,15 +342,15 @@ def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
350
342
351
343
def train (
352
344
self ,
353
- model_data : T . Dict ,
345
+ model_data ,
354
346
dataset_name : str ,
355
347
batch_size : int = None ,
356
348
initial_epoch : int = 0 ,
357
349
shuffle_buffer : int = None ,
358
350
force_creation : bool = False ,
359
- epochs : T . Union [ int , None ] = None ,
360
- train_ds : T . Union [ tf . data . Dataset , T . List [ T . Any ], T . Any ] = None ,
361
- val_ds : T . Union [ tf . data . Dataset , T . List [ T . Any ], T . Any ] = None ,
351
+ epochs = None ,
352
+ train_ds = None ,
353
+ val_ds = None ,
362
354
):
363
355
epochs = epochs or self .datasets .get (dataset_name , {}).get ('epochs' , 100 )
364
356
batch_size = batch_size or self .datasets .get (dataset_name , {}).get ('batch_size' , None )
0 commit comments