diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 33deda4..a9edd80 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -78,17 +78,16 @@ def inject_trainable_lora( def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Attention"]): - - loras = [] + no_injection = True for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: for _child_module in _module.modules(): if _child_module.__class__.__name__ == "LoraInjectedLinear": - loras.append((_child_module.lora_up, _child_module.lora_down)) - if len(loras) == 0: + no_injection = False + yield (_child_module.lora_up, _child_module.lora_down) + if no_injection raise ValueError("No lora injected.") - return loras def save_lora_weight(