Skip to content

Commit 76092d9

Browse files
authored
Merge pull request #16 from lab-v2/debug
Simplified NeuralPyEDCR implementation
2 parents 6b772e4 + 92ac1e4 commit 76092d9

22 files changed

+1611
-1581
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ src_files/helper_functions/__pycache__/bn_fusion.cpython-311.pyc
1313
jupyter/
1414
test_PYEDCR/
1515
config.py
16-
google_sheets_api.py
16+
google_sheets_api.py
17+
data/ImageNet100
18+
data/OpenImage

LTN/LTN_error_detetection_rule_learn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import utils
1313
from PyEDCR import EDCR
14-
import data_preprocessing
14+
import datasets
1515
import backbone_pipeline
1616
import typing
1717
import config

LTN/ltn_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import ltn
33
import numpy as np
44
import torch
5-
import data_preprocessing
5+
import datasets
66
import condition
77
import rule
88
import typing

NeuralPyEDCR.py

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66

77
import numpy as np
88
import typing
9-
# from tqdm.contrib.concurrent import process_map
10-
# import itertools
119

12-
import data_preprocessing
10+
import experiment_config
11+
import data_preprocessor
1312
import PyEDCR
14-
# import google_sheets_api
15-
# import plotting
1613

1714

1815
class NeuralPyEDCR(PyEDCR.EDCR):
@@ -29,7 +26,6 @@ def __init__(self,
2926
sheet_index: int = None,
3027
K_train: typing.Union[typing.List[typing.Tuple[int]], np.ndarray] = None,
3128
K_test: typing.List[typing.Tuple[int]] = None,
32-
include_inconsistency_constraint: bool = False,
3329
secondary_model_name: str = None,
3430
secondary_model_loss: str = None,
3531
secondary_num_epochs: int = None,
@@ -52,7 +48,6 @@ def __init__(self,
5248
sheet_index=sheet_index,
5349
K_train=K_train,
5450
K_test=K_test,
55-
include_inconsistency_constraint=include_inconsistency_constraint,
5651
secondary_model_name=secondary_model_name,
5752
secondary_model_loss=secondary_model_loss,
5853
secondary_num_epochs=secondary_num_epochs,
@@ -79,7 +74,7 @@ def run_learning_pipeline(self,
7974
# self.print_metrics(test=False, prior=True)
8075

8176
for EDCR_epoch in range(self.EDCR_num_epochs):
82-
for g in data_preprocessing.FineCoarseDataPreprocessor.granularities.values():
77+
for g in data_preprocessor.FineCoarseDataPreprocessor.granularities.values():
8378
self.learn_detection_rules(g=g,
8479
multi_processing=multi_processing)
8580
# self.apply_detection_rules(test=False,
@@ -131,7 +126,6 @@ def work_on_value(args):
131126
loss='BCE',
132127
lr=main_lr,
133128
original_num_epochs=original_num_epochs,
134-
include_inconsistency_constraint=False,
135129
secondary_model_name=secondary_model_name,
136130
secondary_model_loss=secondary_model_loss,
137131
secondary_num_epochs=secondary_num_epochs,
@@ -205,37 +199,10 @@ def simulate_for_values(data_str: str,
205199
work_on_value(data)
206200

207201

208-
def main():
209-
data_str = 'military_vehicles'
210-
main_model_name = binary_model_name = 'vit_b_16'
211-
secondary_model_name = 'vit_l_16'
212-
main_lr = secondary_lr = binary_lr = 0.0001
213-
original_num_epochs = 10
214-
secondary_num_epochs = 20
215-
binary_num_epochs = 10
216-
217-
# data_str = 'imagenet'
218-
# main_model_name = binary_model_name = 'dinov2_vits14'
219-
# secondary_model_name = 'dinov2_vitl14'
220-
# # main_lr = 0.00001
221-
# main_lr = secondary_lr = binary_lr = 0.000001
222-
# original_num_epochs = 8
223-
# secondary_num_epochs = 2
224-
# binary_num_epochs = 5
225-
226-
# data_str = 'openimage'
227-
# main_model_name = 'vit_b_16'
228-
# secondary_model_name = binary_model_name = 'dinov2_vits14'
229-
# main_lr = 0.0001
230-
# binary_lr = 0.000001
231-
# secondary_lr = 0.000001
232-
# original_num_epochs = 20
233-
# secondary_num_epochs = 20
234-
# binary_num_epochs = 4
235-
236-
binary_l_strs = list({f.split(f'e{binary_num_epochs - 1}_')[-1].replace('.npy', '')
202+
def run_experiment(config: experiment_config.ExperimentConfig):
203+
binary_l_strs = list({f.split(f'e{config.binary_num_epochs - 1}_')[-1].replace('.npy', '')
237204
for f in os.listdir('binary_results')
238-
if f.startswith(f'{data_str}_{binary_model_name}')})
205+
if f.startswith(f'{config.data_str}_{config.binary_model_name}')})
239206

240207
# print(google_sheets_api.get_maximal_epsilon(tab_name=sheet_tab))
241208

@@ -251,11 +218,11 @@ def main():
251218
# lists_of_fine_labels_to_take_out = [list(range(number_of_fine_classes-1))]
252219

253220
for (curr_secondary_model_name, curr_secondary_model_loss, curr_secondary_num_epochs, curr_secondary_lr) in \
254-
[(secondary_model_name, 'BCE', secondary_num_epochs, secondary_lr),
221+
[(config.secondary_model_name, 'BCE', config.secondary_num_epochs, config.secondary_lr),
255222
# [None] * 4
256223
]:
257224
for (curr_binary_l_strs, curr_binary_lr, curr_binary_num_epochs) in \
258-
[(binary_l_strs, binary_lr, binary_num_epochs),
225+
[(binary_l_strs, config.binary_lr, config.binary_num_epochs),
259226
# ([], None, None)
260227
]:
261228
for (lists_of_fine_labels_to_take_out, maximize_ratio, multi_processing) in \
@@ -264,11 +231,11 @@ def main():
264231
# ([list(range(i)) for i in range(int(number_of_fine_classes / 2) + 1)], True, True)
265232
]:
266233
simulate_for_values(
267-
data_str=data_str,
268-
main_model_name=main_model_name,
269-
main_lr=main_lr,
270-
original_num_epochs=original_num_epochs,
271-
binary_model_name=binary_model_name,
234+
data_str=config.data_str,
235+
main_model_name=config.main_model_name,
236+
main_lr=config.main_lr,
237+
original_num_epochs=config.original_num_epochs,
238+
binary_model_name=config.binary_model_name,
272239
binary_l_strs=curr_binary_l_strs,
273240
binary_lr=curr_binary_lr,
274241
binary_num_epochs=curr_binary_num_epochs,
@@ -320,5 +287,48 @@ def main():
320287
# fontsize=24)
321288

322289

290+
def main():
291+
military_vehicles_config = experiment_config.ExperimentConfig(
292+
data_str='military_vehicles',
293+
main_model_name='vit_b_16',
294+
secondary_model_name='vit_l_16',
295+
main_lr=0.0001,
296+
secondary_lr=0.0001,
297+
binary_lr=0.0001,
298+
original_num_epochs=10,
299+
secondary_num_epochs=20,
300+
binary_num_epochs=10
301+
)
302+
303+
# imagenet_config = data_preprocessing.ExperimentConfig(
304+
# data_str='imagenet',
305+
# main_model_name='dinov2_vits14',
306+
# secondary_model_name='dinov2_vitl14',
307+
# main_lr=0.000001,
308+
# secondary_lr=0.000001,
309+
# binary_lr=0.000001,
310+
# original_num_epochs=8,
311+
# secondary_num_epochs=2,
312+
# binary_num_epochs=5
313+
# )
314+
#
315+
# openimage_config = data_preprocessing.ExperimentConfig(
316+
# data_str='openimage',
317+
# main_model_name='vit_b_16',
318+
# secondary_model_name='dinov2_vits14',
319+
# main_lr=0.0001,
320+
# secondary_lr=0.000001,
321+
# binary_lr=0.000001,
322+
# original_num_epochs=20,
323+
# secondary_num_epochs=20,
324+
# binary_num_epochs=4
325+
# )
326+
327+
run_experiment(config=military_vehicles_config)
328+
329+
330+
331+
332+
323333
if __name__ == '__main__':
324334
main()

0 commit comments

Comments
 (0)