diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index c3bf3dc70..3a44be3b6 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -106,6 +106,26 @@ enable_single_replica_ckpt_restoring: False seed: 0 audio_format: "s16" +# LoRA parameters +enable_lora: False + +# Distilled LoRA +# lora_config: { +# lora_model_name_or_path: ["Lightricks/LTX-2"], +# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"], +# adapter_name: ["distilled-lora-384"], +# rank: [384] +# } + +# Standard LoRA +lora_config: { + lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"], + weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"], + adapter_name: ["camera-control-dolly-in"], + rank: [32] +} + + # LTX-2 Latent Upsampler run_latent_upsampler: False upsampler_model_path: "Lightricks/LTX-2" diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py index fa8c2c46d..516e6f2ea 100644 --- a/src/maxdiffusion/generate_ltx2.py +++ b/src/maxdiffusion/generate_ltx2.py @@ -25,6 +25,7 @@ from google.api_core.exceptions import GoogleAPIError import flax from maxdiffusion.utils.export_utils import export_to_video_with_audio +from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader def upload_video_to_gcs(output_dir: str, video_path: str): @@ -120,6 +121,31 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): run_latent_upsampler = getattr(config, "run_latent_upsampler", False) pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler) + # If LoRA is specified, inject layers and load weights. + if ( + getattr(config, "enable_lora", False) + and hasattr(config, "lora_config") + and config.lora_config + and config.lora_config.get("lora_model_name_or_path") + ): + lora_loader = LTX2NNXLoraLoader() + lora_config = config.lora_config + paths = lora_config["lora_model_name_or_path"] + weights = lora_config.get("weight_name", [None] * len(paths)) + scales = lora_config.get("scale", [1.0] * len(paths)) + ranks = lora_config.get("rank", [64] * len(paths)) + + for i in range(len(paths)): + pipeline = lora_loader.load_lora_weights( + pipeline, + paths[i], + transformer_weight_name=weights[i], + rank=ranks[i], + scale=scales[i], + scan_layers=config.scan_layers, + dtype=config.weights_dtype, + ) + pipeline.enable_vae_slicing() pipeline.enable_vae_tiling() diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index 96bdb0c84..ca0371b76 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -703,3 +703,98 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" return None + + +def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): + """ + Translates LTX2 NNX path to Diffusers/LoRA keys. + """ + # --- 2. Map NNX Suffixes to LoRA Suffixes --- + suffix_map = { + # Self Attention (attn1) + "attn1.to_q": "attn1.to_q", + "attn1.to_k": "attn1.to_k", + "attn1.to_v": "attn1.to_v", + "attn1.to_out": "attn1.to_out.0", + # Audio Self Attention (audio_attn1) + "audio_attn1.to_q": "audio_attn1.to_q", + "audio_attn1.to_k": "audio_attn1.to_k", + "audio_attn1.to_v": "audio_attn1.to_v", + "audio_attn1.to_out": "audio_attn1.to_out.0", + # Audio Cross Attention (audio_attn2) + "audio_attn2.to_q": "audio_attn2.to_q", + "audio_attn2.to_k": "audio_attn2.to_k", + "audio_attn2.to_v": "audio_attn2.to_v", + "audio_attn2.to_out": "audio_attn2.to_out.0", + # Cross Attention (attn2) + "attn2.to_q": "attn2.to_q", + "attn2.to_k": "attn2.to_k", + "attn2.to_v": "attn2.to_v", + "attn2.to_out": "attn2.to_out.0", + # Audio to Video Cross Attention + "audio_to_video_attn.to_q": "audio_to_video_attn.to_q", + "audio_to_video_attn.to_k": "audio_to_video_attn.to_k", + "audio_to_video_attn.to_v": "audio_to_video_attn.to_v", + "audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0", + # Video to Audio Cross Attention + "video_to_audio_attn.to_q": "video_to_audio_attn.to_q", + "video_to_audio_attn.to_k": "video_to_audio_attn.to_k", + "video_to_audio_attn.to_v": "video_to_audio_attn.to_v", + "video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0", + # Feed Forward + "ff.net_0": "ff.net.0.proj", + "ff.net_2": "ff.net.2", + # Audio Feed Forward + "audio_ff.net_0": "audio_ff.net.0.proj", + "audio_ff.net_2": "audio_ff.net.2", + } + + # --- 3. Translation Logic --- + global_map = { + "proj_in": "diffusion_model.patchify_proj", + "audio_proj_in": "diffusion_model.audio_patchify_proj", + "proj_out": "diffusion_model.proj_out", + "audio_proj_out": "diffusion_model.audio_proj_out", + "time_embed.linear": "diffusion_model.adaln_single.linear", + "audio_time_embed.linear": "diffusion_model.audio_adaln_single.linear", + "av_cross_attn_video_a2v_gate.linear": "diffusion_model.av_ca_a2v_gate_adaln_single.linear", + "av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear", + "av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear", + "av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear", + # Nested conditioning layers + "time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1", + "time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2", + "audio_time_embed.emb.timestep_embedder.linear_1": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1", + "audio_time_embed.emb.timestep_embedder.linear_2": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2", + "caption_projection.linear_1": "diffusion_model.caption_projection.linear_1", + "caption_projection.linear_2": "diffusion_model.caption_projection.linear_2", + "audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1", + "audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2", + # Connectors + "feature_extractor.linear": "text_embedding_projection.aggregate_embed", + } + + if nnx_path_str in global_map: + return global_map[nnx_path_str] + + if scan_layers: + if nnx_path_str.startswith("transformer_blocks."): + inner_suffix = nnx_path_str[len("transformer_blocks.") :] + if inner_suffix in suffix_map: + return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}" + else: + m = re.match(r"^transformer_blocks\.(\d+)\.(.+)$", nnx_path_str) + if m: + idx, inner_suffix = m.group(1), m.group(2) + if inner_suffix in suffix_map: + return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}" + + return None diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py new file mode 100644 index 000000000..247b3ba2e --- /dev/null +++ b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py @@ -0,0 +1,75 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX-based LoRA loader for LTX2 models.""" + +from flax import nnx +from .lora_base import LoRABaseMixin +from .lora_pipeline import StableDiffusionLoraLoaderMixin +from ..models import lora_nnx +from .. import max_logging +from . import lora_conversion_utils + + +class LTX2NNXLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based LTX2 model. + Assumes LTX2 pipeline contains 'transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + transformer_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + dtype: str = "float32", + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + + def translate_fn(nnx_path_str): + return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + h_state_dict = None + if hasattr(pipeline, "transformer") and transformer_weight_name: + max_logging.log(f"Merging LoRA into transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + # Filter state dict for transformer keys to avoid confusing warnings + transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")} + merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("transformer not found or no weight name provided for LoRA.") + + if hasattr(pipeline, "connectors"): + max_logging.log(f"Merging LoRA into connectors with rank={rank}") + if h_state_dict is None and transformer_weight_name: + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + + if h_state_dict is not None: + # Filter state dict for connector keys to avoid confusing warnings + connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")} + merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("Could not load LoRA state dict for connectors.") + + return pipeline