6
6
7
7
import numpy as np
8
8
import typing
9
- # from tqdm.contrib.concurrent import process_map
10
- # import itertools
11
9
12
- import data_preprocessing
10
+ import experiment_config
11
+ import data_preprocessor
13
12
import PyEDCR
14
- # import google_sheets_api
15
- # import plotting
16
13
17
14
18
15
class NeuralPyEDCR (PyEDCR .EDCR ):
@@ -29,7 +26,6 @@ def __init__(self,
29
26
sheet_index : int = None ,
30
27
K_train : typing .Union [typing .List [typing .Tuple [int ]], np .ndarray ] = None ,
31
28
K_test : typing .List [typing .Tuple [int ]] = None ,
32
- include_inconsistency_constraint : bool = False ,
33
29
secondary_model_name : str = None ,
34
30
secondary_model_loss : str = None ,
35
31
secondary_num_epochs : int = None ,
@@ -52,7 +48,6 @@ def __init__(self,
52
48
sheet_index = sheet_index ,
53
49
K_train = K_train ,
54
50
K_test = K_test ,
55
- include_inconsistency_constraint = include_inconsistency_constraint ,
56
51
secondary_model_name = secondary_model_name ,
57
52
secondary_model_loss = secondary_model_loss ,
58
53
secondary_num_epochs = secondary_num_epochs ,
@@ -79,7 +74,7 @@ def run_learning_pipeline(self,
79
74
# self.print_metrics(test=False, prior=True)
80
75
81
76
for EDCR_epoch in range (self .EDCR_num_epochs ):
82
- for g in data_preprocessing .FineCoarseDataPreprocessor .granularities .values ():
77
+ for g in data_preprocessor .FineCoarseDataPreprocessor .granularities .values ():
83
78
self .learn_detection_rules (g = g ,
84
79
multi_processing = multi_processing )
85
80
# self.apply_detection_rules(test=False,
@@ -131,7 +126,6 @@ def work_on_value(args):
131
126
loss = 'BCE' ,
132
127
lr = main_lr ,
133
128
original_num_epochs = original_num_epochs ,
134
- include_inconsistency_constraint = False ,
135
129
secondary_model_name = secondary_model_name ,
136
130
secondary_model_loss = secondary_model_loss ,
137
131
secondary_num_epochs = secondary_num_epochs ,
@@ -205,37 +199,10 @@ def simulate_for_values(data_str: str,
205
199
work_on_value (data )
206
200
207
201
208
- def main ():
209
- data_str = 'military_vehicles'
210
- main_model_name = binary_model_name = 'vit_b_16'
211
- secondary_model_name = 'vit_l_16'
212
- main_lr = secondary_lr = binary_lr = 0.0001
213
- original_num_epochs = 10
214
- secondary_num_epochs = 20
215
- binary_num_epochs = 10
216
-
217
- # data_str = 'imagenet'
218
- # main_model_name = binary_model_name = 'dinov2_vits14'
219
- # secondary_model_name = 'dinov2_vitl14'
220
- # # main_lr = 0.00001
221
- # main_lr = secondary_lr = binary_lr = 0.000001
222
- # original_num_epochs = 8
223
- # secondary_num_epochs = 2
224
- # binary_num_epochs = 5
225
-
226
- # data_str = 'openimage'
227
- # main_model_name = 'vit_b_16'
228
- # secondary_model_name = binary_model_name = 'dinov2_vits14'
229
- # main_lr = 0.0001
230
- # binary_lr = 0.000001
231
- # secondary_lr = 0.000001
232
- # original_num_epochs = 20
233
- # secondary_num_epochs = 20
234
- # binary_num_epochs = 4
235
-
236
- binary_l_strs = list ({f .split (f'e{ binary_num_epochs - 1 } _' )[- 1 ].replace ('.npy' , '' )
202
+ def run_experiment (config : experiment_config .ExperimentConfig ):
203
+ binary_l_strs = list ({f .split (f'e{ config .binary_num_epochs - 1 } _' )[- 1 ].replace ('.npy' , '' )
237
204
for f in os .listdir ('binary_results' )
238
- if f .startswith (f'{ data_str } _{ binary_model_name } ' )})
205
+ if f .startswith (f'{ config . data_str } _{ config . binary_model_name } ' )})
239
206
240
207
# print(google_sheets_api.get_maximal_epsilon(tab_name=sheet_tab))
241
208
@@ -251,11 +218,11 @@ def main():
251
218
# lists_of_fine_labels_to_take_out = [list(range(number_of_fine_classes-1))]
252
219
253
220
for (curr_secondary_model_name , curr_secondary_model_loss , curr_secondary_num_epochs , curr_secondary_lr ) in \
254
- [(secondary_model_name , 'BCE' , secondary_num_epochs , secondary_lr ),
221
+ [(config . secondary_model_name , 'BCE' , config . secondary_num_epochs , config . secondary_lr ),
255
222
# [None] * 4
256
223
]:
257
224
for (curr_binary_l_strs , curr_binary_lr , curr_binary_num_epochs ) in \
258
- [(binary_l_strs , binary_lr , binary_num_epochs ),
225
+ [(binary_l_strs , config . binary_lr , config . binary_num_epochs ),
259
226
# ([], None, None)
260
227
]:
261
228
for (lists_of_fine_labels_to_take_out , maximize_ratio , multi_processing ) in \
@@ -264,11 +231,11 @@ def main():
264
231
# ([list(range(i)) for i in range(int(number_of_fine_classes / 2) + 1)], True, True)
265
232
]:
266
233
simulate_for_values (
267
- data_str = data_str ,
268
- main_model_name = main_model_name ,
269
- main_lr = main_lr ,
270
- original_num_epochs = original_num_epochs ,
271
- binary_model_name = binary_model_name ,
234
+ data_str = config . data_str ,
235
+ main_model_name = config . main_model_name ,
236
+ main_lr = config . main_lr ,
237
+ original_num_epochs = config . original_num_epochs ,
238
+ binary_model_name = config . binary_model_name ,
272
239
binary_l_strs = curr_binary_l_strs ,
273
240
binary_lr = curr_binary_lr ,
274
241
binary_num_epochs = curr_binary_num_epochs ,
@@ -320,5 +287,48 @@ def main():
320
287
# fontsize=24)
321
288
322
289
290
+ def main ():
291
+ military_vehicles_config = experiment_config .ExperimentConfig (
292
+ data_str = 'military_vehicles' ,
293
+ main_model_name = 'vit_b_16' ,
294
+ secondary_model_name = 'vit_l_16' ,
295
+ main_lr = 0.0001 ,
296
+ secondary_lr = 0.0001 ,
297
+ binary_lr = 0.0001 ,
298
+ original_num_epochs = 10 ,
299
+ secondary_num_epochs = 20 ,
300
+ binary_num_epochs = 10
301
+ )
302
+
303
+ # imagenet_config = data_preprocessing.ExperimentConfig(
304
+ # data_str='imagenet',
305
+ # main_model_name='dinov2_vits14',
306
+ # secondary_model_name='dinov2_vitl14',
307
+ # main_lr=0.000001,
308
+ # secondary_lr=0.000001,
309
+ # binary_lr=0.000001,
310
+ # original_num_epochs=8,
311
+ # secondary_num_epochs=2,
312
+ # binary_num_epochs=5
313
+ # )
314
+ #
315
+ # openimage_config = data_preprocessing.ExperimentConfig(
316
+ # data_str='openimage',
317
+ # main_model_name='vit_b_16',
318
+ # secondary_model_name='dinov2_vits14',
319
+ # main_lr=0.0001,
320
+ # secondary_lr=0.000001,
321
+ # binary_lr=0.000001,
322
+ # original_num_epochs=20,
323
+ # secondary_num_epochs=20,
324
+ # binary_num_epochs=4
325
+ # )
326
+
327
+ run_experiment (config = military_vehicles_config )
328
+
329
+
330
+
331
+
332
+
323
333
if __name__ == '__main__' :
324
334
main ()
0 commit comments