Skip to content

Commit acb968f

Browse files
committed
Remove typing
1 parent bf46b89 commit acb968f

File tree

2 files changed

+29
-37
lines changed

2 files changed

+29
-37
lines changed

iaflow/__init__.py

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import time
55
import copy
66
import shutil
7-
import typing as T
87
import pickle as pkl
98
import tensorflow as tf
109
import subprocess as sp
@@ -38,7 +37,7 @@ def __init__(self,
3837

3938
def on_epoch_end(self, batch, logs={}):
4039
self.epoch_count += 1
41-
if self.epoch_count % self.frequency_epoch != 0:
40+
if self.epoch_count % self.frequency_epoch != 0 or self.epoch_count == 0:
4241
return
4342

4443
try:
@@ -50,13 +49,6 @@ def on_epoch_end(self, batch, logs={}):
5049
except Exception as e:
5150
print('There was an error sending the notification:', e)
5251

53-
class ParamsNotifier(T.TypedDict):
54-
title: T.Optional[str]
55-
email: T.Optional[str]
56-
chat_id: T.Optional[str]
57-
api_token: T.Optional[str]
58-
webhook_url: T.Optional[str]
59-
6052
def NoImplementedError(message: str):
6153
raise NotImplementedError(message)
6254

@@ -66,11 +58,11 @@ class IAFlow(object):
6658
def __init__(
6759
self,
6860
models_folder: str,
69-
builder_function: T.Callable = lambda **kwargs: NoImplementedError('Builder function not implemented'),
70-
callbacks: T.Union[T.List[T.Any], T.Any] = [],
71-
checkpoint_params: T.Dict = {},
72-
tensorboard_params: T.Dict = {},
73-
params_notifier: ParamsNotifier = None,
61+
builder_function = lambda **kwargs: NoImplementedError('Builder function not implemented'),
62+
callbacks = [],
63+
checkpoint_params = {},
64+
tensorboard_params = {},
65+
params_notifier = None,
7466
):
7567
self.models = {}
7668
self.datasets = {}
@@ -110,7 +102,7 @@ def __find_endwith(self, path: str, endwith: str):
110102
return os.path.join(path, filename)
111103
return None
112104

113-
def __get_params_models(self, load_model: bool, path_model: str, model_params: T.Dict):
105+
def __get_params_models(self, load_model: bool, path_model: str, model_params):
114106
if not load_model:
115107
return model_params
116108

@@ -125,7 +117,7 @@ def __get_params_models(self, load_model: bool, path_model: str, model_params: T
125117

126118
return model_params
127119

128-
def __create_file(self, path: str, content: T.Any, mode: str = 'w', is_json: bool = False):
120+
def __create_file(self, path: str, content, mode: str = 'w', is_json: bool = False):
129121
if not os.path.exists(path):
130122
with open(path, mode) as file:
131123
if is_json:
@@ -136,7 +128,7 @@ def __create_file(self, path: str, content: T.Any, mode: str = 'w', is_json: boo
136128
def __get_config(self):
137129
pass
138130

139-
def set_builder_function(self, builder_function: T.Callable):
131+
def set_builder_function(self, builder_function):
140132
self.builder_function = builder_function
141133

142134
def set_notifier_parameters(self, params: ParamsNotifier):
@@ -189,11 +181,11 @@ def add_dataset(
189181
self,
190182
name: str,
191183
epochs: int,
192-
train_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any],
184+
train_ds,
193185
batch_size: int = None,
194186
shuffle_buffer: int = None,
195-
val_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any] = None,
196-
test_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any] = None
187+
val_ds = None,
188+
test_ds = None
197189
):
198190
if name in self.datasets:
199191
print(f'Dataset {name} already exists')
@@ -217,9 +209,9 @@ def update_dataset(
217209
epochs: int = None,
218210
batch_size: int = None,
219211
shuffle_buffer: int = None,
220-
train_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any] = None,
221-
val_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any] = None,
222-
test_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any] = None
212+
train_ds = None,
213+
val_ds = None,
214+
test_ds = None
223215
):
224216
if name not in self.datasets:
225217
print(f'Dataset {name} not found')
@@ -250,10 +242,10 @@ def add_model(
250242
self,
251243
model_name: str,
252244
run_id: str = None,
253-
model_params: T.Dict = {},
254-
compile_params: T.Dict = {},
255-
load_model_params: T.Union[T.Dict, None] = {}
256-
) -> T.Tuple[tf.keras.Model, str, str, str]:
245+
model_params = {},
246+
compile_params = {},
247+
load_model_params = {}
248+
):
257249

258250
models_folder = self.models_folder
259251
model_params_str = '_'.join(map(str, model_params.values()))
@@ -299,10 +291,10 @@ def update_model(
299291
self,
300292
model_name: str,
301293
run_id: str,
302-
model_params: T.Dict = {},
303-
compile_params: T.Dict = {},
304-
load_model_params: T.Union[T.Dict, None] = {}
305-
) -> T.Tuple[tf.keras.Model, str, str, str]:
294+
model_params = {},
295+
compile_params = {},
296+
load_model_params = {}
297+
):
306298

307299
models_folder = self.models_folder
308300
model_params_str = '_'.join(map(str, model_params.values()))
@@ -326,7 +318,7 @@ def update_model(
326318
print(f'Model {model_name}/{run_id} was updated')
327319
return model_data[run_id]
328320

329-
def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
321+
def delete_model(self, model_data, delete_folder: bool = False):
330322
model_name = model_data.get('model_name')
331323
run_id = model_data.get('run_id')
332324

@@ -350,15 +342,15 @@ def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
350342

351343
def train(
352344
self,
353-
model_data: T.Dict,
345+
model_data,
354346
dataset_name: str,
355347
batch_size: int = None,
356348
initial_epoch: int = 0,
357349
shuffle_buffer: int = None,
358350
force_creation: bool = False,
359-
epochs: T.Union[int, None] = None,
360-
train_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any] = None,
361-
val_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any] = None,
351+
epochs = None,
352+
train_ds = None,
353+
val_ds = None,
362354
):
363355
epochs = epochs or self.datasets.get(dataset_name, {}).get('epochs', 100)
364356
batch_size = batch_size or self.datasets.get(dataset_name, {}).get('batch_size', None)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setuptools.setup(
88
name="iaflow",
9-
version='2.1.4',
9+
version='2.1.5',
1010
author="Enmanuel Magallanes Pinargote",
1111
author_email="enmanuelmag@cardor.dev",
1212
description="This library help to create models with identifiers, checkpoints, logs and metadata automatically, in order to make the training process more efficient and traceable.",

0 commit comments

Comments
 (0)