Skip to content

Commit bf46b89

Browse files
committed
Hot fixes
1 parent aa82bfa commit bf46b89

File tree

2 files changed

+14
-20
lines changed

2 files changed

+14
-20
lines changed

iaflow/__init__.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,11 @@ def get_model(self, model_name: str, run_id: str, force_creation: bool = True):
158158
check_path = run_id_data.get('check_path')
159159
if os.path.exists(check_path) and not force_creation:
160160
print(f'Loading model from {check_path}')
161-
filepath = run_id_data.get('filepath').replace('/', os.path.sep)
162-
run_id_data.pop('filepath')
163-
model = tf.keras.models.load_model(filepath, **run_id_data.get('load_model_params'))
161+
params = {
162+
**run_id_data.get('load_model_params', {}),
163+
'filepath': check_path,
164+
}
165+
model = tf.keras.models.load_model(**params)
164166
return model, run_id_data
165167
elif force_creation:
166168
print(f'Force creation is {force_creation}. Deleting old logs and model')
@@ -195,7 +197,6 @@ def add_dataset(
195197
):
196198
if name in self.datasets:
197199
print(f'Dataset {name} already exists')
198-
return False
199200

200201
self.datasets[name] = {
201202
'train_ds': train_ds,
@@ -209,7 +210,6 @@ def add_dataset(
209210
self.datasets[name]['test_ds'] = test_ds
210211

211212
print(f'Dataset {name} was added')
212-
return True
213213

214214
def update_dataset(
215215
self,
@@ -223,7 +223,6 @@ def update_dataset(
223223
):
224224
if name not in self.datasets:
225225
print(f'Dataset {name} not found')
226-
return False
227226

228227
if batch_size is not None:
229228
self.datasets[name]['batch_size'] = batch_size
@@ -239,16 +238,13 @@ def update_dataset(
239238
self.datasets[name]['test_ds'] = test_ds
240239

241240
print(f'Dataset {name} was updated')
242-
return True
243241

244242
def delete_dataset(self, name: str):
245243
if name not in self.datasets:
246244
print(f'Dataset {name} not found')
247-
return False
248245

249246
del self.datasets[name]
250247
print(f'Dataset {name} was deleted')
251-
return True
252248

253249
def add_model(
254250
self,
@@ -263,7 +259,7 @@ def add_model(
263259
model_params_str = '_'.join(map(str, model_params.values()))
264260
model_ident = '_'.join([ model_name, model_params_str ])
265261

266-
runs_model = self.models.get(model_name)
262+
runs_model = self.models.get(model_name, {})
267263
if run_id is not None and run_id in runs_model:
268264
raise ValueError(f'Model {model_name}/{run_id} already exists')
269265

@@ -337,12 +333,10 @@ def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
337333
model = self.models.get(model_name)
338334
if model is None:
339335
print(f'Model {model_name} not found')
340-
return False
341336

342337
run_id_data = model.get(run_id)
343338
if run_id_data is None:
344339
print(f'Run {run_id} not found')
345-
return False
346340

347341
if delete_folder:
348342
self.__delete_by_path(run_id_data.get('path_model'), is_dir=True)
@@ -353,13 +347,12 @@ def delete_model(self, model_data: T.Dict, delete_folder: bool = False):
353347

354348
self.save()
355349
print(f'Model {model_name}/{run_id} was deleted')
356-
return True
357350

358351
def train(
359352
self,
360353
model_data: T.Dict,
361354
dataset_name: str,
362-
batch_size: int = 32,
355+
batch_size: int = None,
363356
initial_epoch: int = 0,
364357
shuffle_buffer: int = None,
365358
force_creation: bool = False,
@@ -381,6 +374,12 @@ def train(
381374
run_id = model_data.get('run_id')
382375

383376
self.clear_session()
377+
print(f'Training {model_name}/{run_id}...')
378+
print(f'Epochs: {epochs}')
379+
print(f'Batch size: {batch_size}')
380+
print(f'Shuffle buffer: {shuffle_buffer}')
381+
print(f'Start time: {time.strftime("%Y-%m-%d %H:%M:%S")}')
382+
384383
model, run_data = self.get_model(model_name, run_id, force_creation)
385384

386385
if self.notifier:
@@ -403,11 +402,6 @@ def train(
403402
val_ds = val_ds.shuffle(shuffle_buffer)
404403

405404
start_time = time.time()
406-
print(f'Training {model_name}/{run_id}...')
407-
print(f'Epochs: {epochs}')
408-
print(f'Batch size: {batch_size}')
409-
print(f'Shuffle buffer: {shuffle_buffer}')
410-
print(f'Start time: {time.strftime("%Y-%m-%d %H:%M:%S")}')
411405

412406
model.fit(
413407
train_ds,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setuptools.setup(
88
name="iaflow",
9-
version='2.1.3',
9+
version='2.1.4',
1010
author="Enmanuel Magallanes Pinargote",
1111
author_email="enmanuelmag@cardor.dev",
1212
description="This library help to create models with identifiers, checkpoints, logs and metadata automatically, in order to make the training process more efficient and traceable.",

0 commit comments

Comments
 (0)