Skip to content

Commit 69d8b96

Browse files
Feat: Add logging for effective learning rates in LoRA GUI
This commit introduces a helper function, `get_effective_lr_messages`, into `kohya_gui/lora_gui.py` and integrates it into the `train_model` function. The purpose is to provide you with clearer information about how the learning rates set in the GUI (Main LR, Text Encoder LR, U-Net LR, T5XXL LR) will be interpreted and effectively applied by the underlying `sd-scripts` training engine. Before training commences, the GUI will now log: - The Main LR. - The effective LR for the primary Text Encoder (CLIP), indicating if it's a specific value or a fallback to the Main LR. - The effective LR for the T5XXL Text Encoder (if applicable), indicating its source (specific, inherited from primary TE, or fallback to Main LR). - The effective LR for the U-Net, indicating if it's a specific value or a fallback to the Main LR. This enhances transparency by helping you understand how your LR settings interact, without modifying the `sd-scripts` submodule.
1 parent d63a7fa commit 69d8b96

File tree

1 file changed

+70
-1
lines changed

1 file changed

+70
-1
lines changed

kohya_gui/lora_gui.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,64 @@ def open_configuration(
681681
return tuple(values)
682682

683683

684+
def get_effective_lr_messages(
685+
main_lr_val: float,
686+
text_encoder_lr_val: float, # Value from the 'Text Encoder learning rate' GUI field
687+
unet_lr_val: float, # Value from the 'Unet learning rate' GUI field
688+
t5xxl_lr_val: float # Value from the 'T5XXL learning rate' GUI field
689+
) -> list[str]:
690+
messages = []
691+
# Format LRs to scientific notation with 2 decimal places for readability
692+
f_main_lr = f"{main_lr_val:.2e}"
693+
f_te_lr = f"{text_encoder_lr_val:.2e}"
694+
f_unet_lr = f"{unet_lr_val:.2e}"
695+
f_t5_lr = f"{t5xxl_lr_val:.2e}"
696+
697+
messages.append("Effective Learning Rate Configuration (based on GUI settings):")
698+
messages.append(f" - Main LR (for optimizer & fallback): {f_main_lr}")
699+
700+
# --- Text Encoder (Primary/CLIP) LR ---
701+
# If text_encoder_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
702+
effective_clip_lr_str = f_main_lr
703+
clip_lr_source_msg = "(Fallback to Main LR)"
704+
if text_encoder_lr_val != 0.0:
705+
effective_clip_lr_str = f_te_lr
706+
clip_lr_source_msg = "(Specific Value)"
707+
messages.append(f" - Text Encoder (Primary/CLIP) Effective LR: {effective_clip_lr_str} {clip_lr_source_msg}")
708+
709+
# --- Text Encoder (T5XXL, if applicable) LR ---
710+
# Logic based on how text_encoder_lr_list is formed in train_model for sd-scripts:
711+
# 1. If t5xxl_lr_val is non-zero, it's used for T5.
712+
# 2. Else, if text_encoder_lr_val (primary TE LR) is non-zero, it's used for T5.
713+
# 3. Else (both primary TE LR and specific T5XXL LR are zero), T5 uses main_lr_val.
714+
effective_t5_lr_str = f_main_lr # Default fallback
715+
t5_lr_source_msg = "(Fallback to Main LR)"
716+
717+
if t5xxl_lr_val != 0.0:
718+
effective_t5_lr_str = f_t5_lr
719+
t5_lr_source_msg = "(Specific T5XXL Value)"
720+
elif text_encoder_lr_val != 0.0: # No specific T5 LR, but main TE LR is set
721+
effective_t5_lr_str = f_te_lr # T5 inherits from the primary TE LR setting
722+
t5_lr_source_msg = "(Inherited from Primary TE LR)"
723+
# If both t5xxl_lr_val and text_encoder_lr_val are 0.0, effective_t5_lr_str remains f_main_lr.
724+
725+
# The message for T5XXL LR is always added for completeness, indicating its potential value.
726+
# Users should understand it's relevant only if their model architecture uses a T5XXL text encoder.
727+
messages.append(f" - Text Encoder (T5XXL, if applicable) Effective LR: {effective_t5_lr_str} {t5_lr_source_msg}")
728+
729+
# --- U-Net LR ---
730+
# If unet_lr_val (from GUI) is non-zero, it's used. Otherwise, main_lr_val is the fallback.
731+
effective_unet_lr_str = f_main_lr
732+
unet_lr_source_msg = "(Fallback to Main LR)"
733+
if unet_lr_val != 0.0:
734+
effective_unet_lr_str = f_unet_lr
735+
unet_lr_source_msg = "(Specific Value)"
736+
messages.append(f" - U-Net Effective LR: {effective_unet_lr_str} {unet_lr_source_msg}")
737+
738+
messages.append("Note: These LRs reflect the GUI's direct settings. Advanced options in sd-scripts (e.g., block LRs, LoRA+) can further modify rates for specific layers.")
739+
return messages
740+
741+
684742
def train_model(
685743
headless,
686744
print_only,
@@ -1426,10 +1484,21 @@ def train_model(
14261484
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
14271485
)
14281486
unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0
1487+
t5xxl_lr_float = float(t5xxl_lr) if t5xxl_lr is not None else 0.0
1488+
1489+
# Log effective learning rate messages
1490+
lr_messages = get_effective_lr_messages(
1491+
learning_rate_float,
1492+
text_encoder_lr_float,
1493+
unet_lr_float,
1494+
t5xxl_lr_float
1495+
)
1496+
for message in lr_messages:
1497+
log.info(message)
14291498

14301499
# Determine the training configuration based on learning rate values
14311500
# Sets flags for training specific components based on the provided learning rates.
1432-
if learning_rate_float == unet_lr_float == text_encoder_lr_float == 0:
1501+
if learning_rate_float == 0.0 and text_encoder_lr_float == 0.0 and unet_lr_float == 0.0:
14331502
output_message(msg="Please input learning rate values.", headless=headless)
14341503
return TRAIN_BUTTON_VISIBLE
14351504
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.

0 commit comments

Comments
 (0)