Skip to content

Commit 7895de6

Browse files
authored
Merge pull request #2055 from bmaltais/dev
v23.0.2
2 parents 7dae63e + ab188dd commit 7895de6

16 files changed

+243
-232
lines changed

.release

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v23.0.1
1+
v23.0.2

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ The GUI allows you to set the training parameters and generate and run the requi
3737
- [No module called tkinter](#no-module-called-tkinter)
3838
- [SDXL training](#sdxl-training)
3939
- [Change History](#change-history)
40-
- [2024/03/10 (v23.0.1)](#20240310-v2301)
40+
- [2024/03/10 (v23.0.2)](#20240310-v2302)
41+
- [2024/03/09 (v23.0.1)](#20240309-v2301)
4142
- [2024/03/02 (v23.0.0)](#20240302-v2300)
4243

4344
## 🦒 Colab
@@ -364,7 +365,11 @@ The documentation in this section will be moved to a separate document later.
364365

365366
## Change History
366367

367-
### 2024/03/10 (v23.0.1)
368+
### 2024/03/10 (v23.0.2)
369+
370+
- Improve validation of path provided by users before running training
371+
372+
### 2024/03/09 (v23.0.1)
368373

369374
- Update bitsandbytes module to 0.43.0 as it provide native windows support
370375
- Minor fixes to code

gui.bat

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
@echo off
22

3+
set PYTHON_VER=3.10.9
4+
35
:: Deactivate the virtual environment
46
call .\venv\Scripts\deactivate.bat
57

6-
:: Calling external python program to check for local modules
7-
:: python .\setup\check_local_modules.py --no_question
8+
:: Check if Python version meets the recommended version
9+
python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul
10+
if errorlevel 1 (
11+
echo Warning: Python version %PYTHON_VER% is required. Kohya_ss GUI will most likely fail to run.
12+
)
813

914
:: Activate the virtual environment
1015
call .\venv\Scripts\activate.bat

kohya_gui/class_source_model.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@
1515
save_style_symbol = '\U0001f4be' # 💾
1616
document_symbol = '\U0001F4C4' # 📄
1717

18+
default_models = [
19+
'stabilityai/stable-diffusion-xl-base-1.0',
20+
'stabilityai/stable-diffusion-xl-refiner-1.0',
21+
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
22+
'stabilityai/stable-diffusion-2-1-base',
23+
'stabilityai/stable-diffusion-2-base',
24+
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
25+
'stabilityai/stable-diffusion-2-1',
26+
'stabilityai/stable-diffusion-2',
27+
'runwayml/stable-diffusion-v1-5',
28+
'CompVis/stable-diffusion-v1-4',
29+
]
1830

1931
class SourceModel:
2032
def __init__(
@@ -39,19 +51,6 @@ def __init__(
3951
self.save_model_as_choices = save_model_as_choices
4052
self.finetuning = finetuning
4153

42-
default_models = [
43-
'stabilityai/stable-diffusion-xl-base-1.0',
44-
'stabilityai/stable-diffusion-xl-refiner-1.0',
45-
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
46-
'stabilityai/stable-diffusion-2-1-base',
47-
'stabilityai/stable-diffusion-2-base',
48-
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
49-
'stabilityai/stable-diffusion-2-1',
50-
'stabilityai/stable-diffusion-2',
51-
'runwayml/stable-diffusion-v1-5',
52-
'CompVis/stable-diffusion-v1-4',
53-
]
54-
5554
from .common_gui import create_refresh_button
5655

5756
default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs")

kohya_gui/common_gui.py

Lines changed: 138 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import json
1010

1111
from .custom_logging import setup_logging
12-
from datetime import datetime
1312

1413
# Set up logging
1514
log = setup_logging()
@@ -699,7 +698,7 @@ def run_cmd_advanced_training(**kwargs):
699698
if "additional_parameters" in kwargs:
700699
run_cmd += f' {kwargs["additional_parameters"]}'
701700

702-
if "block_lr" in kwargs:
701+
if "block_lr" in kwargs and kwargs["block_lr"] != "":
703702
run_cmd += f' --block_lr="{kwargs["block_lr"]}"'
704703

705704
if kwargs.get("bucket_no_upscale"):
@@ -1143,12 +1142,12 @@ def run_cmd_advanced_training(**kwargs):
11431142

11441143
def verify_image_folder_pattern(folder_path):
11451144
false_response = True # temporarily set to true to prevent stopping training in case of false positive
1146-
true_response = True
11471145

1146+
log.info(f"Verifying image folder pattern of {folder_path}...")
11481147
# Check if the folder exists
11491148
if not os.path.isdir(folder_path):
11501149
log.error(
1151-
f"The provided path '{folder_path}' is not a valid folder. Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
1150+
f"...the provided path '{folder_path}' is not a valid folder. Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
11521151
)
11531152
return false_response
11541153

@@ -1176,22 +1175,22 @@ def verify_image_folder_pattern(folder_path):
11761175
non_matching_subfolders = set(subfolders) - set(matching_subfolders)
11771176
if non_matching_subfolders:
11781177
log.error(
1179-
f"The following folders do not match the required pattern <number>_<text>: {', '.join(non_matching_subfolders)}"
1178+
f"...the following folders do not match the required pattern <number>_<text>: {', '.join(non_matching_subfolders)}"
11801179
)
11811180
log.error(
1182-
f"Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
1181+
f"...please follow the folder structure documentation found at docs\image_folder_structure.md ..."
11831182
)
11841183
return false_response
11851184

11861185
# Check if no sub-folders exist
11871186
if not matching_subfolders:
11881187
log.error(
1189-
f"No image folders found in {folder_path}. Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
1188+
f"...no image folders found in {folder_path}. Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
11901189
)
11911190
return false_response
11921191

1193-
log.info(f"Valid image folder names found in: {folder_path}")
1194-
return true_response
1192+
log.info(f"...valid")
1193+
return True
11951194

11961195

11971196
def SaveConfigFile(
@@ -1231,7 +1230,9 @@ def save_to_file(content):
12311230
def check_duplicate_filenames(
12321231
folder_path, image_extension=[".gif", ".png", ".jpg", ".jpeg", ".webp"]
12331232
):
1234-
log.info("Checking for duplicate image filenames in training data directory...")
1233+
duplicate = False
1234+
1235+
log.info(f"Checking for duplicate image filenames in training data directory {folder_path}...")
12351236
for root, dirs, files in os.walk(folder_path):
12361237
filenames = {}
12371238
for file in files:
@@ -1241,15 +1242,138 @@ def check_duplicate_filenames(
12411242
if filename in filenames:
12421243
existing_path = filenames[filename]
12431244
if existing_path != full_path:
1244-
print(
1245-
f"Warning: Same filename '{filename}' with different image extension found. This will cause training issues. Rename one of the file."
1245+
log.warning(
1246+
f"...same filename '{filename}' with different image extension found. This will cause training issues. Rename one of the file."
12461247
)
1247-
print(f"Existing file: {existing_path}")
1248-
print(f"Current file: {full_path}")
1248+
log.warning(f" Existing file: {existing_path}")
1249+
log.warning(f" Current file: {full_path}")
1250+
duplicate = True
12491251
else:
12501252
filenames[filename] = full_path
1253+
if not duplicate:
1254+
log.info("...valid")
12511255

12521256

1257+
def validate_paths(headless:bool = False, **kwargs):
1258+
from .class_source_model import default_models
1259+
1260+
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path")
1261+
train_data_dir = kwargs.get("train_data_dir")
1262+
reg_data_dir = kwargs.get("reg_data_dir")
1263+
output_dir = kwargs.get("output_dir")
1264+
logging_dir = kwargs.get("logging_dir")
1265+
lora_network_weights= kwargs.get("lora_network_weights")
1266+
finetune_image_folder = kwargs.get("finetune_image_folder")
1267+
resume = kwargs.get("resume")
1268+
vae = kwargs.get("vae")
1269+
1270+
if pretrained_model_name_or_path is not None:
1271+
log.info(f"Validating model file or folder path {pretrained_model_name_or_path} existence...")
1272+
1273+
# Check if it matches the Hugging Face model pattern
1274+
if re.match(r'^[\w-]+\/[\w-]+$', pretrained_model_name_or_path):
1275+
log.info("...huggingface.co model, skipping validation")
1276+
elif pretrained_model_name_or_path not in default_models:
1277+
# If not one of the default models, check if it's a valid local path
1278+
if not os.path.exists(pretrained_model_name_or_path):
1279+
log.error(f"...source model path '{pretrained_model_name_or_path}' is missing or does not exist")
1280+
return False
1281+
else:
1282+
log.info("...valid")
1283+
else:
1284+
log.info("...valid")
1285+
1286+
# Check if train_data_dir is valid
1287+
if train_data_dir != None:
1288+
log.info(f"Validating training data folder path {train_data_dir} existence...")
1289+
if not train_data_dir or not os.path.exists(train_data_dir):
1290+
log.error(f"Image folder path '{train_data_dir}' is missing or does not exist")
1291+
return False
1292+
else:
1293+
log.info("...valid")
1294+
1295+
# Check if there are files with the same filename but different image extension... warn the user if it is the case.
1296+
check_duplicate_filenames(train_data_dir)
1297+
1298+
if not verify_image_folder_pattern(folder_path=train_data_dir):
1299+
return False
1300+
1301+
if finetune_image_folder != None:
1302+
log.info(f"Validating finetuning image folder path {finetune_image_folder} existence...")
1303+
if not finetune_image_folder or not os.path.exists(finetune_image_folder):
1304+
log.error(f"Image folder path '{finetune_image_folder}' is missing or does not exist")
1305+
return False
1306+
else:
1307+
log.info("...valid")
1308+
1309+
if reg_data_dir != None:
1310+
if reg_data_dir != "":
1311+
log.info(f"Validating regularisation data folder path {reg_data_dir} existence...")
1312+
if not os.path.exists(reg_data_dir):
1313+
log.error("...regularisation folder does not exist")
1314+
return False
1315+
1316+
if not verify_image_folder_pattern(folder_path=reg_data_dir):
1317+
return False
1318+
log.info("...valid")
1319+
else:
1320+
log.info("Regularisation folder not specified, skipping validation")
1321+
1322+
if output_dir != None:
1323+
log.info(f"Validating output folder path {output_dir} existence...")
1324+
if output_dir == "" or not os.path.exists(output_dir):
1325+
log.error("...output folder path is missing or invalid")
1326+
return False
1327+
else:
1328+
log.info("...valid")
1329+
1330+
if logging_dir != None:
1331+
if logging_dir != "":
1332+
log.info(f"Validating logging folder path {logging_dir} existence...")
1333+
if not os.path.exists(logging_dir):
1334+
log.error("...logging folder path is missing or invalid")
1335+
return False
1336+
else:
1337+
log.info("...valid")
1338+
else:
1339+
log.info("Logging folder not specified, skipping validation")
1340+
1341+
if lora_network_weights != None:
1342+
if lora_network_weights != "":
1343+
log.info(f"Validating LoRA Network Weight file path {lora_network_weights} existence...")
1344+
if not os.path.exists(lora_network_weights):
1345+
log.error("...path is invalid")
1346+
return False
1347+
else:
1348+
log.info("...valid")
1349+
else:
1350+
log.info("LoRA Network Weight file not specified, skipping validation")
1351+
1352+
if resume != None:
1353+
if resume != "":
1354+
log.info(f"Validating model resume file path {resume} existence...")
1355+
if not os.path.exists(resume):
1356+
log.error("...path is invalid")
1357+
return False
1358+
else:
1359+
log.info("...valid")
1360+
else:
1361+
log.info("Model resume file not specified, skipping validation")
1362+
1363+
if vae != None:
1364+
if vae != "":
1365+
log.info(f"Validating VAE file path {vae} existence...")
1366+
if not os.path.exists(vae):
1367+
log.error("...vae path is invalid")
1368+
return False
1369+
else:
1370+
log.info("...valid")
1371+
else:
1372+
log.info("VAE file not specified, skipping validation")
1373+
1374+
1375+
return True
1376+
12531377
def is_file_writable(file_path):
12541378
if not os.path.exists(file_path):
12551379
# print(f"File '{file_path}' does not exist.")

kohya_gui/dreambooth_gui.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
1-
# v1: initial release
2-
# v2: add open and save folder icons
3-
# v3: Add new Utilities tab for Dreambooth folder preparation
4-
# v3.1: Adding captionning of images to utilities
5-
61
import gradio as gr
72
import json
83
import math
94
import os
10-
import subprocess
115
import sys
126
import pathlib
137
from datetime import datetime
@@ -19,11 +13,10 @@
1913
run_cmd_advanced_training,
2014
update_my_data,
2115
check_if_model_exist,
22-
output_message,
23-
verify_image_folder_pattern,
2416
SaveConfigFile,
2517
save_to_file,
2618
scriptdir,
19+
validate_paths,
2720
)
2821
from .class_configuration_file import ConfigurationFile
2922
from .class_source_model import SourceModel
@@ -406,36 +399,16 @@ def train_model(
406399

407400
headless_bool = True if headless.get("label") == "True" else False
408401

409-
if pretrained_model_name_or_path == "":
410-
output_message(
411-
msg="Source model information is missing", headless=headless_bool
412-
)
413-
return
414-
415-
if train_data_dir == "":
416-
output_message(msg="Image folder path is missing", headless=headless_bool)
417-
return
418-
419-
if not os.path.exists(train_data_dir):
420-
output_message(msg="Image folder does not exist", headless=headless_bool)
421-
return
422-
423-
if not verify_image_folder_pattern(train_data_dir):
424-
return
425-
426-
if reg_data_dir != "":
427-
if not os.path.exists(reg_data_dir):
428-
output_message(
429-
msg="Regularisation folder does not exist",
430-
headless=headless_bool,
431-
)
432-
return
433-
434-
if not verify_image_folder_pattern(reg_data_dir):
435-
return
436-
437-
if output_dir == "":
438-
output_message(msg="Output folder path is missing", headless=headless_bool)
402+
if not validate_paths(
403+
output_dir=output_dir,
404+
pretrained_model_name_or_path=pretrained_model_name_or_path,
405+
train_data_dir=train_data_dir,
406+
reg_data_dir=reg_data_dir,
407+
headless=headless_bool,
408+
logging_dir=logging_dir,
409+
resume=resume,
410+
vae=vae,
411+
):
439412
return
440413

441414
if not print_only_bool and check_if_model_exist(

0 commit comments

Comments
 (0)