@@ -47,18 +47,16 @@ def save_configuration(
47
47
save_precision ,
48
48
seed ,
49
49
num_cpu_threads_per_process ,
50
- convert_to_safetensors ,
51
- convert_to_ckpt ,
52
50
cache_latent ,
53
51
caption_extention ,
54
- use_safetensors ,
55
52
enable_bucket ,
56
53
gradient_checkpointing ,
57
54
full_fp16 ,
58
55
no_token_padding ,
59
56
stop_text_encoder_training ,
60
57
use_8bit_adam ,
61
58
xformers ,
59
+ save_model_as
62
60
):
63
61
original_file_path = file_path
64
62
@@ -103,18 +101,16 @@ def save_configuration(
103
101
'save_precision' : save_precision ,
104
102
'seed' : seed ,
105
103
'num_cpu_threads_per_process' : num_cpu_threads_per_process ,
106
- 'convert_to_safetensors' : convert_to_safetensors ,
107
- 'convert_to_ckpt' : convert_to_ckpt ,
108
104
'cache_latent' : cache_latent ,
109
105
'caption_extention' : caption_extention ,
110
- 'use_safetensors' : use_safetensors ,
111
106
'enable_bucket' : enable_bucket ,
112
107
'gradient_checkpointing' : gradient_checkpointing ,
113
108
'full_fp16' : full_fp16 ,
114
109
'no_token_padding' : no_token_padding ,
115
110
'stop_text_encoder_training' : stop_text_encoder_training ,
116
111
'use_8bit_adam' : use_8bit_adam ,
117
112
'xformers' : xformers ,
113
+ 'save_model_as' : save_model_as
118
114
}
119
115
120
116
# Save the data to the selected file
@@ -144,18 +140,16 @@ def open_configuration(
144
140
save_precision ,
145
141
seed ,
146
142
num_cpu_threads_per_process ,
147
- convert_to_safetensors ,
148
- convert_to_ckpt ,
149
143
cache_latent ,
150
144
caption_extention ,
151
- use_safetensors ,
152
145
enable_bucket ,
153
146
gradient_checkpointing ,
154
147
full_fp16 ,
155
148
no_token_padding ,
156
149
stop_text_encoder_training ,
157
150
use_8bit_adam ,
158
151
xformers ,
152
+ save_model_as
159
153
):
160
154
161
155
original_file_path = file_path
@@ -195,18 +189,16 @@ def open_configuration(
195
189
my_data .get (
196
190
'num_cpu_threads_per_process' , num_cpu_threads_per_process
197
191
),
198
- my_data .get ('convert_to_safetensors' , convert_to_safetensors ),
199
- my_data .get ('convert_to_ckpt' , convert_to_ckpt ),
200
192
my_data .get ('cache_latent' , cache_latent ),
201
193
my_data .get ('caption_extention' , caption_extention ),
202
- my_data .get ('use_safetensors' , use_safetensors ),
203
194
my_data .get ('enable_bucket' , enable_bucket ),
204
195
my_data .get ('gradient_checkpointing' , gradient_checkpointing ),
205
196
my_data .get ('full_fp16' , full_fp16 ),
206
197
my_data .get ('no_token_padding' , no_token_padding ),
207
198
my_data .get ('stop_text_encoder_training' , stop_text_encoder_training ),
208
199
my_data .get ('use_8bit_adam' , use_8bit_adam ),
209
200
my_data .get ('xformers' , xformers ),
201
+ my_data .get ('save_model_as' , save_model_as )
210
202
)
211
203
212
204
@@ -229,18 +221,16 @@ def train_model(
229
221
save_precision ,
230
222
seed ,
231
223
num_cpu_threads_per_process ,
232
- convert_to_safetensors ,
233
- convert_to_ckpt ,
234
224
cache_latent ,
235
225
caption_extention ,
236
- use_safetensors ,
237
226
enable_bucket ,
238
227
gradient_checkpointing ,
239
228
full_fp16 ,
240
229
no_token_padding ,
241
230
stop_text_encoder_training_pct ,
242
231
use_8bit_adam ,
243
232
xformers ,
233
+ save_model_as
244
234
):
245
235
def save_inference_file (output_dir , v2 , v_parameterization ):
246
236
# Copy inference model for v2 if required
@@ -352,8 +342,6 @@ def save_inference_file(output_dir, v2, v_parameterization):
352
342
run_cmd += ' --v_parameterization'
353
343
if cache_latent :
354
344
run_cmd += ' --cache_latents'
355
- if use_safetensors :
356
- run_cmd += ' --use_safetensors'
357
345
if enable_bucket :
358
346
run_cmd += ' --enable_bucket'
359
347
if gradient_checkpointing :
@@ -388,39 +376,20 @@ def save_inference_file(output_dir, v2, v_parameterization):
388
376
run_cmd += f' --logging_dir={ logging_dir } '
389
377
run_cmd += f' --caption_extention={ caption_extention } '
390
378
run_cmd += f' --stop_text_encoder_training={ stop_text_encoder_training } '
379
+ if not save_model_as == 'same as source model' :
380
+ run_cmd += f' --save_model_as={ save_model_as } '
391
381
392
382
print (run_cmd )
393
383
# Run the command
394
384
subprocess .run (run_cmd )
395
385
396
- # check if output_dir/last is a directory ... therefore it is a diffuser model
386
+ # check if output_dir/last is a folder ... therefore it is a diffuser model
397
387
last_dir = pathlib .Path (f'{ output_dir } /last' )
398
- print (last_dir )
399
- if last_dir .is_dir ():
400
- if convert_to_ckpt :
401
- print (f'Converting diffuser model { last_dir } to { last_dir } .ckpt' )
402
- os .system (
403
- f'python ./tools/convert_diffusers20_original_sd.py { last_dir } { last_dir } .ckpt --{ save_precision } '
404
- )
405
-
406
- save_inference_file (output_dir , v2 , v_parameterization )
407
-
408
- if convert_to_safetensors :
409
- print (
410
- f'Converting diffuser model { last_dir } to { last_dir } .safetensors'
411
- )
412
- os .system (
413
- f'python ./tools/convert_diffusers20_original_sd.py { last_dir } { last_dir } .safetensors --{ save_precision } '
414
- )
415
-
416
- save_inference_file (output_dir , v2 , v_parameterization )
417
- else :
388
+
389
+ if not last_dir .is_dir ():
418
390
# Copy inference model for v2 if required
419
391
save_inference_file (output_dir , v2 , v_parameterization )
420
392
421
- # Return the values of the variables as a dictionary
422
- # return
423
-
424
393
425
394
def set_pretrained_model_name_or_path_input (value , v2 , v_parameterization ):
426
395
# define a list of substrings to search for
@@ -533,6 +502,17 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
533
502
'CompVis/stable-diffusion-v1-4' ,
534
503
],
535
504
)
505
+ save_model_as_dropdown = gr .Dropdown (
506
+ label = 'Save trained model as' ,
507
+ choices = [
508
+ 'same as source model' ,
509
+ 'ckpt' ,
510
+ 'diffusers' ,
511
+ "diffusers_safetensors" ,
512
+ 'safetensors' ,
513
+ ],
514
+ value = 'same as source model'
515
+ )
536
516
with gr .Row ():
537
517
v2_input = gr .Checkbox (label = 'v2' , value = True )
538
518
v_parameterization_input = gr .Checkbox (
@@ -557,7 +537,7 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
557
537
with gr .Row ():
558
538
train_data_dir_input = gr .Textbox (
559
539
label = 'Image folder' ,
560
- placeholder = 'Directory where the training folders containing the images are located' ,
540
+ placeholder = 'Folder where the training folders containing the images are located' ,
561
541
)
562
542
train_data_dir_input_folder = gr .Button (
563
543
'📂' , elem_id = 'open_folder_small'
@@ -567,7 +547,7 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
567
547
)
568
548
reg_data_dir_input = gr .Textbox (
569
549
label = 'Regularisation folder' ,
570
- placeholder = '(Optional) Directory where where the regularization folders containing the images are located' ,
550
+ placeholder = '(Optional) Folder where where the regularization folders containing the images are located' ,
571
551
)
572
552
reg_data_dir_input_folder = gr .Button (
573
553
'📂' , elem_id = 'open_folder_small'
@@ -577,8 +557,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
577
557
)
578
558
with gr .Row ():
579
559
output_dir_input = gr .Textbox (
580
- label = 'Output directory ' ,
581
- placeholder = 'Directory to output trained model' ,
560
+ label = 'Output folder ' ,
561
+ placeholder = 'Folder to output trained model' ,
582
562
)
583
563
output_dir_input_folder = gr .Button (
584
564
'📂' , elem_id = 'open_folder_small'
@@ -587,8 +567,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
587
567
get_folder_path , outputs = output_dir_input
588
568
)
589
569
logging_dir_input = gr .Textbox (
590
- label = 'Logging directory ' ,
591
- placeholder = 'Optional: enable logging and output TensorBoard log to this directory ' ,
570
+ label = 'Logging folder ' ,
571
+ placeholder = 'Optional: enable logging and output TensorBoard log to this folder ' ,
592
572
)
593
573
logging_dir_input_folder = gr .Button (
594
574
'📂' , elem_id = 'open_folder_small'
@@ -694,9 +674,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
694
674
no_token_padding_input = gr .Checkbox (
695
675
label = 'No token padding' , value = False
696
676
)
697
- use_safetensors_input = gr .Checkbox (
698
- label = 'Use safetensor when saving' , value = False
699
- )
700
677
701
678
gradient_checkpointing_input = gr .Checkbox (
702
679
label = 'Gradient checkpointing' , value = False
@@ -711,13 +688,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
711
688
)
712
689
xformers_input = gr .Checkbox (label = 'Use xformers' , value = True )
713
690
714
- with gr .Tab ('Model conversion' ):
715
- convert_to_safetensors_input = gr .Checkbox (
716
- label = 'Convert to SafeTensors' , value = True
717
- )
718
- convert_to_ckpt_input = gr .Checkbox (
719
- label = 'Convert to CKPT' , value = False
720
- )
721
691
with gr .Tab ('Utilities' ):
722
692
# Dreambooth folder creation tab
723
693
gradio_dreambooth_folder_creation_tab (
@@ -729,6 +699,13 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
729
699
# Captionning tab
730
700
gradio_caption_gui_tab ()
731
701
gradio_dataset_balancing_tab ()
702
+ # with gr.Tab('Model conversion'):
703
+ # convert_to_safetensors_input = gr.Checkbox(
704
+ # label='Convert to SafeTensors', value=True
705
+ # )
706
+ # convert_to_ckpt_input = gr.Checkbox(
707
+ # label='Convert to CKPT', value=False
708
+ # )
732
709
733
710
button_run = gr .Button ('Train model' )
734
711
@@ -754,18 +731,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
754
731
save_precision_input ,
755
732
seed_input ,
756
733
num_cpu_threads_per_process_input ,
757
- convert_to_safetensors_input ,
758
- convert_to_ckpt_input ,
759
734
cache_latent_input ,
760
735
caption_extention_input ,
761
- use_safetensors_input ,
762
736
enable_bucket_input ,
763
737
gradient_checkpointing_input ,
764
738
full_fp16_input ,
765
739
no_token_padding_input ,
766
740
stop_text_encoder_training_input ,
767
741
use_8bit_adam_input ,
768
742
xformers_input ,
743
+ save_model_as_dropdown
769
744
],
770
745
outputs = [
771
746
config_file_name ,
@@ -787,18 +762,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
787
762
save_precision_input ,
788
763
seed_input ,
789
764
num_cpu_threads_per_process_input ,
790
- convert_to_safetensors_input ,
791
- convert_to_ckpt_input ,
792
765
cache_latent_input ,
793
766
caption_extention_input ,
794
- use_safetensors_input ,
795
767
enable_bucket_input ,
796
768
gradient_checkpointing_input ,
797
769
full_fp16_input ,
798
770
no_token_padding_input ,
799
771
stop_text_encoder_training_input ,
800
772
use_8bit_adam_input ,
801
773
xformers_input ,
774
+ save_model_as_dropdown
802
775
],
803
776
)
804
777
@@ -827,18 +800,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
827
800
save_precision_input ,
828
801
seed_input ,
829
802
num_cpu_threads_per_process_input ,
830
- convert_to_safetensors_input ,
831
- convert_to_ckpt_input ,
832
803
cache_latent_input ,
833
804
caption_extention_input ,
834
- use_safetensors_input ,
835
805
enable_bucket_input ,
836
806
gradient_checkpointing_input ,
837
807
full_fp16_input ,
838
808
no_token_padding_input ,
839
809
stop_text_encoder_training_input ,
840
810
use_8bit_adam_input ,
841
811
xformers_input ,
812
+ save_model_as_dropdown
842
813
],
843
814
outputs = [config_file_name ],
844
815
)
@@ -866,18 +837,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
866
837
save_precision_input ,
867
838
seed_input ,
868
839
num_cpu_threads_per_process_input ,
869
- convert_to_safetensors_input ,
870
- convert_to_ckpt_input ,
871
840
cache_latent_input ,
872
841
caption_extention_input ,
873
- use_safetensors_input ,
874
842
enable_bucket_input ,
875
843
gradient_checkpointing_input ,
876
844
full_fp16_input ,
877
845
no_token_padding_input ,
878
846
stop_text_encoder_training_input ,
879
847
use_8bit_adam_input ,
880
848
xformers_input ,
849
+ save_model_as_dropdown
881
850
],
882
851
outputs = [config_file_name ],
883
852
)
@@ -903,18 +872,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
903
872
save_precision_input ,
904
873
seed_input ,
905
874
num_cpu_threads_per_process_input ,
906
- convert_to_safetensors_input ,
907
- convert_to_ckpt_input ,
908
875
cache_latent_input ,
909
876
caption_extention_input ,
910
- use_safetensors_input ,
911
877
enable_bucket_input ,
912
878
gradient_checkpointing_input ,
913
879
full_fp16_input ,
914
880
no_token_padding_input ,
915
881
stop_text_encoder_training_input ,
916
882
use_8bit_adam_input ,
917
883
xformers_input ,
884
+ save_model_as_dropdown
918
885
],
919
886
)
920
887
0 commit comments