Skip to content

Commit f459c32

Browse files
committed
v18: Save model as option added
1 parent fc22813 commit f459c32

File tree

4 files changed

+97
-93
lines changed

4 files changed

+97
-93
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ Drop by the discord server for support: https://discord.com/channels/10415185624
129129
- Lord of the universe - cacoe (twitter: @cac0e)
130130

131131
## Change history
132+
133+
* 12/17 (v18) update:
134+
- Save model as option added to train_db_fixed.py
135+
- Save model as option added to GUI
136+
- Retire "Model conversion" parameters that was essentially performing the same function as the new `--save_model_as` parameter
132137
* 12/17 (v17.2) update:
133138
- Adding new dataset balancing utility.
134139
* 12/17 (v17.1) update:

dreambooth_gui.py

Lines changed: 39 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,16 @@ def save_configuration(
4747
save_precision,
4848
seed,
4949
num_cpu_threads_per_process,
50-
convert_to_safetensors,
51-
convert_to_ckpt,
5250
cache_latent,
5351
caption_extention,
54-
use_safetensors,
5552
enable_bucket,
5653
gradient_checkpointing,
5754
full_fp16,
5855
no_token_padding,
5956
stop_text_encoder_training,
6057
use_8bit_adam,
6158
xformers,
59+
save_model_as
6260
):
6361
original_file_path = file_path
6462

@@ -103,18 +101,16 @@ def save_configuration(
103101
'save_precision': save_precision,
104102
'seed': seed,
105103
'num_cpu_threads_per_process': num_cpu_threads_per_process,
106-
'convert_to_safetensors': convert_to_safetensors,
107-
'convert_to_ckpt': convert_to_ckpt,
108104
'cache_latent': cache_latent,
109105
'caption_extention': caption_extention,
110-
'use_safetensors': use_safetensors,
111106
'enable_bucket': enable_bucket,
112107
'gradient_checkpointing': gradient_checkpointing,
113108
'full_fp16': full_fp16,
114109
'no_token_padding': no_token_padding,
115110
'stop_text_encoder_training': stop_text_encoder_training,
116111
'use_8bit_adam': use_8bit_adam,
117112
'xformers': xformers,
113+
'save_model_as': save_model_as
118114
}
119115

120116
# Save the data to the selected file
@@ -144,18 +140,16 @@ def open_configuration(
144140
save_precision,
145141
seed,
146142
num_cpu_threads_per_process,
147-
convert_to_safetensors,
148-
convert_to_ckpt,
149143
cache_latent,
150144
caption_extention,
151-
use_safetensors,
152145
enable_bucket,
153146
gradient_checkpointing,
154147
full_fp16,
155148
no_token_padding,
156149
stop_text_encoder_training,
157150
use_8bit_adam,
158151
xformers,
152+
save_model_as
159153
):
160154

161155
original_file_path = file_path
@@ -195,18 +189,16 @@ def open_configuration(
195189
my_data.get(
196190
'num_cpu_threads_per_process', num_cpu_threads_per_process
197191
),
198-
my_data.get('convert_to_safetensors', convert_to_safetensors),
199-
my_data.get('convert_to_ckpt', convert_to_ckpt),
200192
my_data.get('cache_latent', cache_latent),
201193
my_data.get('caption_extention', caption_extention),
202-
my_data.get('use_safetensors', use_safetensors),
203194
my_data.get('enable_bucket', enable_bucket),
204195
my_data.get('gradient_checkpointing', gradient_checkpointing),
205196
my_data.get('full_fp16', full_fp16),
206197
my_data.get('no_token_padding', no_token_padding),
207198
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
208199
my_data.get('use_8bit_adam', use_8bit_adam),
209200
my_data.get('xformers', xformers),
201+
my_data.get('save_model_as', save_model_as)
210202
)
211203

212204

@@ -229,18 +221,16 @@ def train_model(
229221
save_precision,
230222
seed,
231223
num_cpu_threads_per_process,
232-
convert_to_safetensors,
233-
convert_to_ckpt,
234224
cache_latent,
235225
caption_extention,
236-
use_safetensors,
237226
enable_bucket,
238227
gradient_checkpointing,
239228
full_fp16,
240229
no_token_padding,
241230
stop_text_encoder_training_pct,
242231
use_8bit_adam,
243232
xformers,
233+
save_model_as
244234
):
245235
def save_inference_file(output_dir, v2, v_parameterization):
246236
# Copy inference model for v2 if required
@@ -352,8 +342,6 @@ def save_inference_file(output_dir, v2, v_parameterization):
352342
run_cmd += ' --v_parameterization'
353343
if cache_latent:
354344
run_cmd += ' --cache_latents'
355-
if use_safetensors:
356-
run_cmd += ' --use_safetensors'
357345
if enable_bucket:
358346
run_cmd += ' --enable_bucket'
359347
if gradient_checkpointing:
@@ -388,39 +376,20 @@ def save_inference_file(output_dir, v2, v_parameterization):
388376
run_cmd += f' --logging_dir={logging_dir}'
389377
run_cmd += f' --caption_extention={caption_extention}'
390378
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
379+
if not save_model_as == 'same as source model':
380+
run_cmd += f' --save_model_as={save_model_as}'
391381

392382
print(run_cmd)
393383
# Run the command
394384
subprocess.run(run_cmd)
395385

396-
# check if output_dir/last is a directory... therefore it is a diffuser model
386+
# check if output_dir/last is a folder... therefore it is a diffuser model
397387
last_dir = pathlib.Path(f'{output_dir}/last')
398-
print(last_dir)
399-
if last_dir.is_dir():
400-
if convert_to_ckpt:
401-
print(f'Converting diffuser model {last_dir} to {last_dir}.ckpt')
402-
os.system(
403-
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.ckpt --{save_precision}'
404-
)
405-
406-
save_inference_file(output_dir, v2, v_parameterization)
407-
408-
if convert_to_safetensors:
409-
print(
410-
f'Converting diffuser model {last_dir} to {last_dir}.safetensors'
411-
)
412-
os.system(
413-
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.safetensors --{save_precision}'
414-
)
415-
416-
save_inference_file(output_dir, v2, v_parameterization)
417-
else:
388+
389+
if not last_dir.is_dir():
418390
# Copy inference model for v2 if required
419391
save_inference_file(output_dir, v2, v_parameterization)
420392

421-
# Return the values of the variables as a dictionary
422-
# return
423-
424393

425394
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
426395
# define a list of substrings to search for
@@ -533,6 +502,17 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
533502
'CompVis/stable-diffusion-v1-4',
534503
],
535504
)
505+
save_model_as_dropdown = gr.Dropdown(
506+
label='Save trained model as',
507+
choices=[
508+
'same as source model',
509+
'ckpt',
510+
'diffusers',
511+
"diffusers_safetensors",
512+
'safetensors',
513+
],
514+
value='same as source model'
515+
)
536516
with gr.Row():
537517
v2_input = gr.Checkbox(label='v2', value=True)
538518
v_parameterization_input = gr.Checkbox(
@@ -557,7 +537,7 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
557537
with gr.Row():
558538
train_data_dir_input = gr.Textbox(
559539
label='Image folder',
560-
placeholder='Directory where the training folders containing the images are located',
540+
placeholder='Folder where the training folders containing the images are located',
561541
)
562542
train_data_dir_input_folder = gr.Button(
563543
'📂', elem_id='open_folder_small'
@@ -567,7 +547,7 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
567547
)
568548
reg_data_dir_input = gr.Textbox(
569549
label='Regularisation folder',
570-
placeholder='(Optional) Directory where where the regularization folders containing the images are located',
550+
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
571551
)
572552
reg_data_dir_input_folder = gr.Button(
573553
'📂', elem_id='open_folder_small'
@@ -577,8 +557,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
577557
)
578558
with gr.Row():
579559
output_dir_input = gr.Textbox(
580-
label='Output directory',
581-
placeholder='Directory to output trained model',
560+
label='Output folder',
561+
placeholder='Folder to output trained model',
582562
)
583563
output_dir_input_folder = gr.Button(
584564
'📂', elem_id='open_folder_small'
@@ -587,8 +567,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
587567
get_folder_path, outputs=output_dir_input
588568
)
589569
logging_dir_input = gr.Textbox(
590-
label='Logging directory',
591-
placeholder='Optional: enable logging and output TensorBoard log to this directory',
570+
label='Logging folder',
571+
placeholder='Optional: enable logging and output TensorBoard log to this folder',
592572
)
593573
logging_dir_input_folder = gr.Button(
594574
'📂', elem_id='open_folder_small'
@@ -694,9 +674,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
694674
no_token_padding_input = gr.Checkbox(
695675
label='No token padding', value=False
696676
)
697-
use_safetensors_input = gr.Checkbox(
698-
label='Use safetensor when saving', value=False
699-
)
700677

701678
gradient_checkpointing_input = gr.Checkbox(
702679
label='Gradient checkpointing', value=False
@@ -711,13 +688,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
711688
)
712689
xformers_input = gr.Checkbox(label='Use xformers', value=True)
713690

714-
with gr.Tab('Model conversion'):
715-
convert_to_safetensors_input = gr.Checkbox(
716-
label='Convert to SafeTensors', value=True
717-
)
718-
convert_to_ckpt_input = gr.Checkbox(
719-
label='Convert to CKPT', value=False
720-
)
721691
with gr.Tab('Utilities'):
722692
# Dreambooth folder creation tab
723693
gradio_dreambooth_folder_creation_tab(
@@ -729,6 +699,13 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
729699
# Captionning tab
730700
gradio_caption_gui_tab()
731701
gradio_dataset_balancing_tab()
702+
# with gr.Tab('Model conversion'):
703+
# convert_to_safetensors_input = gr.Checkbox(
704+
# label='Convert to SafeTensors', value=True
705+
# )
706+
# convert_to_ckpt_input = gr.Checkbox(
707+
# label='Convert to CKPT', value=False
708+
# )
732709

733710
button_run = gr.Button('Train model')
734711

@@ -754,18 +731,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
754731
save_precision_input,
755732
seed_input,
756733
num_cpu_threads_per_process_input,
757-
convert_to_safetensors_input,
758-
convert_to_ckpt_input,
759734
cache_latent_input,
760735
caption_extention_input,
761-
use_safetensors_input,
762736
enable_bucket_input,
763737
gradient_checkpointing_input,
764738
full_fp16_input,
765739
no_token_padding_input,
766740
stop_text_encoder_training_input,
767741
use_8bit_adam_input,
768742
xformers_input,
743+
save_model_as_dropdown
769744
],
770745
outputs=[
771746
config_file_name,
@@ -787,18 +762,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
787762
save_precision_input,
788763
seed_input,
789764
num_cpu_threads_per_process_input,
790-
convert_to_safetensors_input,
791-
convert_to_ckpt_input,
792765
cache_latent_input,
793766
caption_extention_input,
794-
use_safetensors_input,
795767
enable_bucket_input,
796768
gradient_checkpointing_input,
797769
full_fp16_input,
798770
no_token_padding_input,
799771
stop_text_encoder_training_input,
800772
use_8bit_adam_input,
801773
xformers_input,
774+
save_model_as_dropdown
802775
],
803776
)
804777

@@ -827,18 +800,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
827800
save_precision_input,
828801
seed_input,
829802
num_cpu_threads_per_process_input,
830-
convert_to_safetensors_input,
831-
convert_to_ckpt_input,
832803
cache_latent_input,
833804
caption_extention_input,
834-
use_safetensors_input,
835805
enable_bucket_input,
836806
gradient_checkpointing_input,
837807
full_fp16_input,
838808
no_token_padding_input,
839809
stop_text_encoder_training_input,
840810
use_8bit_adam_input,
841811
xformers_input,
812+
save_model_as_dropdown
842813
],
843814
outputs=[config_file_name],
844815
)
@@ -866,18 +837,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
866837
save_precision_input,
867838
seed_input,
868839
num_cpu_threads_per_process_input,
869-
convert_to_safetensors_input,
870-
convert_to_ckpt_input,
871840
cache_latent_input,
872841
caption_extention_input,
873-
use_safetensors_input,
874842
enable_bucket_input,
875843
gradient_checkpointing_input,
876844
full_fp16_input,
877845
no_token_padding_input,
878846
stop_text_encoder_training_input,
879847
use_8bit_adam_input,
880848
xformers_input,
849+
save_model_as_dropdown
881850
],
882851
outputs=[config_file_name],
883852
)
@@ -903,18 +872,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
903872
save_precision_input,
904873
seed_input,
905874
num_cpu_threads_per_process_input,
906-
convert_to_safetensors_input,
907-
convert_to_ckpt_input,
908875
cache_latent_input,
909876
caption_extention_input,
910-
use_safetensors_input,
911877
enable_bucket_input,
912878
gradient_checkpointing_input,
913879
full_fp16_input,
914880
no_token_padding_input,
915881
stop_text_encoder_training_input,
916882
use_8bit_adam_input,
917883
xformers_input,
884+
save_model_as_dropdown
918885
],
919886
)
920887

0 commit comments

Comments
 (0)