Skip to content

Commit aa82bfa

Browse files
committed
Hot fixes
1 parent b93d1f3 commit aa82bfa

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

iaflow/__init__.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def __init__(
8383

8484
self.notifier = NotifierCallback(**params_notifier) if params_notifier else None
8585

86+
try:
87+
gpu_info = tf.config.experimental.list_physical_devices('GPU')
88+
except:
89+
gpu_info = ''
90+
91+
info_env = f'GPU connected: {gpu_info}' if len(gpu_info) > 0 else 'Not connected to a GPU'
92+
print(f'GPU info: {info_env}')
93+
8694
def __delete_by_path(self, path: str, is_dir: bool = False):
8795
if is_dir:
8896
shutil.rmtree(path, ignore_errors=True)
@@ -150,12 +158,15 @@ def get_model(self, model_name: str, run_id: str, force_creation: bool = True):
150158
check_path = run_id_data.get('check_path')
151159
if os.path.exists(check_path) and not force_creation:
152160
print(f'Loading model from {check_path}')
153-
model = tf.keras.models.load_model(**run_id_data.get('load_model_params'))
154-
else:
161+
filepath = run_id_data.get('filepath').replace('/', os.path.sep)
162+
run_id_data.pop('filepath')
163+
model = tf.keras.models.load_model(filepath, **run_id_data.get('load_model_params'))
164+
return model, run_id_data
165+
elif force_creation:
155166
print(f'Force creation is {force_creation}. Deleting old logs and model')
156167
self.__delete_by_path(run_id_data.get('log_dir'), is_dir=True)
157168
self.__delete_by_path(run_id_data.get('check_path'), is_dir=False)
158-
169+
159170
print('Creating model')
160171
model = self.builder_function(**run_id_data.get('model_params'))
161172
model.compile(**run_id_data.get('compile_params'))
@@ -197,7 +208,7 @@ def add_dataset(
197208
if test_ds is not None:
198209
self.datasets[name]['test_ds'] = test_ds
199210

200-
print(f'Dataset {name} added')
211+
print(f'Dataset {name} was added')
201212
return True
202213

203214
def update_dataset(
@@ -227,7 +238,7 @@ def update_dataset(
227238
if test_ds is not None:
228239
self.datasets[name]['test_ds'] = test_ds
229240

230-
print(f'Dataset {name} updated')
241+
print(f'Dataset {name} was updated')
231242
return True
232243

233244
def delete_dataset(self, name: str):
@@ -236,6 +247,7 @@ def delete_dataset(self, name: str):
236247
return False
237248

238249
del self.datasets[name]
250+
print(f'Dataset {name} was deleted')
239251
return True
240252

241253
def add_model(
@@ -267,16 +279,8 @@ def add_model(
267279
path_params = f'{path_model}/{model_name}_params.json'
268280
self.__create_file(path_params, model_params, is_json=True)
269281

270-
try:
271-
gpu_info = tf.config.experimental.list_physical_devices('GPU')
272-
except:
273-
gpu_info = ''
274-
275-
info_env = gpu_info[0].name if len(gpu_info) > 0 else 'Not connected to a GPU'
276-
print(f'GPU info: {info_env}\n\n')
277-
278282
model_data = self.models.get(model_name, {})
279-
283+
280284
load_model_params['filepath'] = check_path
281285
model_data[run_id] = {
282286
'run_id': run_id,
@@ -292,7 +296,7 @@ def add_model(
292296

293297
self.models[model_name] = model_data
294298
self.save()
295-
print(f'Model {model_name}/{run_id} added')
299+
print(f'Model {model_name}/{run_id} was added')
296300
return model_data[run_id]
297301

298302
def update_model(
@@ -323,7 +327,7 @@ def update_model(
323327

324328
model_data[run_id] = model
325329
self.models[model_name] = model_data
326-
print(f'Model {model_name}/{run_id} updated')
330+
print(f'Model {model_name}/{run_id} was updated')
327331
return model_data[run_id]
328332

329333
def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
@@ -348,18 +352,18 @@ def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
348352
del self.models[model_name]
349353

350354
self.save()
351-
print(f'Model {model_name}/{run_id} deleted')
355+
print(f'Model {model_name}/{run_id} was deleted')
352356
return True
353357

354358
def train(
355359
self,
356360
model_data: T.Dict,
357361
dataset_name: str,
358-
epochs: int = 100,
359362
batch_size: int = 32,
360363
initial_epoch: int = 0,
361364
shuffle_buffer: int = None,
362365
force_creation: bool = False,
366+
epochs: T.Union[int, None] = None,
363367
train_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any] = None,
364368
val_ds: T.Union[tf.data.Dataset, T.List[T.Any], T.Any] = None,
365369
):
@@ -399,6 +403,12 @@ def train(
399403
val_ds = val_ds.shuffle(shuffle_buffer)
400404

401405
start_time = time.time()
406+
print(f'Training {model_name}/{run_id}...')
407+
print(f'Epochs: {epochs}')
408+
print(f'Batch size: {batch_size}')
409+
print(f'Shuffle buffer: {shuffle_buffer}')
410+
print(f'Start time: {time.strftime("%Y-%m-%d %H:%M:%S")}')
411+
402412
model.fit(
403413
train_ds,
404414
epochs=epochs,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setuptools.setup(
88
name="iaflow",
9-
version='2.1.2',
9+
version='2.1.3',
1010
author="Enmanuel Magallanes Pinargote",
1111
author_email="enmanuelmag@cardor.dev",
1212
description="This library help to create models with identifiers, checkpoints, logs and metadata automatically, in order to make the training process more efficient and traceable.",

0 commit comments

Comments
 (0)