From d13a554b1c2f9214c301b5c07e18c64fce13c6b6 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 16:14:40 -0700 Subject: [PATCH 01/31] fix(ltx2): use pure source latents as i2v denoise-mask target At `strength < 1.0` (the `--strength 0.75` LTX-2 i2v default), `run_real_distilled_stage` was cloning `video_latents` *after* `apply_stage_video_conditioning` had already soft-blended the first latent frame positions with noise: the "clean reference" tensor that the per-step denoise-mask blend pulls conditioned tokens toward became `noise*(1-s) + source*s` at replacement positions. Used as the clean target, that pre-blended tensor pinned the first latent to a noisy ghost of the image at every step, so i2v runs produced a first frame that was 25 % noise + 75 % image instead of the source image. Introduce a `clean_latents_for_conditioning` helper that re-applies the replacement-based conditioning with `strength = 1.0` on top of the post-apply tensor, overwriting replacement positions with pure source tokens while appended keyframe tokens and pure-noise regions pass through unchanged. `strength = 1.0` and pure-T2V paths remain bit-for-bit identical. Two regression tests cover the soft-blended case and the no-replacements passthrough. Co-Authored-By: Claude Opus 4.7 (1M context) --- CHANGELOG.md | 1 + crates/mold-inference/src/ltx2/runtime.rs | 109 ++++++++++++++++++++-- 2 files changed, 101 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51e20f6e..af96fa71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- **LTX-2 image-to-video no longer locks the first latent frame to a noisy ghost of the source image at `strength < 1.0`.** In `run_real_distilled_stage` (`crates/mold-inference/src/ltx2/runtime.rs`) the "clean reference" that the per-step denoise-mask blend pulls the conditioned tokens toward was sourced by cloning `video_latents` *after* `apply_stage_video_conditioning` had already soft-blended the first-latent-frame positions with the initial noise (`noise*(1-s) + source*s`). Used as the clean target, that pre-blended tensor pinned the first latent to a noisy copy of the image instead of the pure image at every step — so i2v runs with `--strength 0.75` (the CLI default) produced a first frame that was 25 % noise + 75 % image rather than the source image. A new helper `clean_latents_for_conditioning` re-applies the replacements with strength 1.0 on top of the post-apply tensor so replacement positions hold pure source image tokens while appended keyframe tokens and pure-noise regions pass through unchanged. `strength = 1.0` and pure-T2V paths are bit-for-bit identical to before. Covered by two new regression tests (`clean_latents_replace_soft_blended_positions_with_pure_source`, `clean_latents_passthrough_when_no_replacements`). - **city96-format FLUX fine-tune GGUFs now fail with an honest, actionable error when no dev reference is downloaded, and surface the dependency at pull time instead of inside `ensure_gguf_embeddings`.** Community fine-tune GGUFs (e.g. the `silveroxides/ultrareal-fine-tune-GGUF` tree that powers `ultrareal-v4:q{8,5,4}`) ship only the diffusion blocks and expect the base FLUX input embedding layers (`img_in`, `time_in`, `vector_in`, `guidance_in`) to be patched in from a separately-downloaded flux-dev reference. Two bugs made this fail confusingly: (1) `find_flux_reference_gguf` in `crates/mold-inference/src/flux/pipeline.rs` returned the first candidate with `img_in.weight`, which let `flux-schnell:q8` pass the probe even though schnell is distilled without `guidance_in` — the subsequent patch loop bailed with `reference GGUF (.../flux-schnell-q8/flux1-schnell-Q8_0.gguf) is also missing required tensor 'guidance_in.in_layer.weight'`, making it look like schnell itself was broken. (2) The manifest didn't express the dependency at all, so the first indication a user had that `mold pull ultrareal-v4:q8` wasn't self-sufficient was an HTTP 500 on their first generation. Fixed by (a) adding a `needs_guidance: bool` parameter to `find_flux_reference_gguf` that skips schnell candidates for dev-family targets and verifies candidates contain `guidance_in.in_layer.weight` before accepting them, (b) rewriting both error messages so the source model is named and the reference path is shown as a filename rather than a full path, and (c) adding a pull-time probe in `crates/mold-core/src/download.rs` (`warn_if_flux_gguf_needs_reference`) that scans the first 4 MiB of any downloaded `.gguf` transformer for `img_in.weight`, and prints a one-line warning via the download callback when the GGUF is incomplete and no suitable dev reference is already on disk. Works for both the CLI (`pull_model`) and server (`pull_model_with_callback`) paths. New regression test `find_flux_reference_skips_schnell_when_dev_needed` covers the reference-picker behaviour. - **Prompt expansion can no longer OOM on a multi-GPU box with a tight main GPU.** `LocalExpander` previously hardcoded `gpu_ordinal: 0` and gated placement with a static 2 GB VRAM threshold — on a dual-GPU system with a busy main card it fell back to CPU unnecessarily, and on a q8/bf16 expand model (4+ GB weights) the 2 GB threshold under-budgeted activations so the GPU placement check could pass and the load then OOM. The expander now sizes its budget dynamically (`model_size + 2 GB activations`, matching the T5/Qwen3 pattern) and cascades through devices: main GPU → remaining GPUs in ordinal order → CPU, with `preflight_memory_check()` as the final hard-fail guard when system RAM can't hold the model either. Unified-memory Metal placements also run the RAM preflight (Metal allocations draw from the same pool). Device selection logic is factored into a pure `select_expand_device(gpus, threshold, is_metal) -> ExpandPlacement` helper with unit tests for every branch. diff --git a/crates/mold-inference/src/ltx2/runtime.rs b/crates/mold-inference/src/ltx2/runtime.rs index 781ab3c4..72135d92 100644 --- a/crates/mold-inference/src/ltx2/runtime.rs +++ b/crates/mold-inference/src/ltx2/runtime.rs @@ -1379,6 +1379,36 @@ fn apply_video_token_replacements( Ok(patched) } +/// Build the "clean reference" tensor used by the denoise mask blend at every +/// step. For replacement-based conditioning (e.g. i2v source image) with +/// `strength < 1.0`, `video_latents` already holds `noise*(1-s) + source*s` at +/// the replacement positions. If we reuse that as the clean target, the +/// denoise-mask blend pulls those tokens toward a noisy ghost of the image at +/// every step — the first latent frame never converges to the pure source. +/// +/// Re-applying the replacements with strength 1.0 overwrites those positions +/// with the pure source tokens, leaving appended keyframe tokens (already +/// full-strength in `apply_appended_video_conditioning`) and pure-noise +/// regions untouched. +fn clean_latents_for_conditioning( + video_latents: &Tensor, + conditioning: &StageVideoConditioning, +) -> Result { + if conditioning.replacements.is_empty() { + return Ok(video_latents.clone()); + } + let hard_replacements: Vec = conditioning + .replacements + .iter() + .map(|replacement| VideoTokenReplacement { + start_token: replacement.start_token, + tokens: replacement.tokens.clone(), + strength: 1.0, + }) + .collect(); + apply_video_token_replacements(video_latents, &hard_replacements) +} + fn apply_appended_video_conditioning( video_latents: &Tensor, video_positions: &Tensor, @@ -2699,7 +2729,7 @@ fn run_real_distilled_stage( )?; let clean_video_latents = match video_clean_latents { Some(latents) => video_patchifier.patchify(latents)?, - None => video_latents.clone(), + None => clean_latents_for_conditioning(&video_latents, video_conditioning)?, }; let video_denoise_mask = match video_denoise_mask { Some(mask) => mask.to_device(&device)?.to_dtype(DType::F32)?, @@ -4622,14 +4652,14 @@ mod tests { use super::{ apply_stage_video_conditioning, apply_video_token_replacements, - build_video_conditioning_self_attention_mask, convert_velocity_to_x0, - convert_x0_to_velocity, decoded_video_to_frames, effective_native_guidance_scale, - emit_denoise_progress, guided_velocity_from_cfg, keyframe_only_conditioning, - ltx2_video_transformer_config, reapply_stage_video_conditioning, - should_inspect_step_velocity, source_image_only_conditioning, - strip_appended_video_conditioning, Ltx2RuntimeSession, StageVideoConditioning, - VideoTokenAppendCondition, VideoTokenReplacement, LTX2_AUDIO_LATENT_CHANNELS, - LTX2_VIDEO_LATENT_CHANNELS, + build_video_conditioning_self_attention_mask, clean_latents_for_conditioning, + convert_velocity_to_x0, convert_x0_to_velocity, decoded_video_to_frames, + effective_native_guidance_scale, emit_denoise_progress, guided_velocity_from_cfg, + keyframe_only_conditioning, ltx2_video_transformer_config, + reapply_stage_video_conditioning, should_inspect_step_velocity, + source_image_only_conditioning, strip_appended_video_conditioning, Ltx2RuntimeSession, + StageVideoConditioning, VideoTokenAppendCondition, VideoTokenReplacement, + LTX2_AUDIO_LATENT_CHANNELS, LTX2_VIDEO_LATENT_CHANNELS, }; use crate::ltx2::conditioning::{self, StagedConditioning}; use crate::ltx2::model::VideoPixelShape; @@ -5693,6 +5723,67 @@ mod tests { ); } + #[test] + fn clean_latents_replace_soft_blended_positions_with_pure_source() { + // Simulate the state after `apply_stage_video_conditioning` with + // strength 0.75: at the replacement positions, `video_latents` already + // holds `noise*0.25 + source*0.75`. The denoise-mask blend uses + // `clean_latents` as the target it pulls those positions toward at + // every step — so the clean target must be pure source, not the + // pre-blended mix. + let noise = [0.0f32, 0.0, 1.0, 1.0, 2.0, 2.0]; + let source = [10.0f32, 10.0]; + let strength = 0.75f32; + let blended_first = [ + noise[0] * (1.0 - strength) + source[0] * strength, + noise[1] * (1.0 - strength) + source[1] * strength, + ]; + let soft_blended = Tensor::from_vec( + vec![ + blended_first[0], + blended_first[1], + noise[2], + noise[3], + noise[4], + noise[5], + ], + (1, 3, 2), + &Device::Cpu, + ) + .unwrap(); + let conditioning = StageVideoConditioning { + replacements: vec![VideoTokenReplacement { + start_token: 0, + tokens: Tensor::from_vec(source.to_vec(), (1, 1, 2), &Device::Cpu).unwrap(), + strength: strength as f64, + }], + appended: vec![], + }; + + let clean = clean_latents_for_conditioning(&soft_blended, &conditioning).unwrap(); + let values = clean.flatten_all().unwrap().to_vec1::().unwrap(); + + assert_eq!( + values, + vec![source[0], source[1], noise[2], noise[3], noise[4], noise[5]], + "soft-blended replacement positions must be overwritten with the pure \ + source tokens; other positions must be preserved unchanged" + ); + } + + #[test] + fn clean_latents_passthrough_when_no_replacements() { + let latents = + Tensor::from_vec(vec![0.0f32, 1.0, 2.0, 3.0], (1, 2, 2), &Device::Cpu).unwrap(); + let conditioning = StageVideoConditioning::default(); + + let clean = clean_latents_for_conditioning(&latents, &conditioning).unwrap(); + assert_eq!( + clean.flatten_all().unwrap().to_vec1::().unwrap(), + vec![0.0, 1.0, 2.0, 3.0] + ); + } + #[test] fn video_conditioning_self_attention_mask_blocks_cross_keyframe_attention() { let conditioning = StageVideoConditioning { From b4ed487578c49a202eaeead3eea2526186bb6817 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 16:26:21 -0700 Subject: [PATCH 02/31] feat(chain): add core wire types and request normalisation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce `mold_core::chain` with the `ChainStage` / `ChainRequest` / `ChainResponse` types that will carry server-side chained LTX-2 video generation. The wire format is stages-based from day one so the v2 movie-maker UI can author multi-prompt / multi-keyframe chains without breaking callers: v1 only exposes a single-prompt auto-expand form (`prompt` + `total_frames` + `clip_frames`), and `normalise()` collapses it into a canonical `Vec` before any engine work runs. Normalisation matches the stitch math that Phase 1.4 of the plan will use: delivered_frames = clip_frames + (N - 1) * (clip_frames - motion_tail) so auto-expand picks `N` large enough to cover `total_frames` with tail-overlap trimming in mind; the over-production is discarded from the final clip's tail per the 2026-04-20 sign-off. Guardrails cap chains at 16 stages (≈1552 frames at 97-frame clips, ~64 s at 24 fps), require `8k+1` frame counts for LTX-2, and forbid `motion_tail_frames >= clip_frames` so every continuation emits at least one new frame. Also lifts the existing `base64_opt` serde helper in `types.rs` from private to `pub(crate)` so chain types can share the single source of truth for base64 wire encoding. Unit tests cover: split-into-stages, first-stage-image preservation, empty-request rejection, non-8k+1 rejection, canonical-form passthrough, single-stage short chains, >16-stage guardrails, motion-tail >= clip rejection, missing auto-expand fields, and a property test confirming the auto-expand stage count delivers the requested total frames under every representative (total, clip, tail) combo from the design. tasks/render-chain-v1-plan.md adds the signed-off decisions block at the top so the rationale travels with the code. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/mold-core/src/chain.rs | 596 ++++++++++++++++++++++++++++++++++ crates/mold-core/src/lib.rs | 2 + crates/mold-core/src/types.rs | 2 +- tasks/render-chain-v1-plan.md | 410 +++++++++++++++++++++++ 4 files changed, 1009 insertions(+), 1 deletion(-) create mode 100644 crates/mold-core/src/chain.rs create mode 100644 tasks/render-chain-v1-plan.md diff --git a/crates/mold-core/src/chain.rs b/crates/mold-core/src/chain.rs new file mode 100644 index 00000000..101e479e --- /dev/null +++ b/crates/mold-core/src/chain.rs @@ -0,0 +1,596 @@ +//! Wire types for server-side chained video generation. +//! +//! A *chain* is a sequence of per-clip render stages stitched into a single +//! output video. The v1 CLI UX is single-prompt + arbitrary length, but the +//! wire format is stages-based from day one so the eventual movie-maker +//! (multi-prompt, keyframes, selective regen) can author stages by hand +//! without a breaking change. +//! +//! The server only ever sees the canonical [`ChainRequest`] shape — a +//! `Vec`. Callers can either build that directly or use the +//! auto-expand form (`prompt` + `total_frames` + `clip_frames`), which +//! [`ChainRequest::normalise`] collapses into stages. +//! +//! See `tasks/render-chain-v1-plan.md` for the full design rationale. + +use serde::{Deserialize, Serialize}; + +use crate::error::{MoldError, Result}; +use crate::types::{DevicePlacement, OutputFormat, VideoData}; + +/// A single rendered clip in a chain. Concatenated in order with motion-tail +/// trimming on continuations (stages with `idx >= 1` drop the leading +/// `motion_tail_frames` pixel frames of their output because those duplicate +/// the tail of the previous stage that the engine carried across as +/// latent-space conditioning). +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ChainStage { + /// Prompt used for this stage. In v1 all stages receive the same prompt + /// (auto-expand form replicates it); the movie-maker UI in v2 will let + /// users author per-stage prompts. + #[schema(example = "a cat walking through autumn leaves")] + pub prompt: String, + + /// Frame count for this stage. Must be `8k+1` (LTX-2 pipeline constraint: + /// 9, 17, 25, …, 97). + #[schema(example = 97)] + pub frames: u32, + + /// Optional starting image (raw PNG/JPEG bytes, base64 in JSON). In v1 + /// this is only meaningful on `stages[0]`; later stages draw their + /// conditioning from the prior stage's motion-tail latents instead. + #[serde( + default, + skip_serializing_if = "Option::is_none", + with = "crate::types::base64_opt" + )] + pub source_image: Option>, + + /// Optional negative prompt for CFG-based stages. v1 LTX-2 ignores this + /// (the distilled family doesn't use CFG); the field is reserved so the + /// movie-maker can round-trip it without re-migrating the wire format. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub negative_prompt: Option, + + /// Optional per-stage seed offset. `None` in v1 — the orchestrator + /// derives each stage's seed from the chain's base seed. Reserved as the + /// v2 movie-maker override hook for "regenerate just this stage with a + /// different seed". + #[serde(default, skip_serializing_if = "Option::is_none")] + pub seed_offset: Option, +} + +/// Chained generation request. Server accepts either the canonical form +/// (`stages` non-empty) or the auto-expand form (`prompt` + `total_frames` + +/// `clip_frames`); [`ChainRequest::normalise`] collapses the latter into the +/// former so downstream code only deals with `stages`. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ChainRequest { + #[schema(example = "ltx-2-19b-distilled:fp8")] + pub model: String, + + /// Canonical stages list. Empty triggers auto-expand from + /// `prompt`/`total_frames`/`clip_frames`. + #[serde(default)] + pub stages: Vec, + + /// Pixel frames of motion-tail overlap between consecutive stages. + /// `0` = no overlap (simple concat). `>0` = the final K pixel frames of + /// stage N's latents are threaded into stage N+1's conditioning, and + /// stage N+1's leading K output frames are dropped at stitch time. + /// + /// Defaults to `4` for v1 (matches the CLI default). Must be strictly + /// less than each stage's `frames`. + #[serde(default = "default_motion_tail_frames")] + #[schema(example = 4)] + pub motion_tail_frames: u32, + + #[schema(example = 1216)] + pub width: u32, + #[schema(example = 704)] + pub height: u32, + #[serde(default = "default_fps")] + #[schema(example = 24)] + pub fps: u32, + + /// Chain base seed. Per-stage seeds are derived as + /// `base_seed ^ ((stage_idx as u64) << 32)` by the orchestrator so the + /// whole chain is reproducible from a single seed value. + #[serde(default, skip_serializing_if = "Option::is_none")] + #[schema(example = 42)] + pub seed: Option, + + #[schema(example = 8)] + pub steps: u32, + + #[schema(example = 3.0)] + pub guidance: f64, + + /// Denoising strength for `stages[0].source_image`. Ignored when the + /// first stage has no source image. Continuation stages are always + /// full-strength conditioned via motion-tail latents. + #[serde(default = "default_strength")] + #[schema(example = 1.0)] + pub strength: f64, + + #[serde(default = "default_output_format")] + pub output_format: OutputFormat, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub placement: Option, + + // ── Auto-expand form ──────────────────────────────────────────────── + // These are only read when `stages` is empty; `normalise` clears them + // after expansion so the canonical form only ever carries `stages`. + /// Auto-expand: single prompt replicated across all stages. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub prompt: Option, + + /// Auto-expand: total pixel frames the stitched output should cover. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub total_frames: Option, + + /// Auto-expand: per-clip frame count. Defaults to `97` (LTX-2 19B/22B + /// distilled cap). Must be `8k+1`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub clip_frames: Option, + + /// Auto-expand: starting image for `stages[0]`. + #[serde( + default, + skip_serializing_if = "Option::is_none", + with = "crate::types::base64_opt" + )] + pub source_image: Option>, +} + +/// Response from a chained generation request. The `video` is the stitched +/// output; individual per-stage clips are not returned. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ChainResponse { + pub video: VideoData, + /// Number of stages that actually ran (matches `request.stages.len()` + /// after normalisation). + #[schema(example = 5)] + pub stage_count: u32, + /// GPU ordinal that handled the chain (multi-GPU servers only). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub gpu: Option, +} + +fn default_motion_tail_frames() -> u32 { + 4 +} + +fn default_fps() -> u32 { + 24 +} + +fn default_strength() -> f64 { + 1.0 +} + +fn default_output_format() -> OutputFormat { + OutputFormat::Mp4 +} + +/// Maximum number of stages the v1 orchestrator will accept in a single +/// chain. 16 × 97-frame clips ≈ 1552 frames ≈ 64 s at 24 fps — comfortably +/// past the 400-frame target without risking runaway jobs. +pub const MAX_CHAIN_STAGES: usize = 16; + +impl ChainRequest { + /// Collapse the auto-expand form into a canonical `Vec` and + /// validate the result. Called once on the server side immediately after + /// JSON parsing, before any engine work kicks off. + /// + /// Post-conditions on a successful return: + /// - `self.stages` is non-empty. + /// - Each stage's `frames` is `8k+1` and `> 0`. + /// - `self.stages.len() <= MAX_CHAIN_STAGES`. + /// - All auto-expand fields are `None` (caller must use `self.stages`). + pub fn normalise(mut self) -> Result { + if self.stages.is_empty() { + let prompt = self.prompt.take().ok_or_else(|| { + MoldError::Validation( + "chain request needs either stages[] or prompt + total_frames".into(), + ) + })?; + let total_frames = self.total_frames.ok_or_else(|| { + MoldError::Validation("chain auto-expand requires total_frames".into()) + })?; + if total_frames == 0 { + return Err(MoldError::Validation( + "chain total_frames must be > 0".into(), + )); + } + let clip_frames = self.clip_frames.unwrap_or(97); + if clip_frames == 0 { + return Err(MoldError::Validation( + "chain clip_frames must be > 0".into(), + )); + } + if !is_ltx2_frame_count(clip_frames) { + return Err(MoldError::Validation(format!( + "chain clip_frames ({clip_frames}) must be 8k+1 (9, 17, 25, …, 97)", + ))); + } + let motion_tail = self.motion_tail_frames; + if motion_tail >= clip_frames { + return Err(MoldError::Validation(format!( + "motion_tail_frames ({motion_tail}) must be strictly less than clip_frames ({clip_frames})", + ))); + } + + let source_image = self.source_image.take(); + self.stages = build_auto_expand_stages( + &prompt, + total_frames, + clip_frames, + motion_tail, + source_image, + )?; + } + + if self.stages.is_empty() { + return Err(MoldError::Validation("chain request has no stages".into())); + } + if self.stages.len() > MAX_CHAIN_STAGES { + return Err(MoldError::Validation(format!( + "chain request has {} stages; maximum is {}", + self.stages.len(), + MAX_CHAIN_STAGES, + ))); + } + for (idx, stage) in self.stages.iter().enumerate() { + if stage.frames == 0 { + return Err(MoldError::Validation(format!("stage {idx} has 0 frames",))); + } + if !is_ltx2_frame_count(stage.frames) { + return Err(MoldError::Validation(format!( + "stage {idx} has {} frames; LTX-2 requires 8k+1 (9, 17, 25, …, 97)", + stage.frames, + ))); + } + if self.motion_tail_frames >= stage.frames { + return Err(MoldError::Validation(format!( + "motion_tail_frames ({}) must be strictly less than stage {idx}'s frames ({})", + self.motion_tail_frames, stage.frames, + ))); + } + } + + // Canonicalise: clear auto-expand fields so downstream code only + // ever reads from `stages`. + self.prompt = None; + self.total_frames = None; + self.clip_frames = None; + self.source_image = None; + + Ok(self) + } +} + +/// Returns `true` iff `n` has the form `8k + 1` for some non-negative integer +/// `k` (1, 9, 17, 25, …). The LTX-2 pipeline has this constraint on pixel +/// frame counts due to the VAE's 8× temporal compression with a causal first +/// frame. +fn is_ltx2_frame_count(n: u32) -> bool { + n % 8 == 1 +} + +/// Compute the stage count and per-stage frame allocation for the auto- +/// expand form, matching Phase 1.4's stitch math: +/// +/// - Stage 0 contributes `clip_frames` pixel frames. +/// - Each continuation contributes `clip_frames - motion_tail_frames` new +/// frames (the leading `motion_tail_frames` are dropped at stitch time +/// because they duplicate the prior stage's latent tail). +/// +/// Returns enough stages so the stitched total reaches at least +/// `total_frames`; over-production is trimmed from the tail at stitch time +/// per the signed-off decision 2026-04-20. +fn build_auto_expand_stages( + prompt: &str, + total_frames: u32, + clip_frames: u32, + motion_tail_frames: u32, + source_image: Option>, +) -> Result> { + let (stage_count, per_stage_frames) = if total_frames <= clip_frames { + // Single stage: match the user's requested length exactly so we + // don't render 97 frames and throw most of them away. The frame + // count will still be validated as 8k+1 by the caller. + (1u32, total_frames) + } else { + let effective = clip_frames - motion_tail_frames; + // effective > 0 because the caller has already ensured + // motion_tail_frames < clip_frames. + let remainder = total_frames - clip_frames; + let count = 1 + remainder.div_ceil(effective); + (count, clip_frames) + }; + + let count_usize = stage_count as usize; + if count_usize > MAX_CHAIN_STAGES { + return Err(MoldError::Validation(format!( + "auto-expand would produce {stage_count} stages; maximum is {MAX_CHAIN_STAGES} \ + (try reducing total_frames or increasing clip_frames)", + ))); + } + + let mut stages = Vec::with_capacity(count_usize); + for idx in 0..stage_count { + stages.push(ChainStage { + prompt: prompt.to_string(), + frames: per_stage_frames, + source_image: if idx == 0 { source_image.clone() } else { None }, + negative_prompt: None, + seed_offset: None, + }); + } + Ok(stages) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a minimal auto-expand request with the given knobs. All other + /// fields use their v1 defaults so tests can focus on the logic under + /// exercise. + fn auto_expand_request( + prompt: &str, + total_frames: u32, + clip_frames: u32, + motion_tail_frames: u32, + source_image: Option>, + ) -> ChainRequest { + ChainRequest { + model: "ltx-2-19b-distilled:fp8".into(), + stages: Vec::new(), + motion_tail_frames, + width: 1216, + height: 704, + fps: 24, + seed: Some(42), + steps: 8, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Mp4, + placement: None, + prompt: Some(prompt.into()), + total_frames: Some(total_frames), + clip_frames: Some(clip_frames), + source_image, + } + } + + fn canonical_request(stages: Vec, motion_tail_frames: u32) -> ChainRequest { + ChainRequest { + model: "ltx-2-19b-distilled:fp8".into(), + stages, + motion_tail_frames, + width: 1216, + height: 704, + fps: 24, + seed: Some(42), + steps: 8, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Mp4, + placement: None, + prompt: None, + total_frames: None, + clip_frames: None, + source_image: None, + } + } + + fn make_stage(frames: u32) -> ChainStage { + ChainStage { + prompt: "test".into(), + frames, + source_image: None, + negative_prompt: None, + seed_offset: None, + } + } + + #[test] + fn normalise_splits_single_prompt_into_stages() { + // total=400, clip=97, tail=4 → effective=93, remainder=303, + // N = 1 + ceil(303/93) = 1 + 4 = 5 stages of 97 frames each. + // Stitched = 97 + 4*93 = 469, which will be trimmed to 400 at + // stitch time (per the signed-off "trim from tail" decision). + let normalised = auto_expand_request("a cat walking", 400, 97, 4, None) + .normalise() + .expect("normalise should succeed"); + + assert_eq!( + normalised.stages.len(), + 5, + "400/97 with a 4-frame motion tail should expand to 5 stages", + ); + for stage in &normalised.stages { + assert_eq!(stage.frames, 97); + assert_eq!(stage.prompt, "a cat walking"); + assert!(stage.seed_offset.is_none()); + } + // Auto-expand fields are cleared post-normalisation. + assert!(normalised.prompt.is_none()); + assert!(normalised.total_frames.is_none()); + assert!(normalised.clip_frames.is_none()); + assert!(normalised.source_image.is_none()); + } + + #[test] + fn normalise_preserves_first_stage_image() { + let png = vec![0x89, 0x50, 0x4e, 0x47, 0xde, 0xad, 0xbe, 0xef]; + let normalised = auto_expand_request("test", 200, 97, 4, Some(png.clone())) + .normalise() + .expect("normalise should succeed"); + + assert!(normalised.stages.len() >= 2); + assert_eq!( + normalised.stages[0].source_image.as_deref(), + Some(png.as_slice()), + "stage 0 must carry the starting image", + ); + for stage in &normalised.stages[1..] { + assert!( + stage.source_image.is_none(), + "continuation stages must not carry a source image; conditioning flows \ + through motion-tail latents instead", + ); + } + } + + #[test] + fn normalise_rejects_empty() { + let mut req = canonical_request(Vec::new(), 4); + // No auto-expand fields either. + req.prompt = None; + req.total_frames = None; + + let err = req.normalise().expect_err("empty chain should fail"); + assert!( + matches!(err, MoldError::Validation(_)), + "empty chain should be a validation error, got {err:?}", + ); + } + + #[test] + fn normalise_rejects_non_8k1_frames() { + // Canonical form with a stage whose frames violates the 8k+1 + // constraint. + let req = canonical_request(vec![make_stage(50)], 4); + let err = req.normalise().expect_err("non-8k+1 frames should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("8k+1")), + "error must mention the 8k+1 constraint", + ); + } + + #[test] + fn normalise_accepts_canonical_form_unchanged() { + // Caller already built stages; normalise should validate and clear + // the (already-empty) auto-expand fields without touching stages. + let stages = vec![make_stage(97), make_stage(97), make_stage(97)]; + let normalised = canonical_request(stages.clone(), 4) + .normalise() + .expect("valid canonical form should pass"); + assert_eq!(normalised.stages.len(), 3); + for (left, right) in normalised.stages.iter().zip(stages.iter()) { + assert_eq!(left.frames, right.frames); + assert_eq!(left.prompt, right.prompt); + } + } + + #[test] + fn normalise_single_stage_when_total_leq_clip() { + // total=9 fits in one clip; don't render a full 97-frame stage and + // throw most of it away. + let normalised = auto_expand_request("short", 9, 97, 4, None) + .normalise() + .expect("short single-clip chain should pass"); + assert_eq!(normalised.stages.len(), 1); + assert_eq!(normalised.stages[0].frames, 9); + } + + #[test] + fn normalise_rejects_too_many_stages() { + // 17 canonical stages exceeds MAX_CHAIN_STAGES (16). + let stages = (0..17).map(|_| make_stage(97)).collect(); + let err = canonical_request(stages, 4) + .normalise() + .expect_err("17-stage chain should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("maximum")), + "error must mention the max-stages cap", + ); + } + + #[test] + fn normalise_rejects_auto_expand_too_long() { + // 16 × 97 = 1552 max stitched frames before trim; asking for + // 4000 frames should blow the guardrail. + let err = auto_expand_request("too long", 4000, 97, 4, None) + .normalise() + .expect_err("runaway auto-expand should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("stages")), + "error must name the stage count guardrail", + ); + } + + #[test] + fn normalise_rejects_motion_tail_ge_clip() { + // motion_tail must leave at least one new frame per continuation. + let err = auto_expand_request("bad tail", 200, 97, 97, None) + .normalise() + .expect_err("motion_tail >= clip should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("motion_tail_frames")), + "error must name motion_tail_frames", + ); + } + + #[test] + fn normalise_rejects_missing_total_frames_in_auto_expand() { + let mut req = canonical_request(Vec::new(), 4); + req.prompt = Some("missing total".into()); + // total_frames omitted. + let err = req + .normalise() + .expect_err("missing total_frames should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("total_frames")), + "error must name total_frames", + ); + } + + #[test] + fn is_ltx2_frame_count_matches_8k_plus_1() { + for valid in [1u32, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97] { + assert!( + is_ltx2_frame_count(valid), + "{valid} should be a valid LTX-2 frame count", + ); + } + for invalid in [0u32, 2, 8, 10, 16, 50, 96, 98, 100] { + assert!( + !is_ltx2_frame_count(invalid), + "{invalid} must not pass the 8k+1 check", + ); + } + } + + #[test] + fn build_stages_math_matches_stitch_budget() { + // Auto-expand must produce enough stages that the stitch delivers + // at least `total_frames` pixel frames. Stitch math: + // delivered = clip_frames + (N - 1) * (clip_frames - motion_tail) + let cases = [ + (400u32, 97u32, 4u32, 5u32), // 97 + 4*93 = 469 ≥ 400 + (200, 97, 4, 3), // 97 + 2*93 = 283 ≥ 200 + (97, 97, 4, 1), // single clip hits 97 exactly + (300, 97, 0, 4), // zero tail, 4*97 = 388 ≥ 300 + ]; + for (total, clip, tail, expected_n) in cases { + let req = auto_expand_request("m", total, clip, tail, None) + .normalise() + .expect("valid auto-expand should normalise"); + assert_eq!( + req.stages.len() as u32, + expected_n, + "expected {expected_n} stages for total={total}, clip={clip}, tail={tail}", + ); + let delivered = clip + (expected_n - 1) * (clip - tail); + assert!( + delivered >= total, + "{expected_n} stages deliver {delivered} frames but {total} were requested", + ); + } + } +} diff --git a/crates/mold-core/src/lib.rs b/crates/mold-core/src/lib.rs index a16bb81f..e7b2f3c1 100644 --- a/crates/mold-core/src/lib.rs +++ b/crates/mold-core/src/lib.rs @@ -1,5 +1,6 @@ pub mod build_info; pub mod catalog; +pub mod chain; pub mod client; pub mod config; pub mod control; @@ -18,6 +19,7 @@ mod config_test; mod test_support; pub use catalog::build_model_catalog; +pub use chain::{ChainRequest, ChainResponse, ChainStage, MAX_CHAIN_STAGES}; pub use client::MoldClient; pub use config::{ parse_device_ref_str, Config, DefaultModelResolution, DefaultModelSource, LoggingConfig, diff --git a/crates/mold-core/src/types.rs b/crates/mold-core/src/types.rs index ade380e1..72e95b90 100644 --- a/crates/mold-core/src/types.rs +++ b/crates/mold-core/src/types.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; /// Serde helpers for `Option>` as base64 in JSON. -mod base64_opt { +pub(crate) mod base64_opt { use base64::Engine as _; use serde::{Deserialize, Deserializer, Serializer}; diff --git a/tasks/render-chain-v1-plan.md b/tasks/render-chain-v1-plan.md new file mode 100644 index 00000000..d0324863 --- /dev/null +++ b/tasks/render-chain-v1-plan.md @@ -0,0 +1,410 @@ +# Render Chain v1 — Implementation Plan + +> Server-side chained video generation for LTX-2: generate videos of arbitrary length by stringing together multiple per-clip renders and stitching the results. v1 exposes a single-prompt/arbitrary-length UX; the request shape is **stages-based from day one** so the eventual movie-maker (multi-prompt, multi-keyframe) extends without a breaking change. + +## Confirmed design decisions (signed off 2026-04-20) + +1. **Trim over-production from the tail** of the final clip, not the head. The head carries the user's starting image anchor and is perceptually load-bearing; tail frames are the freshest continuation but cheapest to lose. +2. **Per-stage seed derivation: `stage_seed = base_seed ^ ((stage_idx as u64) << 32)`.** Deterministic, reproducible, avoids identical-noise artefacts when prompts match across stages. `ChainStage::seed_offset` stays reserved as the v2 movie-maker override hook. +3. **Fail closed on mid-chain failure.** If any stage errors, return 502 and discard all prior stages. No partial stitch is ever written to the gallery. Partial-resume is a v2 movie-maker feature. +4. **1 GB RAM ceiling for the accumulation buffer.** Hold decoded `RgbImage`s in memory through the stitch — acceptable for the 400-frame 1216×704 target. Revisit with streaming encode when someone pushes 1000+ frames. +5. **Single-GPU per chain.** The orchestrator runs every stage on the GPU the engine was loaded onto. Multi-GPU stage fan-out is a v2 perf win; docs mention it, code doesn't build it. + +**Goal:** `mold run ltx-2-19b-distilled:fp8 "a cat walking" --image cat.png --frames 400` produces a single 400-frame MP4, stitched from ~4 coherent sub-clips, each seeded by a motion tail of latents from the prior clip. + +**Scope (v1):** + +- LTX-2 only (other video engines intentionally out of scope). +- Single prompt replicated across all stages. Optional starting image on stage 0. +- Motion-tail carryover **using cached latents in-process** (no VAE re-encode between clips). +- Single stitched output to the gallery. No per-clip gallery rows, no `chain_id` grouping. +- Sequential execution (clip N+1 waits for N). Multi-GPU fan-out is v2. +- Server-side orchestration under a new `/api/generate/chain[/stream]` route. CLI auto-routes when `--frames > max_per_clip`. + +**Explicitly NOT in v1:** + +- Movie maker UI (that's v2, built on the same server API). +- Per-stage prompts/keyframes (the request shape supports them; the CLI doesn't expose them yet). +- Crossfade / colour-matching at clip boundaries. +- Pause/resume/retry of a partial chain. + +**Base branch:** `main` · **Feature branch:** `feat/render-chain-v1` · **PR target:** `main` + +--- + +## The compatibility contract + +The key architectural decision: **the wire format is already multi-stage.** v1 auto-synthesises the stages list from a single prompt + total length, but the server only ever sees the stages form. That means v2 (movie maker) is additive — the SPA just lets the user author the stages list by hand, no server breaking changes. + +```json +POST /api/generate/chain +{ + "model": "ltx-2-19b-distilled:fp8", + "stages": [ + { "prompt": "a cat walking", "frames": 97, "source_image": "" }, + { "prompt": "a cat walking", "frames": 97 }, + { "prompt": "a cat walking", "frames": 97 }, + { "prompt": "a cat walking", "frames": 97 } + ], + "motion_tail_frames": 4, + "width": 1216, "height": 704, "fps": 24, + "seed": 42, "steps": 8, "guidance": 3.0, "strength": 1.0, + "output_format": "mp4" +} +``` + +Or the auto-expand form (what v1 CLI sends): + +```json +POST /api/generate/chain +{ + "model": "ltx-2-19b-distilled:fp8", + "prompt": "a cat walking", + "total_frames": 400, + "clip_frames": 97, + "source_image": "", + "motion_tail_frames": 4, + "width": 1216, "height": 704, "fps": 24, + "seed": 42, "steps": 8, "guidance": 3.0, "strength": 1.0, + "output_format": "mp4" +} +``` + +Server-side, a canonicalising function collapses the auto-expand form into stages. From the engine's POV there's only ever a `Vec`. + +--- + +## File map + +### New + +``` +crates/mold-core/src/chain.rs -- ChainStage, ChainRequest, ChainResponse types +crates/mold-inference/src/ltx2/chain.rs -- LTX-2 chain orchestrator + latent-tail carry +crates/mold-server/src/routes_chain.rs -- POST /api/generate/chain[/stream] +``` + +### Modified + +``` +crates/mold-core/src/lib.rs -- re-export chain types +crates/mold-core/src/client.rs -- MoldClient::generate_chain[_stream]() +crates/mold-inference/src/ltx2/mod.rs -- pub use chain::{Ltx2ChainOrchestrator, ChainTail} +crates/mold-inference/src/ltx2/pipeline.rs -- expose internal render path that returns (VideoData, ChainTail) +crates/mold-inference/src/ltx2/runtime.rs -- thread ChainTail through run_real_distilled_stage +crates/mold-server/src/lib.rs -- route registration +crates/mold-server/src/queue.rs -- chain handler uses ModelCache but does NOT enqueue via the existing video job queue (reason in §3) +crates/mold-cli/src/main.rs -- auto-route --frames > clip_max to /api/generate/chain +crates/mold-cli/src/commands/generate.rs -- chain client + progress rendering +CHANGELOG.md +website/guide/video.md -- document --frames N and the chain endpoint +``` + +--- + +## Conventions + +- All new Rust code gets unit tests where the logic is pure (stage expansion, tail shape math, concat-drop math). The orchestrator's end-to-end path is covered by an integration test that swaps in a fake engine. +- `mold-inference` crate has `test = false` on the `lib` target — new tests in `ltx2/chain.rs` must either run under `#[cfg(test)] mod tests` with logic that doesn't touch candle weights, or use the fake-engine pattern. Keep tests weight-free. +- CLI manual UAT runs against BEAST (`MOLD_HOST=http://beast:7680`) with `ltx-2-19b-distilled:fp8`. +- Commit scopes: `feat(chain): …`, `fix(chain): …`, `test(chain): …`, `docs(chain): …`. +- Every task ends with a commit. No mid-plan push. + +--- + +## Phases + +### Phase 0 — core types (no-op at runtime) + +**0.1. Add `mold-core::chain` module with wire types.** + +```rust +// crates/mold-core/src/chain.rs +pub struct ChainStage { + pub prompt: String, + pub frames: u32, + pub source_image: Option>, // PNG bytes + pub negative_prompt: Option, // future-proof; v1 ignores if Some + pub seed_offset: Option, // v2 hook; v1 derives from base seed +} + +pub struct ChainRequest { + pub model: String, + pub stages: Vec, // canonical form + #[serde(default)] + pub motion_tail_frames: u32, // 0 = single-frame handoff; >0 = multi-frame tail + pub width: u32, pub height: u32, pub fps: u32, + pub seed: Option, pub steps: u32, pub guidance: f64, + pub strength: f64, // applied to stage[0].source_image only + pub output_format: OutputFormat, + pub placement: Option, + // auto-expand form (server normalises): + pub prompt: Option, + pub total_frames: Option, + pub clip_frames: Option, + pub source_image: Option>, +} + +pub struct ChainResponse { pub video: VideoData, pub stage_count: u32, pub gpu: Option } +``` + +- Add a `normalise(self) -> Result` that collapses the auto-expand fields into stages when `stages.is_empty()`. +- Validation: at least one stage, each stage has `frames` satisfying 8k+1 and > 0, total stages × clip_frames ≤ 16 (early guardrail — users aren't generating feature films with this yet). +- Tests: `normalise_splits_single_prompt_into_stages`, `normalise_preserves_first_stage_image`, `normalise_rejects_empty`, `normalise_rejects_non_8k1_frames`. + +Commit: `feat(chain): add core wire types and request normalisation`. + +**0.2. Re-export from `mold_core`, add `MoldClient::generate_chain`/`generate_chain_stream`.** + +Mirror the existing `generate` / `generate_stream` shape. No server changes yet — client just has the surface area. + +Commit: `feat(core): MoldClient chain methods`. + +--- + +### Phase 1 — LTX-2 chain orchestrator (single GPU, in-process) + +**1.1. Define `ChainTail` as the latent-carryover payload.** + +```rust +// crates/mold-inference/src/ltx2/chain.rs +pub struct ChainTail { + pub frames: u32, // number of pixel frames this tail represents + pub latents: Tensor, // [1, C, tail_latent_frames, H/32, W/32] on the engine device + pub last_rgb_frame: RgbImage, // for fallback + debugging +} +``` + +The VAE temporal ratio is 8 with causal first frame, so `tail_latent_frames = ((tail_pixel_frames - 1) / 8 + 1).max(1)`. For `motion_tail_frames=4` this is 1 latent frame. For `motion_tail_frames=9` it's 2 latent frames. Tests cover the arithmetic. + +**1.2. Extend `Ltx2Engine` with a chain-aware generate path.** + +Add a method that `generate` proper delegates to: + +```rust +impl Ltx2Engine { + pub fn generate_with_carryover( + &mut self, + req: &GenerateRequest, + carry: Option<&ChainTail>, + ) -> Result<(GenerateResponse, ChainTail)>; +} +``` + +When `carry = None`, behaviour is identical to `self.generate(req)` (use the source_image path as today). When `carry = Some(tail)`, the engine: + +1. Skips VAE encode on `stage_conditioning` for the keyframe at frame 0. +2. Instead, threads `tail.latents` straight into `maybe_load_stage_video_conditioning` via a new optional parameter. The patchified tail tokens go into `StageVideoConditioning::replacements` with `strength = 1.0` and `start_token = 0..tail_token_count`. +3. Extracts the last `K = motion_tail_frames` pixel frames' worth of latents from the completed denoise (before VAE decode) and returns them as the new `ChainTail`. + +The new latent extraction hook needs to run **after the last denoise step, before `vae.decode`** in the distilled and two-stage paths. Surface it as a single helper `extract_tail_latents(&final_latents, motion_tail_frames) -> Tensor` that narrows along the time axis. + +- Tests for the helper: `extract_tail_computes_correct_latent_slice`, `extract_tail_preserves_device_and_dtype`, `extract_tail_handles_single_frame_edge_case`. + +**1.3. Stage conditioning: accept pre-encoded latents instead of a staged image.** + +Currently `maybe_load_stage_video_conditioning` (`runtime.rs:1215`) reads an image path, decodes, VAE-encodes. Add a sibling path that accepts `Option<&Tensor>` as pre-patchified tokens (or raw latents to be patchified in place). Route through it when the orchestrator passes carryover. + +Concretely: a new variant on `StagedImage` or a parallel `StagedLatent` struct carried through `StagedConditioning`. Prefer the latter — keeps the existing image path pristine. + +```rust +pub struct StagedLatent { + pub latents: Tensor, // [1, C, T, H/32, W/32] + pub frame: u32, // start frame (0 for chain carryover) + pub strength: f32, // 1.0 for chain +} + +pub struct StagedConditioning { + pub images: Vec, + pub latents: Vec, // NEW, empty for today's callers + pub audio_path: Option, + pub video_path: Option, +} +``` + +`maybe_load_stage_video_conditioning` iterates `images` then `latents`, patchifying the latter directly without calling `vae.encode`. All existing call sites pass an empty `latents` Vec. + +- Test: `staged_latent_produces_same_replacement_token_shape_as_image_for_single_latent_frame`. + +**1.4. Build `Ltx2ChainOrchestrator`.** + +```rust +// crates/mold-inference/src/ltx2/chain.rs +pub struct Ltx2ChainOrchestrator<'a> { + engine: &'a mut Ltx2Engine, +} + +impl<'a> Ltx2ChainOrchestrator<'a> { + pub fn run( + &mut self, + req: &ChainRequest, + progress: Option, + ) -> Result; +} +``` + +Internal loop: + +``` +let mut tail: Option = None; +let mut accumulated_frames: Vec = Vec::new(); +let tail_drop = req.motion_tail_frames as usize; + +for (idx, stage) in req.stages.iter().enumerate() { + let per_clip = build_clip_request(stage, &req, tail.is_some())?; + let (resp, new_tail) = self.engine.generate_with_carryover(&per_clip, tail.as_ref())?; + let frames = decode_video_frames_from_response(&resp)?; + if idx == 0 { + accumulated_frames.extend(frames); + } else { + // drop the leading `tail_drop` pixel frames; they duplicate the prior clip's tail + accumulated_frames.extend(frames.into_iter().skip(tail_drop)); + } + tail = Some(new_tail); + emit_progress(progress.as_ref(), ChainStageDone { idx, total: req.stages.len() }); +} + +let stitched = encode_mp4(&accumulated_frames, req.fps)?; +Ok(ChainResponse { video: stitched, ... }) +``` + +- Stage-1 request has `source_image = stage.source_image`, `keyframes = None`. +- Stage-N request (N ≥ 2) has `source_image = None`, `keyframes = None`; the carryover is passed via the `tail` parameter to `generate_with_carryover`, not through the request DTO. +- Progress events: forward engine events with an added `stage_idx`, plus emit `ChainStageStart` / `ChainStageDone` / `ChainStitching` / `ChainComplete`. + +- Tests (fake engine): `chain_runs_all_stages_and_drops_tail_prefix_from_continuations`, `chain_with_zero_tail_concats_full_clips_without_drop`, `chain_progress_forwards_engine_events_with_stage_idx`, `chain_empty_stages_errors`. + +Commit: `feat(ltx2): chain orchestrator with latent-tail carryover`. + +--- + +### Phase 2 — server route + +**2.1. `POST /api/generate/chain` (non-streaming).** + +Handler flow: + +1. Parse & normalise the `ChainRequest`. +2. Validate model is an LTX-2 family (`anyhow::bail!` with a clear error otherwise). +3. Grab the model's engine from `ModelCache` (load if needed, same as the existing video path). +4. Construct `Ltx2ChainOrchestrator` against it and call `run()`. +5. Save the stitched MP4 via the same save path as single-clip videos (`save_video_to_dir`), populating `OutputMetadata` with a synthetic prompt (`stages[0].prompt` for v1) and a note in a new optional metadata field `chain_stage_count: Option`. +6. Return `ChainResponse` as JSON. + +Do **not** go through the existing single-job queue — a chain is a long-running compound job and would block the queue for 10+ minutes. Instead, the handler holds the `ModelCache` mutex the same way the multi-GPU worker does, for the full chain duration. This is OK because the multi-GPU pool already has per-GPU thread isolation. + +**2.2. `POST /api/generate/chain/stream` (SSE).** + +Same flow but progress events stream as `data:` frames. Event types: + +- `chain_start { stage_count, estimated_total_frames }` +- `stage_start { stage_idx }` +- `denoise_step { stage_idx, step, total }` (forwarded from engine with `stage_idx` wrapped in) +- `stage_done { stage_idx, frames_emitted }` +- `stitching { total_frames }` +- `complete { video_frames, video_fps, video_base64, filename, seed, ... }` (same shape as `/api/generate/stream` complete event) +- `error { message }` + +The existing SSE completion-event helper (`build_sse_complete_event` in `queue.rs`) is not reusable as-is because it takes a single `GenerateResponse`; write a sibling `build_chain_sse_complete_event(&ChainResponse)` that produces the same JSON structure plus `chain_stage_count`. + +- Tests: route-level tests with a fake engine that exercise both non-streaming and SSE shapes; verify SSE emits events in the expected order. + +Commit: `feat(server): chain render endpoint and SSE stream`. + +--- + +### Phase 3 — CLI + +**3.1. Auto-route `mold run` to `/api/generate/chain` when `--frames > max_per_clip`.** + +Add a constant in `mold-cli` for LTX-2 clip caps (97 for 19B distilled, 97 for 22B — same as today's single-clip validation). When `frames > cap`: + +- Build a `ChainRequest` with `prompt=…`, `total_frames=…`, `clip_frames=cap`, `source_image=…`, `motion_tail_frames=4` (default). +- Call `MoldClient::generate_chain_stream`. +- Render a progress bar per stage stacked with a parent "chain" bar. + +When `frames ≤ cap`, path is unchanged (`/api/generate/stream`, single clip, today's behaviour). + +- New flag: `--clip-frames N` to let advanced users override the per-clip length (default = model cap). +- New flag: `--motion-tail N` to override the tail (default 4, 0 to disable). +- Help text for `--frames` updates to mention chained output when > cap. + +- Tests: `run_frames_above_cap_selects_chain_endpoint` (argparse-level; doesn't invoke the network). + +**3.2. `--local` chain mode.** + +For parity with `mold run --local`, the CLI should run the orchestrator in-process when `--local` is passed. Factor the orchestrator invocation into a helper so both the server handler and the CLI local path share it. + +Commit: `feat(cli): chain rendering for --frames above clip cap`. + +--- + +### Phase 4 — docs & changelog + +**4.1. Website.** Add a new section in `website/guide/video.md` explaining chained video output, how motion tail works, and the CLI flags. Link it from the LTX-2 model page. + +**4.2. CHANGELOG.** Unreleased / Added entry describing the `/api/generate/chain` route, the CLI auto-routing behaviour, and the motion-tail carryover. + +**4.3. Skill file.** Update `.claude/skills/mold/SKILL.md` with the new CLI flags and endpoint. + +Commit: `docs(chain): guide, changelog, and skill updates`. + +--- + +## Integration test: a realistic end-to-end + +One integration test lives in `crates/mold-server/tests/chain_integration.rs` (or inline in `tests/` if an integration dir exists). It: + +1. Stands up an in-process server with a **fake LTX-2 engine** (not real weights) whose `generate_with_carryover` returns a deterministic gradient pattern + a synthetic `ChainTail` whose latents are zeros but whose RGB tail frame is the last frame of the emitted clip. +2. POSTs an auto-expand chain request with `total_frames=200`, `clip_frames=97`, `motion_tail_frames=4`. +3. Asserts: + - Three stages fired. + - The stitched MP4 has `ceil((200 - 97) / 93) * 93 + 97 = 97 + 93*2 = 283 ≥ 200` frames before trim; after trim it's 200 frames. + - SSE stream emitted events in the expected order. + - The gallery DB got one row with `chain_stage_count = 3`. + +The fake-engine pattern keeps this test out of the GPU path and makes it safe to run in CI. + +--- + +## Open design decisions I'm flagging for your sign-off + +1. **Trim policy.** If `total_frames = 400` and chain math produces 469 frames, should we trim from the tail (final clip's final frames get cut — but those are the freshest continuation) or from the head (stage-0 frames get cut — but those are the user-anchored ones)? I recommend **trim from tail** because the head is where the user's starting image landed and matters more perceptually. + +2. **Seed handling across stages.** Should each stage get the same seed (reproducible but with artifacts from identical noise when prompts match), or derive per-stage seeds (`base_seed ^ (stage_idx << 32)`)? I recommend **derive per-stage**. `seed_offset` on `ChainStage` lets the movie maker override. + +3. **Failure mode mid-chain.** If stage 3 of 4 fails, do we return a 502 and discard everything, or return the partial stitch of stages 1–3? I recommend **fail closed for v1** — no partial output. Partial resume is a v2 movie-maker feature where individual stage regen is first-class. + +4. **Memory.** 400 frames × 1216×704×3 ≈ 1 GB of RgbImages held in RAM before MP4 encode. Acceptable for v1. If users push to 1000+ frames we revisit with streaming encode. + +5. **Placement.** Chain always runs on a single GPU for v1 (the one the engine was loaded onto). Multi-GPU fan-out (stage N and N+1 on different cards) is a v2 perf win; mention in docs but don't build. + +--- + +## What `mold run` looks like after this ships + +```console +$ mold run ltx-2-19b-distilled:fp8 "a cat walking through autumn leaves" \ + --image cat.png --frames 400 + +⏳ Chain render: 4 stages × 97 frames (motion tail: 4) → 388 stitched frames +▸ Stage 1/4 · denoise step 8/8 · 47s +▸ Stage 2/4 · denoise step 8/8 · 44s (tail carried from stage 1) +▸ Stage 3/4 · denoise step 8/8 · 44s +▸ Stage 4/4 · denoise step 8/8 · 44s +▸ Stitching 388 frames @ 24fps … +✔ Saved mold-ltx-2-19b-distilled-{ts}.mp4 (400 frames, 16.7s, 16MB) +``` + +--- + +## Out-of-scope for v1 but in-scope for v2 (movie maker) + +- SPA route `/movie` with a timeline authoring UI. +- Per-stage prompts and keyframes exposed in the request body (the server already supports this — only the UI needs to change). +- Per-clip gallery rows with `chain_id` grouping so users can iterate on individual stages. +- Selective stage regeneration (replace stage 2 without redoing 1/3/4). +- Crossfade blending at clip boundaries. +- Multi-GPU stage fan-out. + +The whole point of v1 is to ship a stable foundation these land on top of without breaking changes. From 0328e765f7866315100ecd8091b82295aa2cb4fd Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 16:31:38 -0700 Subject: [PATCH 03/31] feat(core): MoldClient chain methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `MoldClient::generate_chain` (POST /api/generate/chain, non- streaming JSON request/response) and `MoldClient::generate_chain_stream` (POST /api/generate/chain/stream, SSE) mirroring the existing `generate` / `generate_stream` shape. The server routes land in Phase 2; this commit ships the client surface so Phase 1's fake-engine tests and Phase 2's route wiring have a settled wire contract to implement against. Chain-specific wire types (all new, under `mold_core::chain`): - `ChainProgressEvent` — tagged enum streamed under `event: progress`. Variants: `chain_start { stage_count, estimated_total_frames }`, `stage_start { stage_idx }`, `denoise_step { stage_idx, step, total }`, `stage_done { stage_idx, frames_emitted }`, `stitching { total_frames }`. snake_case tagged JSON matches the existing `SseProgressEvent` style. - `SseChainCompleteEvent` — kept as a sibling to `crate::types::SseCompleteEvent` rather than an extension, so chain completion shape can evolve independently (stage_count, stitched- video payload, optional thumb/GIF, audio metadata, elapsed time). Error translation matches the single-clip methods: | Status | generate_chain | generate_chain_stream | |------------------------|-------------------------------------------------|-------------------------------------------------| | 200 | parse ChainResponse JSON | parse SSE until `complete` event | | 404, empty body | hard error "chain endpoint not found" | `Ok(None)` — caller may fall back | | 404, non-empty body | `MoldError::ModelNotFound` | `MoldError::ModelNotFound` | | 422 | `MoldError::Validation` | `MoldError::Validation` | | 4xx/5xx else | generic anyhow | generic anyhow | The non-streaming empty-404 behaviour deliberately differs from SSE: streaming clients can fall back to non-streaming, but non-streaming callers have nowhere to go and should fail loudly. Integration coverage: - `crates/mold-core/tests/chain_client.rs` (wiremock): endpoint/body shape assertion on non-streaming; 422 → Validation; 404-with-body → ModelNotFound; non-streaming empty 404 → hard error; SSE empty 404 → Ok(None); SSE progress + complete roundtrip reconstructs `ChainResponse` with thumb + gpu. - Pure serde roundtrip test for every `ChainProgressEvent` variant asserting snake_case tag format. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/mold-core/src/chain.rs | 124 +++++++++++++ crates/mold-core/src/client.rs | 132 ++++++++++++++ crates/mold-core/src/lib.rs | 5 +- crates/mold-core/tests/chain_client.rs | 234 +++++++++++++++++++++++++ 4 files changed, 494 insertions(+), 1 deletion(-) create mode 100644 crates/mold-core/tests/chain_client.rs diff --git a/crates/mold-core/src/chain.rs b/crates/mold-core/src/chain.rs index 101e479e..e8b68c7e 100644 --- a/crates/mold-core/src/chain.rs +++ b/crates/mold-core/src/chain.rs @@ -158,6 +158,85 @@ pub struct ChainResponse { pub gpu: Option, } +/// SSE completion event for a successful chain run. Streamed as the final +/// `data:` frame under the `event: complete` SSE type. The payload is +/// base64-encoded to stay JSON-safe; clients decode it into `VideoData`. +/// +/// This is a sibling to [`crate::types::SseCompleteEvent`] rather than an +/// extension so image/video vs. chain completion shapes stay independent +/// and can evolve separately. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct SseChainCompleteEvent { + /// Base64-encoded stitched video bytes (format per `format` field). + pub video: String, + pub format: OutputFormat, + #[schema(example = 1216)] + pub width: u32, + #[schema(example = 704)] + pub height: u32, + #[schema(example = 400)] + pub frames: u32, + #[schema(example = 24)] + pub fps: u32, + /// Base64-encoded first-frame PNG thumbnail. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub thumbnail: Option, + /// Base64-encoded animated GIF preview (always emitted for gallery UI). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub gif_preview: Option, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub has_audio: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub duration_ms: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub audio_sample_rate: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub audio_channels: Option, + /// Number of stages that ran end-to-end. + #[schema(example = 5)] + pub stage_count: u32, + /// GPU ordinal that handled the chain (multi-GPU only). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub gpu: Option, + /// Wall-clock elapsed time across all stages + stitching. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub generation_time_ms: Option, +} + +/// Chain-specific SSE progress event. Streamed as `data:` JSON frames from +/// `POST /api/generate/chain/stream` under the `event: progress` SSE type. +/// +/// Per-stage denoise steps are wrapped with `stage_idx` so consumers can +/// render stacked progress bars (overall chain + per-stage) without a +/// separate subscription. Non-denoise engine events (weight load, cache +/// hits, etc.) are intentionally not forwarded through this enum in v1 — +/// they're scoped to individual stages and the UX goal for v1 is per-stage +/// progress, not per-component telemetry. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChainProgressEvent { + /// Emitted once at the start of the chain, after normalisation. Gives + /// consumers the final stage count and the target pre-trim frame total + /// so they can size progress bars up front. + ChainStart { + stage_count: u32, + estimated_total_frames: u32, + }, + /// Stage `stage_idx` (0-indexed) has started its denoise loop. + StageStart { stage_idx: u32 }, + /// Per-step denoise progress for the active stage. + DenoiseStep { + stage_idx: u32, + step: u32, + total: u32, + }, + /// Stage finished generating; `frames_emitted` is the raw clip frame + /// count before motion-tail trim at stitch time. + StageDone { stage_idx: u32, frames_emitted: u32 }, + /// All stages complete; stitching/encoding the final MP4. + Stitching { total_frames: u32 }, +} + fn default_motion_tail_frames() -> u32 { 4 } @@ -566,6 +645,51 @@ mod tests { } } + #[test] + fn chain_progress_event_roundtrips_json_with_snake_case_tags() { + let cases = [ + ( + ChainProgressEvent::ChainStart { + stage_count: 5, + estimated_total_frames: 469, + }, + r#""type":"chain_start""#, + ), + ( + ChainProgressEvent::StageStart { stage_idx: 0 }, + r#""type":"stage_start""#, + ), + ( + ChainProgressEvent::DenoiseStep { + stage_idx: 2, + step: 4, + total: 8, + }, + r#""type":"denoise_step""#, + ), + ( + ChainProgressEvent::StageDone { + stage_idx: 3, + frames_emitted: 97, + }, + r#""type":"stage_done""#, + ), + ( + ChainProgressEvent::Stitching { total_frames: 400 }, + r#""type":"stitching""#, + ), + ]; + for (event, expected_tag) in cases { + let json = serde_json::to_string(&event).expect("serialize"); + assert!( + json.contains(expected_tag), + "missing snake_case tag {expected_tag} in {json}", + ); + let roundtrip: ChainProgressEvent = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(roundtrip, event, "roundtrip must preserve payload"); + } + } + #[test] fn build_stages_math_matches_stitch_budget() { // Auto-expand must produce enough stages that the stitch delivers diff --git a/crates/mold-core/src/client.rs b/crates/mold-core/src/client.rs index e900739f..ffea387e 100644 --- a/crates/mold-core/src/client.rs +++ b/crates/mold-core/src/client.rs @@ -1,3 +1,4 @@ +use crate::chain::{ChainProgressEvent, ChainRequest, ChainResponse, SseChainCompleteEvent}; use crate::error::MoldError; use crate::types::{ ExpandRequest, ExpandResponse, GalleryImage, GenerateRequest, GenerateResponse, ImageData, @@ -313,6 +314,137 @@ impl MoldClient { anyhow::bail!("SSE stream ended without complete event") } + /// Submit a chained video generation request (non-streaming). + /// + /// The server normalises the auto-expand form into stages, runs each + /// stage sequentially with motion-tail latent carryover, stitches the + /// result into a single video, and returns a [`ChainResponse`]. Large + /// chains take minutes — prefer [`Self::generate_chain_stream`] for + /// interactive clients that want progress updates. + pub async fn generate_chain(&self, req: &ChainRequest) -> Result { + let resp = self + .client + .post(format!("{}/api/generate/chain", self.base_url)) + .json(req) + .send() + .await?; + + if resp.status() == reqwest::StatusCode::NOT_FOUND { + let body = resp.text().await.unwrap_or_default(); + if body.is_empty() { + anyhow::bail!("chain endpoint not found — server predates render-chain v1"); + } + return Err(MoldError::ModelNotFound(body).into()); + } + if resp.status() == reqwest::StatusCode::UNPROCESSABLE_ENTITY { + let body = resp.text().await.unwrap_or_default(); + return Err(MoldError::Validation(format!("validation error: {body}")).into()); + } + if resp.status().is_client_error() || resp.status().is_server_error() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("server error {status}: {body}"); + } + + let chain: ChainResponse = resp.json().await?; + Ok(chain) + } + + /// Submit a chained video generation request with SSE progress streaming. + /// + /// Returns: + /// - `Ok(Some(response))` — streaming succeeded and the `complete` event + /// carried the stitched video. + /// - `Ok(None)` — server doesn't have the chain endpoint (empty 404). + /// Callers can fall back to [`Self::generate_chain`] or error. + /// - `Err(_)` — validation, model-not-found, or mid-stream server error. + pub async fn generate_chain_stream( + &self, + req: &ChainRequest, + progress_tx: tokio::sync::mpsc::UnboundedSender, + ) -> Result> { + let mut resp = self + .client + .post(format!("{}/api/generate/chain/stream", self.base_url)) + .json(req) + .send() + .await?; + + if resp.status() == reqwest::StatusCode::NOT_FOUND { + let body = resp.text().await.unwrap_or_default(); + if body.is_empty() { + return Ok(None); + } + return Err(MoldError::ModelNotFound(body).into()); + } + if resp.status() == reqwest::StatusCode::UNPROCESSABLE_ENTITY { + let body = resp.text().await.unwrap_or_default(); + return Err(MoldError::Validation(format!("validation error: {body}")).into()); + } + if resp.status().is_client_error() || resp.status().is_server_error() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("server error {status}: {body}"); + } + + let b64 = base64::engine::general_purpose::STANDARD; + let mut buffer = String::new(); + while let Some(chunk) = resp.chunk().await? { + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(event_text) = next_sse_event(&mut buffer) { + let (event_type, data) = parse_sse_event(&event_text); + match event_type.as_str() { + "progress" => { + if let Ok(p) = serde_json::from_str::(&data) { + let _ = progress_tx.send(p); + } + } + "complete" => { + let complete: SseChainCompleteEvent = serde_json::from_str(&data)?; + let payload = b64.decode(&complete.video)?; + let thumbnail = complete + .thumbnail + .as_deref() + .and_then(|s| b64.decode(s).ok()) + .unwrap_or_default(); + let gif_preview = complete + .gif_preview + .as_deref() + .and_then(|s| b64.decode(s).ok()) + .unwrap_or_default(); + let video = VideoData { + data: payload, + format: complete.format, + width: complete.width, + height: complete.height, + frames: complete.frames, + fps: complete.fps, + thumbnail, + gif_preview, + has_audio: complete.has_audio, + duration_ms: complete.duration_ms, + audio_sample_rate: complete.audio_sample_rate, + audio_channels: complete.audio_channels, + }; + return Ok(Some(ChainResponse { + video, + stage_count: complete.stage_count, + gpu: complete.gpu, + })); + } + "error" => { + let error: SseErrorEvent = serde_json::from_str(&data)?; + anyhow::bail!("server error: {}", error.message); + } + _ => {} + } + } + } + + anyhow::bail!("chain SSE stream ended without complete event") + } + /// Ask the server to pull (download) a model. Blocks until the download /// completes on the server side. The server updates its in-memory config /// so subsequent generate/load requests can find the model. diff --git a/crates/mold-core/src/lib.rs b/crates/mold-core/src/lib.rs index e7b2f3c1..9da6a5e2 100644 --- a/crates/mold-core/src/lib.rs +++ b/crates/mold-core/src/lib.rs @@ -19,7 +19,10 @@ mod config_test; mod test_support; pub use catalog::build_model_catalog; -pub use chain::{ChainRequest, ChainResponse, ChainStage, MAX_CHAIN_STAGES}; +pub use chain::{ + ChainProgressEvent, ChainRequest, ChainResponse, ChainStage, SseChainCompleteEvent, + MAX_CHAIN_STAGES, +}; pub use client::MoldClient; pub use config::{ parse_device_ref_str, Config, DefaultModelResolution, DefaultModelSource, LoggingConfig, diff --git a/crates/mold-core/tests/chain_client.rs b/crates/mold-core/tests/chain_client.rs new file mode 100644 index 00000000..dc06d1bb --- /dev/null +++ b/crates/mold-core/tests/chain_client.rs @@ -0,0 +1,234 @@ +//! Integration tests for `MoldClient::generate_chain{,_stream}` using +//! `wiremock` to simulate the `/api/generate/chain` server endpoints. +//! +//! These tests pin the HTTP surface (method, path, JSON request body) and +//! verify error translation (422 → Validation, 404 empty → None on stream, +//! 404 with body → ModelNotFound). They do NOT exercise real LTX-2 work — +//! the server side lands in Phase 2. + +use base64::Engine as _; +use mold_core::chain::{ChainProgressEvent, ChainRequest, ChainStage, SseChainCompleteEvent}; +use mold_core::error::MoldError; +use mold_core::types::OutputFormat; +use mold_core::MoldClient; +use wiremock::matchers::{body_json_schema, method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +fn mold_error(err: &anyhow::Error) -> &MoldError { + err.downcast_ref::() + .unwrap_or_else(|| panic!("not a MoldError: {err}")) +} + +fn sample_request() -> ChainRequest { + ChainRequest { + model: "ltx-2-19b-distilled:fp8".into(), + stages: vec![ChainStage { + prompt: "a cat walking".into(), + frames: 97, + source_image: None, + negative_prompt: None, + seed_offset: None, + }], + motion_tail_frames: 4, + width: 1216, + height: 704, + fps: 24, + seed: Some(42), + steps: 8, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Mp4, + placement: None, + prompt: None, + total_frames: None, + clip_frames: None, + source_image: None, + } +} + +fn minimal_chain_response_json() -> serde_json::Value { + serde_json::json!({ + "video": { + "data": [], + "format": "mp4", + "width": 1216, + "height": 704, + "frames": 97, + "fps": 24, + "thumbnail": [] + }, + "stage_count": 1 + }) +} + +// ── /api/generate/chain (non-streaming) ──────────────────────────────── + +#[tokio::test] +async fn generate_chain_posts_to_correct_endpoint_and_parses_response() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain")) + .and(body_json_schema::) + .respond_with(ResponseTemplate::new(200).set_body_json(minimal_chain_response_json())) + .expect(1) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let resp = client + .generate_chain(&sample_request()) + .await + .expect("non-streaming chain should succeed on 200"); + assert_eq!(resp.stage_count, 1); + assert_eq!(resp.video.frames, 97); + assert_eq!(resp.video.format, OutputFormat::Mp4); +} + +#[tokio::test] +async fn generate_chain_surfaces_422_as_validation_error() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain")) + .respond_with(ResponseTemplate::new(422).set_body_string("frames must be 8k+1")) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let err = client + .generate_chain(&sample_request()) + .await + .expect_err("422 must error"); + assert!( + matches!(mold_error(&err), MoldError::Validation(msg) if msg.contains("8k+1")), + "422 must translate to MoldError::Validation carrying the body", + ); +} + +#[tokio::test] +async fn generate_chain_translates_404_with_body_to_model_not_found() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain")) + .respond_with(ResponseTemplate::new(404).set_body_string("model 'ltx-2-foo' not found")) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let err = client + .generate_chain(&sample_request()) + .await + .expect_err("404 with body must error"); + assert!( + matches!(mold_error(&err), MoldError::ModelNotFound(msg) if msg.contains("ltx-2-foo")), + "404-with-body must translate to MoldError::ModelNotFound", + ); +} + +#[tokio::test] +async fn generate_chain_empty_404_fails_loudly_instead_of_silently() { + // Non-streaming callers have no fallback path — an empty 404 means the + // server predates render-chain v1, which is a hard error (unlike the + // streaming case where Ok(None) signals "try the non-streaming path"). + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain")) + .respond_with(ResponseTemplate::new(404).set_body_string("")) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let err = client + .generate_chain(&sample_request()) + .await + .expect_err("empty 404 must error on non-streaming path"); + let msg = format!("{err}"); + assert!( + msg.contains("chain endpoint not found"), + "error must name the missing endpoint, got: {msg}", + ); +} + +// ── /api/generate/chain/stream (SSE) ─────────────────────────────────── + +#[tokio::test] +async fn generate_chain_stream_returns_none_on_empty_404() { + // An empty 404 on the streaming endpoint means the server doesn't + // support chain SSE yet — callers are expected to fall back to the + // non-streaming path. + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain/stream")) + .respond_with(ResponseTemplate::new(404).set_body_string("")) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let (tx, _rx) = tokio::sync::mpsc::unbounded_channel::(); + let out = client + .generate_chain_stream(&sample_request(), tx) + .await + .expect("empty 404 should resolve to Ok(None)"); + assert!(out.is_none(), "empty 404 must signal unsupported endpoint"); +} + +#[tokio::test] +async fn generate_chain_stream_parses_progress_and_complete_events() { + let b64 = base64::engine::general_purpose::STANDARD; + let video_bytes = b"FAKE_MP4_BYTES"; + let thumb_bytes = b"THUMB"; + let complete = SseChainCompleteEvent { + video: b64.encode(video_bytes), + format: OutputFormat::Mp4, + width: 1216, + height: 704, + frames: 97, + fps: 24, + thumbnail: Some(b64.encode(thumb_bytes)), + gif_preview: None, + has_audio: false, + duration_ms: Some(4040), + audio_sample_rate: None, + audio_channels: None, + stage_count: 1, + gpu: Some(0), + generation_time_ms: Some(45_000), + }; + let progress = ChainProgressEvent::DenoiseStep { + stage_idx: 0, + step: 4, + total: 8, + }; + // Build a chunk-encoded SSE body carrying one progress event then + // complete. `\n\n` terminates each SSE event. + let body = format!( + "event: progress\ndata: {}\n\nevent: complete\ndata: {}\n\n", + serde_json::to_string(&progress).unwrap(), + serde_json::to_string(&complete).unwrap(), + ); + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain/stream")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_string(body), + ) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let resp = client + .generate_chain_stream(&sample_request(), tx) + .await + .expect("SSE stream should succeed") + .expect("complete event should yield a response"); + + assert_eq!(resp.stage_count, 1); + assert_eq!(resp.video.data, video_bytes); + assert_eq!(resp.video.thumbnail, thumb_bytes); + assert_eq!(resp.gpu, Some(0)); + let ev = rx.recv().await.expect("progress event should be forwarded"); + assert_eq!(ev, progress); +} From e89826f0ec5422c799a9861d88caba57117d26ae Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 17:15:00 -0700 Subject: [PATCH 04/31] feat(ltx2): ChainTail type and latent-tail extraction helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce the carryover primitive that render-chain stages hand to each other. `ChainTail { frames, latents, last_rgb_frame }` bundles the final VAE latents of a stage's motion tail so the next stage can patch those tokens straight into its conditioning without a VAE decode → RGB → VAE encode round-trip. No engine wiring yet — the orchestrator and the `generate_with_carryover` entry point land in sibling commits. Helpers in the new `ltx2::chain` module: - `tail_latent_frame_count(pixel_frames: u32) -> usize` — exposes the LTX-2 VAE's 8× causal-first-frame temporal ratio as the formula `((n - 1) / 8) + 1`. Matches `VideoLatentShape::from_pixel_shape`. Panics on `0`; callers must validate upstream. - `extract_tail_latents(final_latents: &Tensor, pixel_frames: u32) -> Result` — narrows the time axis of a rank-5 `[B, C, T, H, W]` latents tensor down to the last K latent frames corresponding to the requested pixel-frame tail. Errors (not panics) on rank mismatch or oversize tail request so orchestrator bugs surface as operational errors, not process aborts. Unit tests cover: the VAE formula across representative tail sizes (4→1, 9→2, 16→2, 17→3, 97→13), rejection of a zero pixel-frame tail, correct narrowing on a synthetic [1, 2, 3, 1, 1] tensor with sentinel values proving the last latent frame is returned across all channels, narrowing on a larger rank-5 tensor, rank-4 rejection, and oversize-tail rejection. All tests are weight-free and run under `cargo test -p mold-ai-inference --lib`. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/mold-inference/src/ltx2/chain.rs | 197 ++++++++++++++++++++++++ crates/mold-inference/src/ltx2/mod.rs | 2 + 2 files changed, 199 insertions(+) create mode 100644 crates/mold-inference/src/ltx2/chain.rs diff --git a/crates/mold-inference/src/ltx2/chain.rs b/crates/mold-inference/src/ltx2/chain.rs new file mode 100644 index 00000000..b1c4c91e --- /dev/null +++ b/crates/mold-inference/src/ltx2/chain.rs @@ -0,0 +1,197 @@ +//! LTX-2 chain carryover primitives. +//! +//! Server-side chained video generation stitches multiple per-clip renders +//! into a single output. To avoid a VAE decode → RGB → VAE encode round-trip +//! between clips (which loses information and doubles VAE cost), the tail of +//! each clip is carried across as latent-space tokens and threaded into the +//! next clip's conditioning directly. +//! +//! This module owns the data types and shape math for that handoff. The +//! orchestrator and the `Ltx2Engine::generate_with_carryover` entry point +//! land in sibling commits. +//! +//! See `tasks/render-chain-v1-plan.md` Phase 1.1 for context. + +use anyhow::{anyhow, Context, Result}; +use candle_core::Tensor; +use image::RgbImage; + +use crate::ltx2::model::shapes::SpatioTemporalScaleFactors; + +/// Opaque carryover payload handed from one chain stage to the next. +/// +/// Holds the final VAE latents of the emitting stage's motion tail, not the +/// decoded pixels — so the receiving stage can patchify the tokens directly +/// into its conditioning without a VAE re-encode. +#[derive(Debug, Clone)] +pub struct ChainTail { + /// Number of *pixel* frames this tail represents (not latent frames). + /// Clients of [`ChainTail`] work in pixel-frame units because that's + /// what users think in; the latent-frame count is derived from this + /// plus the LTX-2 VAE's 8× causal temporal ratio. + pub frames: u32, + + /// Latent tokens for the tail. + /// + /// Shape: `[batch=1, channels=128, tail_latent_frames, H/32, W/32]` + /// where `tail_latent_frames = tail_latent_frame_count(self.frames)`. + /// + /// Dtype is whatever the denoise loop produced — typically `F32`. + /// Device is the engine's active device (GPU or CPU); the orchestrator + /// is responsible for ensuring the next stage runs on the same device. + pub latents: Tensor, + + /// The last decoded pixel frame of the emitting stage. Kept for + /// debugging, progress UIs that want a thumbnail of the handoff point, + /// and as a fallback rendering target if latent carryover ever needs + /// to be disabled at runtime. + pub last_rgb_frame: RgbImage, +} + +/// Number of latent frames corresponding to `pixel_frames` pixel frames +/// under the LTX-2 VAE's 8× causal temporal compression. `1` for +/// `1..=8` pixel frames, `2` for `9..=16`, etc. Matches +/// `VideoLatentShape::from_pixel_shape`. +/// +/// Panics if `pixel_frames == 0` — a zero-frame tail is nonsensical and +/// would under-flow the formula. Callers must validate upstream. +pub fn tail_latent_frame_count(pixel_frames: u32) -> usize { + assert!( + pixel_frames > 0, + "tail_latent_frame_count: pixel_frames must be > 0", + ); + let scale = SpatioTemporalScaleFactors::default().time; + ((pixel_frames as usize - 1) / scale) + 1 +} + +/// Slice the last `tail_latent_frame_count(pixel_frames)` frames off the +/// time axis of a rank-5 video-latents tensor shaped +/// `[B, C, T, H, W]`. +/// +/// The returned tensor is a view/narrow on the input (no copy on candle's +/// current backends) so callers who intend to hand it to a separate engine +/// invocation — which may drop this engine's state and rebuild it — should +/// `.contiguous()` or `.copy()` the result before the original owner goes +/// out of scope. +/// +/// Errors if the tensor is not rank-5 or the requested tail exceeds the +/// available time axis — the latter would mean the orchestrator asked for +/// more tail than the stage produced, which indicates a caller bug. +pub fn extract_tail_latents(final_latents: &Tensor, pixel_frames: u32) -> Result { + let dims = final_latents.dims(); + if dims.len() != 5 { + return Err(anyhow!( + "extract_tail_latents: expected rank-5 tensor [B, C, T, H, W], got shape {:?}", + dims, + )); + } + let time = dims[2]; + let tail = tail_latent_frame_count(pixel_frames); + if tail > time { + return Err(anyhow!( + "extract_tail_latents: tail requests {} latent frames but the stage emitted only {} \ + (pixel_frames={}, tensor shape={:?})", + tail, + time, + pixel_frames, + dims, + )); + } + let start = time - tail; + final_latents + .narrow(2, start, tail) + .with_context(|| format!("narrow last {tail} latent frames off time axis")) +} + +#[cfg(test)] +mod tests { + use super::*; + use candle_core::{DType, Device}; + + #[test] + fn tail_latent_frame_count_matches_vae_formula() { + // Single-frame tail and up to 8 pixel frames fit in 1 latent frame + // (LTX-2 VAE uses causal first frame + 8× temporal compression). + for px in [1u32, 2, 4, 8] { + assert_eq!(tail_latent_frame_count(px), 1, "{px} pixel frames"); + } + // 9..=16 span 2 latent frames, 17..=24 span 3, etc. + assert_eq!(tail_latent_frame_count(9), 2); + assert_eq!(tail_latent_frame_count(16), 2); + assert_eq!(tail_latent_frame_count(17), 3); + assert_eq!(tail_latent_frame_count(24), 3); + // Full-clip tail (97 frames) → 13 latent frames, matching + // VideoLatentShape::from_pixel_shape under the same VAE ratio. + assert_eq!(tail_latent_frame_count(97), 13); + } + + #[test] + #[should_panic(expected = "pixel_frames must be > 0")] + fn tail_latent_frame_count_rejects_zero() { + tail_latent_frame_count(0); + } + + #[test] + fn extract_tail_narrows_last_latent_frame_for_4_pixel_frame_tail() { + // Build a synthetic [1, 2, 3, 1, 1] where channel 0 is the latent- + // frame index and channel 1 is a sentinel (42, 43, 44) so we can + // see which frames the narrow returns. + let data = vec![ + // frame 0 + 0.0f32, 42.0, // frame 1 + 1.0, 43.0, // frame 2 + 2.0, 44.0, + ]; + // Arrange [B=1, C=2, T=3, H=1, W=1]. `Tensor::from_vec` fills in + // row-major order — the permute below puts channels on axis 1. + let raw = Tensor::from_vec(data, (1, 3, 2, 1, 1), &Device::Cpu).expect("build raw tensor"); + // Reshape [1, T, C, H, W] → [1, C, T, H, W] + let latents = raw + .permute([0, 2, 1, 3, 4]) + .expect("permute to [B, C, T, H, W]"); + assert_eq!(latents.dims(), &[1, 2, 3, 1, 1]); + + // tail_latent_frame_count(4) = 1 → take the last latent frame only. + let tail = extract_tail_latents(&latents, 4).expect("extract"); + assert_eq!(tail.dims(), &[1, 2, 1, 1, 1]); + let values = tail.flatten_all().unwrap().to_vec1::().unwrap(); + assert_eq!( + values, + vec![2.0, 44.0], + "tail must be the last latent frame (index 2) across all channels", + ); + } + + #[test] + fn extract_tail_narrows_two_frames_for_9_pixel_frame_tail() { + // Simple rank-5 zero tensor with T=3; narrowing the last 2 frames + // out of 3 is enough to verify the shape without wrestling with + // permutations again. + let latents = Tensor::zeros((1, 1, 3, 2, 2), DType::F32, &Device::Cpu).unwrap(); + let tail = extract_tail_latents(&latents, 9).expect("extract"); + assert_eq!(tail.dims(), &[1, 1, 2, 2, 2]); + } + + #[test] + fn extract_tail_rejects_rank_4_tensor() { + let bad = Tensor::zeros((1, 128, 3, 4), DType::F32, &Device::Cpu).unwrap(); + let err = extract_tail_latents(&bad, 4).expect_err("rank 4 must fail"); + let msg = format!("{err}"); + assert!( + msg.contains("rank-5") && msg.contains("T, H, W"), + "error must identify the rank mismatch, got: {msg}", + ); + } + + #[test] + fn extract_tail_rejects_oversize_request() { + // Tensor has 1 latent frame; asking for a 9-pixel-frame tail needs 2. + let latents = Tensor::zeros((1, 128, 1, 4, 4), DType::F32, &Device::Cpu).unwrap(); + let err = extract_tail_latents(&latents, 9).expect_err("oversize tail must fail"); + let msg = format!("{err}"); + assert!( + msg.contains("requests 2") && msg.contains("only 1"), + "error must name the latent-frame mismatch, got: {msg}", + ); + } +} diff --git a/crates/mold-inference/src/ltx2/mod.rs b/crates/mold-inference/src/ltx2/mod.rs index d2101d33..9858e1d5 100644 --- a/crates/mold-inference/src/ltx2/mod.rs +++ b/crates/mold-inference/src/ltx2/mod.rs @@ -1,5 +1,6 @@ mod assets; mod backend; +pub mod chain; mod conditioning; mod execution; mod guidance; @@ -13,4 +14,5 @@ mod runtime; mod sampler; mod text; +pub use chain::{extract_tail_latents, tail_latent_frame_count, ChainTail}; pub use pipeline::Ltx2Engine; From e91721090cdf3839897854214bdb7747ada10dd4 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 17:21:04 -0700 Subject: [PATCH 05/31] feat(ltx2): staged latent conditioning bypasses VAE encode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `StagedConditioning` now carries both disk-backed images (existing single-clip path) and in-memory latent blocks (new, empty for every non-chain caller). The render-chain orchestrator will populate the new `latents: Vec` field with a prior stage's motion-tail latents so the receiving stage can patchify those tokens straight into its `StageVideoConditioning::replacements` without the VAE decode → RGB → VAE encode round-trip — that's the point of latent carryover. Changes: - `StagedLatent { latents, frame, strength }` in `ltx2::conditioning` — mirrors `StagedImage`'s semantics but with a pre-encoded `candle_core::Tensor` instead of a disk path. `frame = 0` routes tokens through `replacements` (chain v1 motion tail); non-zero `frame` builds a `VideoTokenAppendCondition` so the movie- maker in v2 can thread latents into arbitrary positions. - `StagedConditioning` drops `PartialEq` since `Tensor` doesn't implement structural equality. Grepped for comparison usages — none. Existing callers of `stage_conditioning()` get `latents: Vec::new()`. - `maybe_load_stage_video_conditioning` in `runtime.rs`: - Early-return gate now also considers `plan.conditioning.latents`. - VAE is loaded conditionally: only when images or reference video need encoding. Pure-latent chain handoffs skip VAE load entirely. - New loop iterates staged latents, patchifies each block, routes frame-0 tokens to `replacements` (keyframe pipelines aside) and other frames to `appended` — symmetrical with the image path. Tests (weight-free): - `stage_conditioning_leaves_latents_empty_for_non_chain_callers` — pins the back-compat invariant: every non-chain generate path continues to receive an empty latents vec. - `staged_latent_patchifies_to_same_token_shape_as_image_at_single_latent_frame` — verifies a `[1, 128, 1, 22, 38]` chain-tail latent block patchifies to `[1, 836, 128]` tokens, the same shape the image-conditioning path produces after VAE encode + patchify for the equivalent latent geometry. Chain orchestrator + `Ltx2Engine::generate_with_carryover` land in the sibling Phase 1c commit. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../mold-inference/src/ltx2/conditioning.rs | 56 ++++++++++++++- crates/mold-inference/src/ltx2/runtime.rs | 72 +++++++++++++++++-- 2 files changed, 123 insertions(+), 5 deletions(-) diff --git a/crates/mold-inference/src/ltx2/conditioning.rs b/crates/mold-inference/src/ltx2/conditioning.rs index b0b8d0e7..ce7c9bb1 100644 --- a/crates/mold-inference/src/ltx2/conditioning.rs +++ b/crates/mold-inference/src/ltx2/conditioning.rs @@ -1,4 +1,5 @@ use anyhow::{bail, Result}; +use candle_core::Tensor; use mold_core::{GenerateRequest, TimeRange}; use std::fs; use std::ops::RangeInclusive; @@ -11,9 +12,38 @@ pub(crate) struct StagedImage { pub(crate) strength: f32, } -#[derive(Debug, Clone, PartialEq)] +/// Pre-encoded latent block used as conditioning input, bypassing the +/// staged-image path's VAE encode. Populated by the render-chain +/// orchestrator when handing a motion-tail off between stages; empty for +/// every non-chain caller today. +/// +/// Tensor shape must be `[batch=1, channels=128, T_latent, H/32, W/32]` +/// to match the LTX-2 video VAE output. The runtime patchifies it directly +/// into conditioning tokens. +#[derive(Debug, Clone)] +pub(crate) struct StagedLatent { + pub(crate) latents: Tensor, + /// Starting pixel frame for this latent block. `0` routes the tokens + /// through `StageVideoConditioning::replacements`; non-zero values + /// build a `VideoTokenAppendCondition` like the keyframe image path. + pub(crate) frame: u32, + /// Replacement/append strength. `1.0` for chain motion-tail carryover + /// (hard-overwrite), matching the keyframe image strength convention. + pub(crate) strength: f32, +} + +/// Conditioning inputs staged for a single run. Carries both disk-backed +/// files (images, audio, reference video — existing single-clip flow) and +/// in-memory latent blocks (chain carryover — new, empty for non-chain +/// callers). +/// +/// Not `PartialEq` because `StagedLatent` wraps a `candle_core::Tensor` +/// which doesn't implement meaningful structural equality. Existing tests +/// only compare individual fields so this is safe to drop. +#[derive(Debug, Clone)] pub(crate) struct StagedConditioning { pub(crate) images: Vec, + pub(crate) latents: Vec, pub(crate) audio_path: Option, pub(crate) video_path: Option, } @@ -99,6 +129,7 @@ pub(crate) fn stage_conditioning( Ok(StagedConditioning { images, + latents: Vec::new(), audio_path, video_path, }) @@ -224,6 +255,29 @@ mod tests { assert!(mask[18..].iter().all(|value| *value == 0.0)); } + #[test] + fn stage_conditioning_leaves_latents_empty_for_non_chain_callers() { + // Single-clip callers build `StagedConditioning` via this function; + // the `latents` field (used by the render-chain orchestrator to inject + // pre-encoded motion-tail tokens) must stay empty so existing runs + // keep routing conditioning through the image path with VAE encode. + let work_dir = tempfile::tempdir().unwrap(); + let mut req = req(); + req.source_image = Some(fake_png_bytes()); + req.keyframes = Some(vec![KeyframeCondition { + frame: 8, + image: fake_png_bytes(), + }]); + req.source_video = Some(fake_mp4_bytes()); + req.audio_file = Some(fake_wav_bytes()); + + let staged = stage_conditioning(&req, work_dir.path()).unwrap(); + assert!( + staged.latents.is_empty(), + "non-chain callers must leave latents empty", + ); + } + #[test] fn stage_conditioning_stages_source_image_as_frame_zero_replacement() { let work_dir = tempfile::tempdir().unwrap(); diff --git a/crates/mold-inference/src/ltx2/runtime.rs b/crates/mold-inference/src/ltx2/runtime.rs index 72135d92..8df1f26c 100644 --- a/crates/mold-inference/src/ltx2/runtime.rs +++ b/crates/mold-inference/src/ltx2/runtime.rs @@ -1219,17 +1219,32 @@ fn maybe_load_stage_video_conditioning( dtype: DType, include_reference_video: bool, ) -> Result { - if plan.conditioning.images.is_empty() && !include_reference_video { + if plan.conditioning.images.is_empty() + && plan.conditioning.latents.is_empty() + && !include_reference_video + { return Ok(StageVideoConditioning::default()); } - let mut vae = load_ltx2_video_vae(plan, device, dtype)?; - vae.use_tiling = false; - vae.use_framewise_decoding = false; + // The VAE is only needed when we have images to encode or a reference + // video to ingest. Pre-encoded staged latents (chain carryover) skip + // VAE load entirely — that's the whole point of latent carryover. + let need_vae = !plan.conditioning.images.is_empty() || include_reference_video; + let mut vae = if need_vae { + let mut loaded = load_ltx2_video_vae(plan, device, dtype)?; + loaded.use_tiling = false; + loaded.use_framewise_decoding = false; + Some(loaded) + } else { + None + }; let patchifier = VideoLatentPatchifier::new(1); let mut conditioning = StageVideoConditioning::default(); for image in &plan.conditioning.images { + let vae = vae.as_mut().expect( + "need_vae guarantees the VAE is loaded whenever plan.conditioning.images is non-empty", + ); let bytes = std::fs::read(&image.path).with_context(|| { format!( "failed to read staged LTX-2 conditioning image '{}'", @@ -1271,7 +1286,36 @@ fn maybe_load_stage_video_conditioning( )?); } } + // Pre-encoded latents (chain carryover). No VAE needed — tokens come + // straight from the caller. For v1 chain this only ever holds a frame-0 + // replacement (motion-tail latents from the prior stage); appended + // (non-frame-0) is kept as a forward-compat branch for the movie-maker. + for staged in &plan.conditioning.latents { + let latents = staged.latents.to_device(device)?.to_dtype(DType::F32)?; + let use_guiding_latent = matches!(plan.pipeline, PipelineKind::Keyframe); + if staged.frame == 0 && !use_guiding_latent { + let tokens = patchifier.patchify(&latents)?; + conditioning.replacements.push(VideoTokenReplacement { + start_token: 0, + tokens, + strength: staged.strength as f64, + }); + } else { + conditioning + .appended + .push(append_condition_from_video_latents( + &latents, + pixel_shape, + staged.frame, + 1, + staged.strength as f64, + )?); + } + } if include_reference_video { + let vae = vae.as_mut().expect( + "need_vae guarantees the VAE is loaded whenever include_reference_video is true", + ); let video_path = plan.conditioning.video_path.as_ref().with_context(|| { format!( "native {:?} stage requested reference video conditioning without a staged source_video", @@ -5784,6 +5828,26 @@ mod tests { ); } + #[test] + fn staged_latent_patchifies_to_same_token_shape_as_image_at_single_latent_frame() { + // A 4-pixel-frame motion tail at 1216×704 output lands on a latent + // block of shape [1, 128, 1, 22, 38]. The render-chain orchestrator + // produces this block from the prior stage's denoise result; the + // image-conditioning path produces the same shape after VAE encode. + // Both must patchify to [1, T*H*W, C] = [1, 1*22*38, 128] tokens so + // the downstream replacement pass sees them identically regardless + // of which path produced them. + let latents = Tensor::zeros( + (1, LTX2_VIDEO_LATENT_CHANNELS, 1, 22, 38), + DType::F32, + &Device::Cpu, + ) + .unwrap(); + let patchifier = super::VideoLatentPatchifier::new(1); + let tokens = patchifier.patchify(&latents).expect("patchify"); + assert_eq!(tokens.dims(), &[1, 22 * 38, LTX2_VIDEO_LATENT_CHANNELS]); + } + #[test] fn video_conditioning_self_attention_mask_blocks_cross_keyframe_attention() { let conditioning = StageVideoConditioning { From 14801c78574cf78b0c0bc01df862f5d7bd3147fc Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 17:28:37 -0700 Subject: [PATCH 06/31] feat(ltx2): chain orchestrator with motion-tail carryover loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `Ltx2ChainOrchestrator` that drives the per-stage render loop for chained video generation: builds each stage's `GenerateRequest`, threads the prior stage's `ChainTail` through the renderer, drops the leading motion-tail frames on every continuation, accumulates frames, and returns a `ChainRunOutput`. The `ChainStageRenderer` trait is the seam between the orchestrator (pure control flow) and the engine (tensor work). The LTX-2 engine implementation lands in Phase 1d — this commit ships the orchestrator fully tested against a fake renderer so the engine plumbing can be reviewed in isolation. Behaviour nailed down (from the 2026-04-20 sign-off): - **Per-stage seeds**: `base_seed ^ ((stage_idx as u64) << 32)`. A stage's `seed_offset` overrides the default when set — reserved for the v2 movie-maker's "regen just this stage" affordance. - **Motion-tail trim**: stage 0 emits all its frames; continuations drop the leading `req.motion_tail_frames` pixel frames because those duplicate the previous clip's tail that was threaded back as latent conditioning. `motion_tail_frames = 0` is a legitimate configuration (simple concat). - **Fail closed**: a mid-chain renderer error bubbles up immediately. All frames accumulated so far are discarded — no partial stitch is ever written to the gallery. Partial resume is a v2 feature. - **No audio or target-total-frame trim in v1**: the orchestrator delivers whatever frame count the stages produce (with tail drops applied). Target-total trimming is the caller's responsibility (server / CLI). Audio-video chains are out of scope for v1. Progress events forwarded through `Option<&mut dyn FnMut(ChainProgressEvent)>`: `ChainStart` → `StageStart` → `DenoiseStep` (wrapping the renderer's `StageProgressEvent`s with `stage_idx`) → `StageDone` → (next stage) → `Stitching`. Chain-level subscribers can render a stacked overall+per-stage progress bar without coordinating with the engine. Per-stage `GenerateRequest` is constructed to ensure only stage 0 carries the optional starting image — even if the caller forgot to clear it on later stages, the orchestrator suppresses it because continuations must condition on motion-tail latents only. `strength` becomes `1.0` on continuations regardless of the chain default since the tail carryover is always a hard replacement. Tests (weight-free, injecting a `FakeRenderer`): - `chain_runs_all_stages_and_drops_tail_prefix_from_continuations` — 3×97-frame clips with 4-frame tail produce exactly 97 + 2×93 = 283 accumulated frames. - `chain_with_zero_tail_concats_full_clips_without_drop` — `tail=0` keeps every frame on continuations. - `chain_empty_stages_errors_without_calling_renderer` — zero-stage requests fail before touching the renderer. - `chain_fails_closed_mid_chain_discarding_accumulated_frames` — simulated stage-1 failure bubbles up; stage 2 never runs. - `chain_derives_per_stage_seed_from_base_seed` — three stages from base seed 42 land on 42, 42^(1<<32), 42^(2<<32). - `chain_only_stage0_carries_source_image` — a source image set on stages[1] is suppressed, so continuations can't accidentally condition on a still image instead of the motion tail. - `chain_forwards_engine_events_with_stage_idx_wrapping` — checks the full expected event order for a 2-stage chain with per-stage progress emission. - `chain_rejects_motion_tail_ge_stage_frames_before_running` — up-front validation catches `motion_tail >= frames` so the renderer is never invoked with a degenerate configuration. - `chain_respects_seed_offset_override_when_stage_provides_one` — pins `ChainStage::seed_offset` semantics for the v2 movie-maker hook. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/mold-inference/src/ltx2/chain.rs | 593 +++++++++++++++++++++++- crates/mold-inference/src/ltx2/mod.rs | 5 +- 2 files changed, 596 insertions(+), 2 deletions(-) diff --git a/crates/mold-inference/src/ltx2/chain.rs b/crates/mold-inference/src/ltx2/chain.rs index b1c4c91e..dfcb08df 100644 --- a/crates/mold-inference/src/ltx2/chain.rs +++ b/crates/mold-inference/src/ltx2/chain.rs @@ -12,9 +12,11 @@ //! //! See `tasks/render-chain-v1-plan.md` Phase 1.1 for context. -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use candle_core::Tensor; use image::RgbImage; +use mold_core::chain::{ChainProgressEvent, ChainRequest, ChainStage}; +use mold_core::{GenerateRequest, OutputFormat}; use crate::ltx2::model::shapes::SpatioTemporalScaleFactors; @@ -103,6 +105,266 @@ pub fn extract_tail_latents(final_latents: &Tensor, pixel_frames: u32) -> Result .with_context(|| format!("narrow last {tail} latent frames off time axis")) } +// ── Orchestrator: loops stages, drops motion-tail prefix, accumulates frames + +/// Per-stage progress events the orchestrator observes from the renderer. +/// The renderer emits these synchronously while a stage is denoising; the +/// orchestrator wraps them with `stage_idx` before forwarding as +/// [`ChainProgressEvent`]s to the chain-level subscriber. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StageProgressEvent { + /// Denoise step `step` of `total` completed for the active stage. + DenoiseStep { step: u32, total: u32 }, +} + +/// Output of a single stage render: the decoded pixel frames (full clip, +/// before motion-tail trim), the pre-VAE-decode latent tail the next stage +/// needs, and the wall-clock elapsed time for the render. +#[derive(Debug)] +pub struct StageOutcome { + pub frames: Vec, + pub tail: ChainTail, + pub generation_time_ms: u64, +} + +/// Abstraction over "render one chain stage". Production uses the LTX-2 +/// engine impl (lands in Phase 1d); tests inject a fake implementation +/// that fabricates deterministic frames and a synthetic [`ChainTail`] +/// without loading candle weights. +pub trait ChainStageRenderer { + fn render_stage( + &mut self, + stage_req: &GenerateRequest, + carry: Option<&ChainTail>, + stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, + ) -> Result; +} + +/// Output of an end-to-end chain run: accumulated RGB frames with motion- +/// tail prefix already trimmed on continuations, the number of stages +/// that ran, and the total elapsed render time. +/// +/// The orchestrator does *not* trim to a target total frame count or +/// encode the frames into an output video — those are the caller's job +/// (server / CLI). Keeps the orchestrator single-purpose: produce a +/// coherent frame stream from a stages list. +#[derive(Debug)] +pub struct ChainRunOutput { + pub frames: Vec, + pub stage_count: u32, + pub generation_time_ms: u64, +} + +/// Drives the per-stage render loop for a chained generation. Borrows its +/// renderer mutably so the loop can re-enter the engine on the same GPU +/// context across stages. +pub struct Ltx2ChainOrchestrator<'a, R: ChainStageRenderer> { + renderer: &'a mut R, +} + +impl<'a, R: ChainStageRenderer> Ltx2ChainOrchestrator<'a, R> { + pub fn new(renderer: &'a mut R) -> Self { + Self { renderer } + } + + /// Run every stage in `req.stages` and return the accumulated frames. + /// + /// Behaviour invariants (from the 2026-04-20 sign-off): + /// - Per-stage seeds are derived as `base_seed ^ ((stage_idx as u64) << 32)`. + /// - Stage 0's output is kept whole; continuations drop their leading + /// `req.motion_tail_frames` pixel frames because those duplicate the + /// prior stage's tail that was threaded back as latent conditioning. + /// - Mid-chain failure returns the error immediately; partial frames are + /// discarded (no partial stitch is ever produced in v1). + pub fn run( + &mut self, + req: &ChainRequest, + mut chain_progress: Option<&mut dyn FnMut(ChainProgressEvent)>, + ) -> Result { + if req.stages.is_empty() { + bail!("Ltx2ChainOrchestrator::run: chain request has no stages"); + } + validate_motion_tail(req)?; + + let stage_count = req.stages.len() as u32; + let estimated_total_frames = estimate_stitched_frames(req); + if let Some(cb) = chain_progress.as_deref_mut() { + cb(ChainProgressEvent::ChainStart { + stage_count, + estimated_total_frames, + }); + } + + let base_seed = req.seed.unwrap_or(0); + let motion_tail_drop = req.motion_tail_frames as usize; + let mut accumulated_frames: Vec = Vec::new(); + let mut total_generation_ms: u64 = 0; + let mut carry: Option = None; + + for (idx, stage) in req.stages.iter().enumerate() { + let stage_idx = idx as u32; + if let Some(cb) = chain_progress.as_deref_mut() { + cb(ChainProgressEvent::StageStart { stage_idx }); + } + + let stage_seed = derive_stage_seed(base_seed, idx, stage); + let stage_req = build_stage_generate_request(stage, req, stage_seed, idx); + + // Wrap the chain progress subscriber so per-stage denoise + // events land on it with `stage_idx` tagged in. The wrapping + // closure holds a mutable reborrow of the outer callback for + // just the duration of this call — `render_stage` is + // synchronous so the reborrow ends before the next iteration. + let outcome = match chain_progress.as_deref_mut() { + Some(chain_cb) => { + let mut wrapping = |event: StageProgressEvent| match event { + StageProgressEvent::DenoiseStep { step, total } => { + chain_cb(ChainProgressEvent::DenoiseStep { + stage_idx, + step, + total, + }); + } + }; + self.renderer + .render_stage(&stage_req, carry.as_ref(), Some(&mut wrapping))? + } + None => self + .renderer + .render_stage(&stage_req, carry.as_ref(), None)?, + }; + + let mut frames = outcome.frames; + if idx > 0 && motion_tail_drop > 0 { + if motion_tail_drop >= frames.len() { + bail!( + "stage {stage_idx}: emitted {} frames but motion_tail_drop={motion_tail_drop} — tail would consume the whole clip", + frames.len(), + ); + } + frames.drain(..motion_tail_drop); + } + let frames_emitted = frames.len() as u32; + accumulated_frames.extend(frames); + total_generation_ms = total_generation_ms.saturating_add(outcome.generation_time_ms); + carry = Some(outcome.tail); + + if let Some(cb) = chain_progress.as_deref_mut() { + cb(ChainProgressEvent::StageDone { + stage_idx, + frames_emitted, + }); + } + } + + if let Some(cb) = chain_progress.as_deref_mut() { + cb(ChainProgressEvent::Stitching { + total_frames: accumulated_frames.len() as u32, + }); + } + + Ok(ChainRunOutput { + frames: accumulated_frames, + stage_count, + generation_time_ms: total_generation_ms, + }) + } +} + +fn validate_motion_tail(req: &ChainRequest) -> Result<()> { + for (idx, stage) in req.stages.iter().enumerate() { + if req.motion_tail_frames >= stage.frames { + bail!( + "motion_tail_frames ({}) must be strictly less than stage {idx}'s frames ({}) \ + so every continuation emits at least one new frame", + req.motion_tail_frames, + stage.frames, + ); + } + } + Ok(()) +} + +fn estimate_stitched_frames(req: &ChainRequest) -> u32 { + // delivered = stages[0].frames + Σ (stages[i].frames - motion_tail) for i >= 1 + let tail = req.motion_tail_frames; + req.stages + .iter() + .enumerate() + .map(|(idx, stage)| { + if idx == 0 { + stage.frames + } else { + stage.frames.saturating_sub(tail) + } + }) + .sum() +} + +fn derive_stage_seed(base_seed: u64, idx: usize, stage: &ChainStage) -> u64 { + if let Some(offset) = stage.seed_offset { + base_seed ^ offset + } else { + base_seed ^ ((idx as u64) << 32) + } +} + +fn build_stage_generate_request( + stage: &ChainStage, + chain: &ChainRequest, + stage_seed: u64, + idx: usize, +) -> GenerateRequest { + GenerateRequest { + prompt: stage.prompt.clone(), + negative_prompt: stage.negative_prompt.clone(), + model: chain.model.clone(), + width: chain.width, + height: chain.height, + steps: chain.steps, + guidance: chain.guidance, + seed: Some(stage_seed), + batch_size: 1, + // Continuation stages never use the per-chain output_format + // downstream — the orchestrator decodes to frames regardless — + // but MP4 is the canonical intermediate for LTX-2. + output_format: OutputFormat::Mp4, + embed_metadata: None, + scheduler: None, + // Stage 0 carries the optional starting image; continuations + // get their conditioning from motion-tail latents via the + // `carry` argument to `render_stage`. + source_image: if idx == 0 { + stage.source_image.clone() + } else { + None + }, + edit_images: None, + strength: if idx == 0 { chain.strength } else { 1.0 }, + mask_image: None, + control_image: None, + control_model: None, + control_scale: 1.0, + expand: None, + original_prompt: None, + lora: None, + frames: Some(stage.frames), + fps: Some(chain.fps), + upscale_model: None, + gif_preview: false, + enable_audio: Some(false), // v1 chain: no audio plumbing yet + audio_file: None, + source_video: None, + keyframes: None, + pipeline: None, + loras: None, + retake_range: None, + spatial_upscale: None, + temporal_upscale: None, + placement: chain.placement.clone(), + } +} + #[cfg(test)] mod tests { use super::*; @@ -194,4 +456,333 @@ mod tests { "error must name the latent-frame mismatch, got: {msg}", ); } + + // ── Orchestrator tests (fake renderer, weight-free) ─────────────── + + use image::Rgb; + use mold_core::chain::ChainStage; + + /// Deterministic fake renderer for orchestrator tests. Records every + /// call so assertions can inspect the per-stage request shape, emits + /// a solid-color frame block plus a zero-valued latent tail, and + /// optionally returns errors on pre-configured stage indices. + struct FakeRenderer { + calls: Vec, + /// If set, fail on the listed stage indices with the given message. + fail_on: Vec<(usize, String)>, + /// Per-call override of frame count (default: use stage_req.frames). + frame_count_override: Option, + /// If true, emit one DenoiseStep event per stage so tests can + /// verify progress forwarding. + emit_progress: bool, + } + + #[derive(Debug, Clone)] + struct CallRecord { + seed: Option, + frames: Option, + has_source_image: bool, + has_carry: bool, + } + + impl FakeRenderer { + fn new() -> Self { + Self { + calls: Vec::new(), + fail_on: Vec::new(), + frame_count_override: None, + emit_progress: false, + } + } + } + + impl ChainStageRenderer for FakeRenderer { + fn render_stage( + &mut self, + stage_req: &GenerateRequest, + carry: Option<&ChainTail>, + mut stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, + ) -> Result { + let idx = self.calls.len(); + self.calls.push(CallRecord { + seed: stage_req.seed, + frames: stage_req.frames, + has_source_image: stage_req.source_image.is_some(), + has_carry: carry.is_some(), + }); + if let Some((_, msg)) = self.fail_on.iter().find(|(stage_idx, _)| *stage_idx == idx) { + bail!("{msg}"); + } + if self.emit_progress { + if let Some(cb) = stage_progress.as_deref_mut() { + cb(StageProgressEvent::DenoiseStep { step: 1, total: 1 }); + } + } + + let frame_count = self + .frame_count_override + .unwrap_or_else(|| stage_req.frames.expect("fake renderer: stage_req.frames")); + let width = stage_req.width; + let height = stage_req.height; + // Colour the frames with the stage index so assertions can + // verify which stage a frame came from. + let mut frames = Vec::with_capacity(frame_count as usize); + for frame_num in 0..frame_count { + let channel = (idx as u8).wrapping_mul(37).wrapping_add(frame_num as u8); + frames.push(RgbImage::from_pixel(width, height, Rgb([channel, 0, 0]))); + } + let last_frame = frames.last().cloned().unwrap(); + + // Build a synthetic tail latent at the "right" shape for the + // requested motion tail. Shape isn't validated by the + // orchestrator itself — the engine impl in Phase 1d will check. + let latent = Tensor::zeros( + (1, 128, 1, height as usize / 32, width as usize / 32), + DType::F32, + &Device::Cpu, + ) + .unwrap(); + + Ok(StageOutcome { + frames, + tail: ChainTail { + frames: 4, + latents: latent, + last_rgb_frame: last_frame, + }, + generation_time_ms: 100, + }) + } + } + + fn stage(prompt: &str, frames: u32) -> ChainStage { + ChainStage { + prompt: prompt.into(), + frames, + source_image: None, + negative_prompt: None, + seed_offset: None, + } + } + + fn chain_req(stages: Vec, motion_tail_frames: u32) -> ChainRequest { + ChainRequest { + model: "ltx-2-19b-distilled:fp8".into(), + stages, + motion_tail_frames, + width: 1216, + height: 704, + fps: 24, + seed: Some(42), + steps: 8, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Mp4, + placement: None, + prompt: None, + total_frames: None, + clip_frames: None, + source_image: None, + } + } + + #[test] + fn chain_runs_all_stages_and_drops_tail_prefix_from_continuations() { + let stages = vec![stage("a", 97), stage("a", 97), stage("a", 97)]; + let req = chain_req(stages, 4); + let mut renderer = FakeRenderer::new(); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let out = orch.run(&req, None).expect("chain runs"); + // Stage 0 keeps all 97 frames; each continuation drops the + // leading 4 frames, so delivered = 97 + 2 * (97 - 4) = 97 + 186 = 283. + assert_eq!(out.frames.len(), 97 + 93 * 2); + assert_eq!(out.stage_count, 3); + assert_eq!(renderer.calls.len(), 3); + // Stage 0 has no carry; later stages do. + assert!(!renderer.calls[0].has_carry); + assert!(renderer.calls[1].has_carry); + assert!(renderer.calls[2].has_carry); + } + + #[test] + fn chain_with_zero_tail_concats_full_clips_without_drop() { + let stages = vec![stage("a", 97), stage("a", 97)]; + let req = chain_req(stages, 0); + let mut renderer = FakeRenderer::new(); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let out = orch.run(&req, None).expect("chain runs"); + assert_eq!( + out.frames.len(), + 97 * 2, + "zero motion tail must keep every frame on continuations", + ); + } + + #[test] + fn chain_empty_stages_errors_without_calling_renderer() { + let req = chain_req(vec![], 4); + let mut renderer = FakeRenderer::new(); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let err = orch.run(&req, None).expect_err("empty stages must fail"); + assert!( + format!("{err}").contains("has no stages"), + "error must name the missing stages, got: {err}", + ); + assert!(renderer.calls.is_empty()); + } + + #[test] + fn chain_fails_closed_mid_chain_discarding_accumulated_frames() { + // Signed-off decision 2026-04-20: mid-chain failure returns the + // error immediately and throws away any frames already produced. + // No partial stitch is ever written to the gallery. + let stages = vec![stage("a", 97), stage("a", 97), stage("a", 97)]; + let req = chain_req(stages, 4); + let mut renderer = FakeRenderer::new(); + renderer.fail_on = vec![(1, "simulated GPU OOM on stage 1".into())]; + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let err = orch + .run(&req, None) + .expect_err("mid-chain failure must bubble up"); + assert!( + format!("{err}").contains("simulated GPU OOM"), + "error must carry the renderer's message, got: {err}", + ); + // Stage 0 ran (recorded), stage 1 failed (recorded before bail), + // stage 2 never ran. + assert_eq!(renderer.calls.len(), 2); + } + + #[test] + fn chain_derives_per_stage_seed_from_base_seed() { + let stages = vec![stage("a", 9), stage("a", 9), stage("a", 9)]; + let mut req = chain_req(stages, 0); + req.seed = Some(42); + let mut renderer = FakeRenderer::new(); + renderer.frame_count_override = Some(9); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + orch.run(&req, None).expect("chain runs"); + // Per the sign-off: stage_seed = base ^ ((idx as u64) << 32). + assert_eq!(renderer.calls[0].seed, Some(42)); + assert_eq!(renderer.calls[1].seed, Some(42 ^ (1u64 << 32))); + assert_eq!(renderer.calls[2].seed, Some(42 ^ (2u64 << 32))); + } + + #[test] + fn chain_only_stage0_carries_source_image() { + let mut stages = vec![stage("a", 9), stage("a", 9)]; + stages[0].source_image = Some(vec![0x89, 0x50, 0x4e, 0x47]); // PNG magic + // If a caller forgets to clear later stages' source_image, the + // orchestrator still suppresses it — continuations must always + // condition on motion-tail latents, never on a staged image. + stages[1].source_image = Some(vec![0x89, 0x50, 0x4e, 0x47]); + let req = chain_req(stages, 0); + let mut renderer = FakeRenderer::new(); + renderer.frame_count_override = Some(9); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + orch.run(&req, None).expect("chain runs"); + assert!(renderer.calls[0].has_source_image); + assert!(!renderer.calls[1].has_source_image); + } + + #[test] + fn chain_forwards_engine_events_with_stage_idx_wrapping() { + let stages = vec![stage("a", 9), stage("a", 9)]; + let req = chain_req(stages, 0); + let mut renderer = FakeRenderer::new(); + renderer.frame_count_override = Some(9); + renderer.emit_progress = true; + + let mut events: Vec = Vec::new(); + { + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let mut cb = |e: ChainProgressEvent| events.push(e); + orch.run(&req, Some(&mut cb)).expect("chain runs"); + } + + // Expected order: + // ChainStart, StageStart(0), DenoiseStep(0), StageDone(0), + // StageStart(1), DenoiseStep(1), StageDone(1), Stitching + assert!(matches!( + events[0], + ChainProgressEvent::ChainStart { stage_count: 2, .. } + )); + assert!(matches!( + events[1], + ChainProgressEvent::StageStart { stage_idx: 0 } + )); + assert!(matches!( + events[2], + ChainProgressEvent::DenoiseStep { + stage_idx: 0, + step: 1, + total: 1 + } + )); + assert!(matches!( + events[3], + ChainProgressEvent::StageDone { + stage_idx: 0, + frames_emitted: 9 + } + )); + assert!(matches!( + events[4], + ChainProgressEvent::StageStart { stage_idx: 1 } + )); + assert!(matches!( + events[5], + ChainProgressEvent::DenoiseStep { + stage_idx: 1, + step: 1, + total: 1 + } + )); + assert!(matches!( + events[6], + ChainProgressEvent::StageDone { + stage_idx: 1, + frames_emitted: 9 + } + )); + assert!(matches!( + events[7], + ChainProgressEvent::Stitching { total_frames: 18 } + )); + assert_eq!(events.len(), 8); + } + + #[test] + fn chain_rejects_motion_tail_ge_stage_frames_before_running() { + let stages = vec![stage("a", 9), stage("a", 9)]; + // tail=9 equals stage frames — no net-new content on continuation. + let req = chain_req(stages, 9); + let mut renderer = FakeRenderer::new(); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let err = orch.run(&req, None).expect_err("must fail"); + assert!( + format!("{err}").contains("motion_tail_frames"), + "error must name motion_tail_frames, got: {err}", + ); + // Renderer never gets called because validation runs up-front. + assert!(renderer.calls.is_empty()); + } + + #[test] + fn chain_respects_seed_offset_override_when_stage_provides_one() { + let mut stages = vec![stage("a", 9), stage("a", 9)]; + stages[1].seed_offset = Some(0xDEADBEEF); + let mut req = chain_req(stages, 0); + req.seed = Some(100); + let mut renderer = FakeRenderer::new(); + renderer.frame_count_override = Some(9); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + orch.run(&req, None).expect("runs"); + assert_eq!(renderer.calls[0].seed, Some(100)); + assert_eq!( + renderer.calls[1].seed, + Some(100 ^ 0xDEADBEEFu64), + "seed_offset must take precedence over the default index-derived seed", + ); + } } diff --git a/crates/mold-inference/src/ltx2/mod.rs b/crates/mold-inference/src/ltx2/mod.rs index 9858e1d5..ac0c5b6e 100644 --- a/crates/mold-inference/src/ltx2/mod.rs +++ b/crates/mold-inference/src/ltx2/mod.rs @@ -14,5 +14,8 @@ mod runtime; mod sampler; mod text; -pub use chain::{extract_tail_latents, tail_latent_frame_count, ChainTail}; +pub use chain::{ + extract_tail_latents, tail_latent_frame_count, ChainRunOutput, ChainStageRenderer, ChainTail, + Ltx2ChainOrchestrator, StageOutcome, StageProgressEvent, +}; pub use pipeline::Ltx2Engine; From 350d27ec813a14efc1dc9475595319bfc8b193b9 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 17:48:42 -0700 Subject: [PATCH 07/31] docs(chain): render-chain-v1 context handoff for resuming work MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Capture the state of the branch (6 commits on local main, not pushed), the five signed-off design decisions, the Phase 1d → 2 → 3 → 4 remaining work with specific file:line surgery points, and a ready-to-paste prompt for a fresh Claude Code session. Gotchas documented: stale `test = false` claim in CLAUDE.md, pre-existing clippy warnings unrelated to this branch, VAE 8× causal temporal ratio already encoded by `extract_tail_latents`, and the existing-parameter-reuse opportunity on `run_real_distilled_stage` (no new params needed). Co-Authored-By: Claude Opus 4.7 (1M context) --- tasks/render-chain-v1-handoff.md | 311 +++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 tasks/render-chain-v1-handoff.md diff --git a/tasks/render-chain-v1-handoff.md b/tasks/render-chain-v1-handoff.md new file mode 100644 index 00000000..540e04cc --- /dev/null +++ b/tasks/render-chain-v1-handoff.md @@ -0,0 +1,311 @@ +# render-chain-v1 — context handoff + +> Paste the prompt at the bottom of this file into a fresh Claude Code session +> to resume work on render-chain v1. Everything above it is reference material +> that the prompt points at. + +## Status + +Branch: `main` (local). **6 commits stacked ahead of `origin/main`, not pushed** +per plan convention (no mid-plan push): + +| # | Commit | Scope | Phase | +|---|-----------|----------|-------| +| 1 | `d13a554` | `fix(ltx2): use pure source latents as i2v denoise-mask target` | Fix A (prereq) | +| 2 | `b4ed487` | `feat(chain): add core wire types and request normalisation` | 0.1 | +| 3 | `0328e76` | `feat(core): MoldClient chain methods` | 0.2 | +| 4 | `e89826f` | `feat(ltx2): ChainTail type and latent-tail extraction helper` | 1a | +| 5 | `e917210` | `feat(ltx2): staged latent conditioning bypasses VAE encode` | 1b | +| 6 | `14801c7` | `feat(ltx2): chain orchestrator with motion-tail carryover loop` | 1c | + +Test status on commit 6: `mold-core` 617 pass, `mold-inference` 586 pass, +`cargo fmt --check` clean, no candle weights loaded by any test. + +Pre-existing clippy warnings on main (NOT introduced by this branch): +- `crates/mold-core/src/download.rs:1451` — `manual_repeat_n` +- `crates/mold-core/src/placement_test.rs:167` — `field_reassign_with_default` + +These only fire on newer clippy versions than CI pins and are unrelated to +the chain work. Don't "fix" them as part of render-chain. + +## Signed-off design decisions (do NOT re-litigate) + +User confirmed these 2026-04-20 and they're recorded at the top of +`tasks/render-chain-v1-plan.md`: + +1. **Trim over-production from the tail** of the final clip, not the head. +2. **Per-stage seed derivation: `stage_seed = base_seed ^ ((stage_idx as u64) << 32)`.** + `ChainStage::seed_offset` overrides this; reserved for the v2 movie-maker. +3. **Fail closed on mid-chain failure.** 502 + discard all prior stages. No + partial stitch. +4. **Accept ~1 GB RAM ceiling** for accumulated `RgbImage` buffer. Streaming + encode revisited at 1000+ frames. +5. **Single-GPU per chain.** Multi-GPU stage fan-out is v2. + +The orchestrator already encodes 1, 2, 3 and Phase 2 server route handles 3. + +## What's done + +- **`mold_core::chain`** — wire types (`ChainRequest`, `ChainResponse`, + `ChainStage`, `ChainProgressEvent`, `SseChainCompleteEvent`) and + `ChainRequest::normalise()`. Re-exports from `mold_core`. +- **`MoldClient::generate_chain{,_stream}`** with 422 → Validation, 404-with- + body → ModelNotFound, empty-404 → hard error (non-streaming) / `Ok(None)` + (streaming). Wiremock integration tests pin all four paths. +- **`ltx2::chain::ChainTail` + `extract_tail_latents`** — pure tensor math, + VAE formula `((pixel - 1) / 8) + 1`. Errors (not panics) on rank + mismatch / oversize tail. +- **`StagedLatent` + `StagedConditioning.latents`** — threaded through + `maybe_load_stage_video_conditioning` in `runtime.rs`. When the latents + vec is non-empty, the function builds `VideoTokenReplacement`s straight + from pre-encoded tokens and **skips VAE load entirely** (conditional + `Option` — confirmed only loaded when images or reference video are + present). +- **`Ltx2ChainOrchestrator`** — fully tested against + a fake renderer. Handles seed derivation, motion-tail trim on + continuations (stage 0 keeps all frames, continuations drop leading K), + progress forwarding with `stage_idx` wrapping, fail-closed error handling. + Orchestrator does NOT trim to a target total or encode MP4 — those are + caller responsibilities. + +## What's remaining + +### Phase 1d — `impl ChainStageRenderer for Ltx2Engine` (engine integration) + +The one-sentence contract: given `stage_req`, optional `carry: &ChainTail`, +and an optional stage-progress callback, return +`StageOutcome { frames, tail, generation_time_ms }`. + +Three sub-tasks: + +1. **Tail capture slot.** Add a mechanism for `render_real_distilled_av` + (`crates/mold-inference/src/ltx2/runtime.rs:1722`) to clone the + pre-VAE-decode `latents` tensor into a caller-provided slot. The exact + capture point is immediately before `vae.decode(&latents...)` at + `runtime.rs:2010` — shape is `[1, 128, T_latent, H/32, W/32]` F32. + Preferred mechanism: a field on `Ltx2RuntimeSession` (or a method + argument threaded down) holding `Option>>>`. + Production non-chain callers leave it `None` and pay no overhead. + +2. **`Ltx2Engine::generate_with_carryover(&mut self, req, carry)`**: + - Validate the request is a supported family (v1 scope: distilled LTX-2 + only — see `select_pipeline` at `crates/mold-inference/src/ltx2/pipeline.rs:108`). + - Build a `Ltx2GeneratePlan` via the existing `materialize_request` flow. + When `carry.is_some()`, wipe `source_image` and append a + `StagedLatent { latents: carry.latents.clone(), frame: 0, strength: 1.0 }` + to `plan.conditioning.latents`. The runtime already handles the rest + (`maybe_load_stage_video_conditioning` skips VAE, builds a frame-0 + replacement from patchified tokens). + - Enable the tail-capture slot. + - Run the existing render → decode → encode pipeline. + - Pull the captured latents out of the slot. + - Call `ltx2::chain::extract_tail_latents(&captured, motion_tail_frames)` + to get the tail slice. + - Decode the stitched MP4 once to extract `last_rgb_frame` (or capture + it alongside the `frames` Vec from `decoded_video_to_frames`). + - Return `(GenerateResponse, ChainTail)`. + +3. **`impl ChainStageRenderer for Ltx2Engine`** that delegates to + `generate_with_carryover`. The orchestrator's fake-renderer tests + define the exact contract; no new test harness needed for the impl — + real-engine coverage is Phase 2's integration test. + +**Gotchas:** +- `CLAUDE.md` claims `[lib] test = false` on `mold-inference` and + `mold-server` — **this is stale.** Both have normal test configs. Verified + in Phase 1a/b/c by running 586 tests. +- `run_real_distilled_stage` already takes + `video_clean_latents: Option<&Tensor>` and `video_denoise_mask: Option<&Tensor>` — + don't add new parameters unnecessarily. The tail carryover rides on + `conditioning.replacements` via `StagedLatent`, not on `video_clean_latents`. +- VAE temporal ratio is **8× with causal first frame** (`model/shapes.rs:20`). + `extract_tail_latents` already encodes this; just call it. +- `motion_tail_frames` defaults to 4 per plan; orchestrator validates + `motion_tail < stage.frames` up front, but the engine should still + tolerate `motion_tail = 0` (simple concat, no latent carryover — `carry` + will be `None` for every stage in that configuration). + +### Phase 2 — `POST /api/generate/chain[/stream]` server route + +Plan §2. Handler flow: + +1. Parse + `ChainRequest::normalise()`. +2. Reject non-LTX-2 models with a clear error. +3. Grab the engine from `ModelCache` (`crates/mold-server/src/lib.rs` — + holds `AppState.model_cache: Arc>`). +4. Construct `Ltx2ChainOrchestrator` against it, call `run()`. +5. Trim accumulated frames to target total (the ChainRequest no longer + carries `total_frames` after normalise — if you want tail-trim support, + add a `target_total_frames: Option` field that normalise + populates). Per the sign-off: trim from the tail. +6. Encode stitched MP4. Reuse `ltx2::media::encode_frames_to_mp4` or the + existing `encode_native_video` path — scout during Phase 2. +7. Save via `save_video_to_dir` with an `OutputMetadata` synthesised from + `stages[0].prompt`; optionally add `chain_stage_count: Option` to + `OutputMetadata`. +8. Return `ChainResponse` JSON. + +**Do NOT go through the existing single-job queue.** A 10+ minute chain +would block the queue. Instead hold the `ModelCache` mutex directly for +the chain duration, same pattern as the multi-GPU pool. Reason in plan §2.1. + +SSE variant: same flow, stream `ChainProgressEvent` as `event: progress` +JSON frames and a final `SseChainCompleteEvent` as `event: complete`. + +Tests: route-level with a fake engine (same trait seam as Phase 1c). No +real weights. + +### Phase 3 — CLI auto-routing + flags + +When `--frames > clip_cap` (97 for LTX-2 19B/22B distilled), build a +`ChainRequest` from the CLI args and route to +`MoldClient::generate_chain_stream`. New flags: `--clip-frames N`, +`--motion-tail N` (default 4). + +Stacked progress bar: one parent bar per chain (estimated total frames), +one per-stage bar wiping between stages. + +`--local` parity: factor the orchestrator invocation so both server +handler and CLI local path use the same code. + +### Phase 4 — docs + +- `website/guide/video.md`: new "Chained video output" section explaining + `--frames N`, motion tail, and the server endpoint. +- `CHANGELOG.md`: Unreleased/Added entry. +- `.claude/skills/mold/SKILL.md`: new CLI flags + endpoint. + +## Verification commands + +Run these in order after any Phase 1d change to verify nothing regressed: + +```bash +cargo fmt -p mold-ai-inference -- --check +cargo check -p mold-ai-inference +cargo test -p mold-ai-inference --lib ltx2::chain:: # orchestrator + tail helpers +cargo test -p mold-ai-inference --lib # full 586-test sweep (~35 s) +cargo test -p mold-ai-core # sanity +``` + +Phase 1d's own tests should live alongside existing `pipeline.rs::tests` +patterns (using `with_runtime_session` injection at +`crates/mold-inference/src/ltx2/pipeline.rs:1062` — the existing test +exercises the runtime without real weights). + +## File map — where everything lives now + +``` +NEW (this branch): + crates/mold-core/src/chain.rs # wire types + normalise + crates/mold-core/tests/chain_client.rs # wiremock integration + crates/mold-inference/src/ltx2/chain.rs # ChainTail + orchestrator + +MODIFIED (this branch): + crates/mold-core/src/lib.rs # re-exports + crates/mold-core/src/types.rs # pub(crate) base64_opt + crates/mold-core/src/client.rs # generate_chain{,_stream} + crates/mold-inference/src/ltx2/mod.rs # pub use chain::* + crates/mold-inference/src/ltx2/conditioning.rs # StagedLatent + crates/mold-inference/src/ltx2/runtime.rs # latents loop + Fix A + +TARGETS (Phase 1d): + crates/mold-inference/src/ltx2/pipeline.rs # Ltx2Engine::generate_with_carryover + crates/mold-inference/src/ltx2/runtime.rs # tail-capture slot on session + +TARGETS (Phase 2+): + crates/mold-server/src/routes_chain.rs # NEW + crates/mold-server/src/lib.rs # route registration + crates/mold-cli/src/main.rs # auto-route + crates/mold-cli/src/commands/generate.rs # chain path + local parity + website/guide/video.md # docs + CHANGELOG.md + .claude/skills/mold/SKILL.md +``` + +## Convention reminders + +- Feature branch: `feat/render-chain-v1` (currently committing directly to + local `main` since pre-push). PR target: `main`. +- Commit scopes: `feat(chain)`, `fix(chain)`, `test(chain)`, `docs(chain)` + (core), or `feat(ltx2)`, `feat(server)`, `feat(cli)` depending on crate. +- **No mid-plan push.** All work accumulates locally until Phase 4 ends. +- Every phase step ends with a commit; verification (`fmt`, `test`) + between every step. +- Tests must be weight-free. Use the trait-seam pattern (Phase 1c) or the + `with_runtime_session` injection pattern (`pipeline.rs:1062`). + +--- + +## The prompt + +Paste from here into a fresh Claude Code session: + +--- + +I'm continuing work on **render-chain v1** — server-side chained LTX-2 video +generation for the mold repo. + +## Read first, in this order + +1. `CLAUDE.md` (both global at `~/.claude-personal/CLAUDE.md` and + `/Users/jeffreydilley/github/mold/CLAUDE.md`). +2. `tasks/render-chain-v1-plan.md` — full design, signed-off decisions. +3. `tasks/render-chain-v1-handoff.md` — status, remaining work, gotchas. + **This is your primary briefing.** Read it end-to-end before writing code. + +## Status on entry + +- 6 commits stacked locally on `main`, not pushed (per plan convention). + Last commit: `14801c7 feat(ltx2): chain orchestrator with motion-tail carryover loop`. +- Phase 0 (core wire types + client) and Phase 1a/b/c (ltx2 chain types, + StagedLatent plumbing, orchestrator + fake-renderer tests) are done. +- `mold-inference` has 586 tests passing, `mold-core` 617. Nothing loads + candle weights. Fmt clean. +- `CLAUDE.md`'s claim that `mold-inference` has `[lib] test = false` is + **stale** — the previous session verified tests run normally. + +## What you're doing + +**Phase 1d** — the engine integration that makes the orchestrator actually +render. Spec in `render-chain-v1-handoff.md` under "Phase 1d". In one +sentence: implement `impl ChainStageRenderer for Ltx2Engine` by adding a +tail-capture slot to `Ltx2RuntimeSession` and a +`Ltx2Engine::generate_with_carryover` method that populates +`plan.conditioning.latents` from the `ChainTail` input and returns the +captured tail alongside the response. + +Key surgery points already scouted: +- Tail capture immediately before `vae.decode` at + `crates/mold-inference/src/ltx2/runtime.rs:2010` +- Plan's staged-latents plumbing already works — + `maybe_load_stage_video_conditioning` accepts pre-encoded latents when + you populate `plan.conditioning.latents` (Phase 1b). + +After Phase 1d, Phases 2 (server route), 3 (CLI), and 4 (docs) per the plan. + +## How to work + +- Use `superpowers:subagent-driven-development` — the plan is sized for it. +- Use `superpowers:verification-before-completion` before claiming any + phase done. The handoff doc has the exact verification commands. +- Every step ends with a commit. Commit scope `feat(ltx2)` for Phase 1d. +- Do NOT push anything — plan convention is no mid-plan push. +- Do NOT re-litigate the signed-off design decisions in the handoff doc. +- Tests must be weight-free (use the `with_runtime_session` injection + pattern from `pipeline.rs:1062` or the trait seam shipped in Phase 1c). + +## Start here + +1. Run `git status && git log --oneline -7` to confirm the 6 commits are + on the tree. +2. Read `tasks/render-chain-v1-handoff.md` end-to-end. +3. Delegate an Explore subagent to map `Ltx2RuntimeSession` and the full + `Ltx2Engine::generate` → `generate_inner` → `render_native_video` call + chain end-to-end before writing code. Cite file:line throughout. Keep + the report under 2000 words. +4. Then plan the tail-capture mechanism (decide: field on + `Ltx2RuntimeSession` vs. threaded parameter, ergonomics tradeoffs). +5. Implement. Commit. Then Phase 2. + +If you hit a surprise that invalidates an assumption in the plan or +handoff doc, stop and re-plan rather than papering over it. From 1c142e300fd248cfef68562eb380b0d051db7a17 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 18:13:42 -0700 Subject: [PATCH 08/31] feat(ltx2): Ltx2Engine chain stage renderer with pre-VAE latent tail capture Add a pre-VAE-decode tail-capture slot on Ltx2RuntimeSession threaded into render_real_distilled_av, implement Ltx2Engine::render_chain_stage that injects a carryover ChainTail as a StagedLatent and extracts the post-denoise tail, and wire it through impl ChainStageRenderer for Ltx2Engine. Distilled-only in v1; other pipeline families error up-front. Amend ChainStageRenderer::render_stage to carry motion_tail_pixel_frames so the engine knows how many frames to narrow off the emitted latents. Part of render-chain v1 (Phase 1d). Weight-free tests added; full mold-inference and mold-core lib test suites stay green. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/mold-inference/src/ltx2/chain.rs | 21 ++- crates/mold-inference/src/ltx2/pipeline.rs | 191 ++++++++++++++++++++- crates/mold-inference/src/ltx2/runtime.rs | 36 +++- 3 files changed, 239 insertions(+), 9 deletions(-) diff --git a/crates/mold-inference/src/ltx2/chain.rs b/crates/mold-inference/src/ltx2/chain.rs index dfcb08df..d40c95de 100644 --- a/crates/mold-inference/src/ltx2/chain.rs +++ b/crates/mold-inference/src/ltx2/chain.rs @@ -136,6 +136,7 @@ pub trait ChainStageRenderer { &mut self, stage_req: &GenerateRequest, carry: Option<&ChainTail>, + motion_tail_pixel_frames: u32, stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, ) -> Result; } @@ -226,12 +227,19 @@ impl<'a, R: ChainStageRenderer> Ltx2ChainOrchestrator<'a, R> { }); } }; - self.renderer - .render_stage(&stage_req, carry.as_ref(), Some(&mut wrapping))? + self.renderer.render_stage( + &stage_req, + carry.as_ref(), + req.motion_tail_frames, + Some(&mut wrapping), + )? } - None => self - .renderer - .render_stage(&stage_req, carry.as_ref(), None)?, + None => self.renderer.render_stage( + &stage_req, + carry.as_ref(), + req.motion_tail_frames, + None, + )?, }; let mut frames = outcome.frames; @@ -257,7 +265,7 @@ impl<'a, R: ChainStageRenderer> Ltx2ChainOrchestrator<'a, R> { } } - if let Some(cb) = chain_progress.as_deref_mut() { + if let Some(cb) = chain_progress.as_mut() { cb(ChainProgressEvent::Stitching { total_frames: accumulated_frames.len() as u32, }); @@ -501,6 +509,7 @@ mod tests { &mut self, stage_req: &GenerateRequest, carry: Option<&ChainTail>, + _motion_tail_pixel_frames: u32, mut stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, ) -> Result { let idx = self.calls.len(); diff --git a/crates/mold-inference/src/ltx2/pipeline.rs b/crates/mold-inference/src/ltx2/pipeline.rs index 1f584d14..7f419d4d 100644 --- a/crates/mold-inference/src/ltx2/pipeline.rs +++ b/crates/mold-inference/src/ltx2/pipeline.rs @@ -1,6 +1,6 @@ #![allow(clippy::type_complexity)] -use anyhow::{bail, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use candle_core::Device; use mold_core::{ GenerateRequest, GenerateResponse, Ltx2PipelineMode, ModelPaths, OutputFormat, VideoData, @@ -11,7 +11,10 @@ use std::time::Instant; use super::assets; use super::backend::Ltx2Backend; -use super::conditioning; +use super::chain::{ + extract_tail_latents, ChainStageRenderer, ChainTail, StageOutcome, StageProgressEvent, +}; +use super::conditioning::{self, StagedLatent}; use super::execution; use super::lora; use super::media::{self, ProbeMetadata}; @@ -522,6 +525,147 @@ impl Ltx2Engine { gpu: None, }) } + + /// Render a single chain stage, optionally conditioning on a carryover + /// tail from the prior stage. + /// + /// `motion_tail_pixel_frames` is the number of pixel frames to narrow + /// off the emitted latents for the *next* stage's carryover. `0` + /// returns an error (nonsensical — use the regular single-clip path + /// if no tail is wanted). + /// + /// Scope: distilled LTX-2 pipeline only. Other pipeline families + /// return an error up-front so the chain orchestrator fails fast. + pub(crate) fn render_chain_stage( + &mut self, + req: &GenerateRequest, + carry: Option<&ChainTail>, + motion_tail_pixel_frames: u32, + ) -> Result { + if motion_tail_pixel_frames == 0 { + bail!("render_chain_stage: motion_tail_pixel_frames must be > 0"); + } + if !self.loaded { + self.load()?; + } + let start = Instant::now(); + self.emit("Preparing native LTX-2 chain stage"); + + let pipeline = self.select_pipeline(req)?; + if !matches!(pipeline, PipelineKind::Distilled) { + bail!( + "render-chain v1 only supports the distilled LTX-2 pipeline, got {:?}", + pipeline, + ); + } + + let work_dir = tempfile::tempdir().context("failed to create LTX-2 temp directory")?; + let native_output = work_dir.path().join("ltx2-native-output.mp4"); + let mut plan = self.materialize_request(req, work_dir.path(), &native_output)?; + + // Inject carryover tail latents as StagedLatent on frame 0. The + // runtime detects a non-empty `conditioning.latents` and bypasses + // the VAE load entirely, patchifying the pre-encoded tokens into + // conditioning replacements directly (see conditioning.rs + // StagedLatent docstring + runtime.rs + // maybe_load_stage_video_conditioning). + if let Some(tail) = carry { + // The caller (orchestrator) is responsible for blanking + // source_image on continuation stages, but defence-in-depth: + // clear staged images so they can't compete with the latent + // carryover. + plan.conditioning.images.clear(); + plan.conditioning.latents.push(StagedLatent { + latents: tail.latents.clone(), + frame: 0, + strength: 1.0, + }); + } + + // Reuse an existing runtime session if we have one; otherwise + // build one. Arm the tail-capture slot on the session before + // render. + let mut runtime = match self.native_runtime.take() { + Some(runtime) => runtime, + None => self.create_runtime_session(&plan)?, + }; + let slot = runtime.arm_tail_capture(); + + self.emit("Executing native LTX-2 chain stage runtime"); + let prepared = match runtime.prepare(&plan) { + Ok(prepared) => prepared, + Err(err) => { + runtime.clear_tail_capture(); + self.native_runtime = Some(runtime); + return Err(err); + } + }; + let render_result = + runtime.render_native_video(&plan, &prepared, self.on_progress.as_ref()); + runtime.clear_tail_capture(); + self.native_runtime = Some(runtime); + let rendered = render_result?; + + // Drain captured latents. The slot must have been populated by + // the distilled render path — if it's empty, that's a wiring bug, + // not a user error. + let captured = slot + .lock() + .map_err(|_| anyhow!("chain tail-capture mutex was poisoned mid-render"))? + .take() + .ok_or_else(|| { + anyhow!( + "distilled render completed without populating the chain tail-capture slot; \ + this is a pipeline wiring bug" + ) + })?; + + // `extract_tail_latents` returns a narrow view; make it + // contiguous so it survives independently of the runtime's + // working tensors. + let tail_slice = extract_tail_latents(&captured, motion_tail_pixel_frames)?; + let tail_latents = tail_slice + .contiguous() + .context("materializing chain tail latents into an owned tensor")?; + + let frames = rendered.frames; + let last_rgb_frame = frames + .last() + .ok_or_else(|| anyhow!("distilled render returned zero frames"))? + .clone(); + + let generation_time_ms = start.elapsed().as_millis() as u64; + Self::log_timing("pipeline.render_chain_stage", start); + + Ok(StageOutcome { + frames, + tail: ChainTail { + frames: motion_tail_pixel_frames, + latents: tail_latents, + last_rgb_frame, + }, + generation_time_ms, + }) + } +} + +impl ChainStageRenderer for Ltx2Engine { + fn render_stage( + &mut self, + stage_req: &GenerateRequest, + carry: Option<&ChainTail>, + motion_tail_pixel_frames: u32, + _stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, + ) -> Result { + // `_stage_progress` is intentionally unused in v1: per-stage + // denoise events flow through `self.on_progress` already. Phase 2's + // server route will install an on_progress callback that forwards + // those events onto the chain SSE stream with `stage_idx` tagged + // in. If the orchestrator later needs denoise-step events routed + // through its own channel, we can plumb `stage_progress` into a + // temporary ProgressCallback wrapper here. + self.render_chain_stage(stage_req, carry, motion_tail_pixel_frames) + } } impl InferenceEngine for Ltx2Engine { @@ -1087,4 +1231,47 @@ mod tests { assert!(!video.has_audio); assert!(engine.native_runtime.is_none()); } + + #[test] + fn render_chain_stage_rejects_non_distilled_pipeline() { + // A model name without "distilled" in it selects `PipelineKind::TwoStage` + // via `select_pipeline`, which must be rejected up-front by the chain + // entry point before any runtime work happens. + let mut engine = Ltx2Engine::with_runtime_session( + "ltx-2-19b:fp8".to_string(), + dummy_paths(), + runtime_session(), + ); + engine.loaded = true; + let req = request(OutputFormat::Mp4, Some(false)); + let err = engine + .render_chain_stage(&req, None, 4) + .expect_err("must fail on non-distilled pipeline"); + let msg = format!("{err}"); + assert!( + msg.contains("distilled"), + "error must name the pipeline constraint, got: {msg}", + ); + } + + #[test] + fn render_chain_stage_rejects_zero_motion_tail() { + // Zero-frame motion tail is nonsensical — it would narrow nothing off + // for the next stage. Fast-fail before any allocation. + let mut engine = Ltx2Engine::with_runtime_session( + "ltx-2-19b-distilled:fp8".to_string(), + dummy_paths(), + runtime_session(), + ); + engine.loaded = true; + let req = request(OutputFormat::Mp4, Some(false)); + let err = engine + .render_chain_stage(&req, None, 0) + .expect_err("must fail on zero motion tail"); + let msg = format!("{err}"); + assert!( + msg.contains("motion_tail_pixel_frames"), + "error must name the motion_tail constraint, got: {msg}", + ); + } } diff --git a/crates/mold-inference/src/ltx2/runtime.rs b/crates/mold-inference/src/ltx2/runtime.rs index 8df1f26c..5fc0a62b 100644 --- a/crates/mold-inference/src/ltx2/runtime.rs +++ b/crates/mold-inference/src/ltx2/runtime.rs @@ -291,6 +291,11 @@ impl Ltx2VaeLatentStats { pub struct Ltx2RuntimeSession { device: Option, prompt_encoder: Option, + /// Optional slot wired into `render_real_distilled_av` so + /// `Ltx2Engine::render_chain_stage` can snapshot the pre-VAE-decode + /// final latents and forward them to the next chain stage as a + /// [`super::chain::ChainTail`]. `None` outside chain flow. + pub(crate) tail_capture: Option>>>, } impl Ltx2RuntimeSession { @@ -298,6 +303,7 @@ impl Ltx2RuntimeSession { Self { device: Some(device), prompt_encoder: Some(prompt_encoder), + tail_capture: None, } } @@ -305,9 +311,20 @@ impl Ltx2RuntimeSession { Self { device: None, prompt_encoder: Some(prompt_encoder), + tail_capture: None, } } + pub(crate) fn arm_tail_capture(&mut self) -> std::sync::Arc>> { + let slot = std::sync::Arc::new(std::sync::Mutex::new(None)); + self.tail_capture = Some(std::sync::Arc::clone(&slot)); + slot + } + + pub(crate) fn clear_tail_capture(&mut self) { + self.tail_capture = None; + } + pub fn prepare(&mut self, plan: &Ltx2GeneratePlan) -> Result { let prepare_total_start = Instant::now(); let mut stage1_shape = derive_stage1_render_shape( @@ -597,7 +614,13 @@ impl Ltx2RuntimeSession { return Ok(None); } let render = match plan.pipeline { - PipelineKind::Distilled => render_real_distilled_av(plan, prepared, device, progress), + PipelineKind::Distilled => render_real_distilled_av( + plan, + prepared, + device, + progress, + self.tail_capture.as_ref(), + ), PipelineKind::OneStage => render_real_one_stage_av(plan, prepared, device, progress), PipelineKind::TwoStage | PipelineKind::TwoStageHq @@ -1724,6 +1747,7 @@ fn render_real_distilled_av( prepared: &NativePreparedRun, device: &candle_core::Device, progress: Option<&ProgressCallback>, + tail_capture: Option<&std::sync::Arc>>>, ) -> Result { let debug_enabled = ltx_debug_enabled(); let prompt_inputs = prepare_render_prompt_inputs( @@ -2008,6 +2032,16 @@ fn render_real_distilled_av( vae.use_tiling = false; vae.use_framewise_decoding = false; let decode_start = Instant::now(); + // Chain-stage hook: capture the pre-decode F32 latents so + // `Ltx2Engine::render_chain_stage` can narrow the tail off for the next + // stage's conditioning. Cheap shallow clone (candle tensors are + // Arc-backed). A poisoned mutex is ignored here — the outer caller + // detects an empty slot and emits a clear error. + if let Some(slot) = tail_capture { + if let Ok(mut guard) = slot.lock() { + *guard = Some(latents.clone()); + } + } let (_dec_output, video) = vae.decode(&latents.to_dtype(dtype)?, None, false, false)?; if debug_enabled { log_tensor_stats("decoded_video", &video)?; From 548f2fc82d80de423575c4aed5b202747ec013ab Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 18:35:20 -0700 Subject: [PATCH 09/31] feat(server): chain render endpoint with SSE streaming Add POST /api/generate/chain and POST /api/generate/chain/stream for server-side chained LTX-2 video generation. Handler take/restores the engine out of the model cache and runs the full chain in a spawn_blocking so the sync orchestrator never blocks the async runtime. Drives Ltx2ChainOrchestrator through the engine's ChainStageRenderer view, trims accumulated frames to target total from the tail per sign-off, encodes the stitched output (MP4 when the mp4 feature is on, APNG fallback otherwise), and saves to the gallery with a synthesised OutputMetadata. Expose as_chain_renderer() on InferenceEngine (default None), overridden by Ltx2Engine. Relax Ltx2ChainOrchestrator's renderer bound to ?Sized so trait objects compose cleanly. Promote ltx_video::video_enc from pub(crate) to pub so mold-server can reuse encode_mp4/encode_apng/ encode_gif/first_frame_png for chain stitching. Weight-free route tests cover the happy path, the mid-chain failure (502 Bad Gateway), the unsupported-model rejection (422), progress event ordering through the SSE helper, and tail-trim behaviour. Part of render-chain v1 (Phase 2). Co-Authored-By: Claude Opus 4.7 (1M context) --- Cargo.lock | 1 + crates/mold-cli/Cargo.toml | 2 +- crates/mold-inference/src/engine.rs | 11 + crates/mold-inference/src/ltx2/chain.rs | 4 +- crates/mold-inference/src/ltx2/pipeline.rs | 4 + crates/mold-inference/src/ltx_video/mod.rs | 5 +- crates/mold-server/Cargo.toml | 5 + crates/mold-server/src/lib.rs | 1 + crates/mold-server/src/routes.rs | 27 +- crates/mold-server/src/routes_chain.rs | 788 +++++++++++++++++++++ 10 files changed, 843 insertions(+), 5 deletions(-) create mode 100644 crates/mold-server/src/routes_chain.rs diff --git a/Cargo.lock b/Cargo.lock index c950b94e..4a76c58b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3104,6 +3104,7 @@ dependencies = [ "async-trait", "axum", "base64 0.22.1", + "candle-core-mold", "clap", "dirs 5.0.1", "futures", diff --git a/crates/mold-cli/Cargo.toml b/crates/mold-cli/Cargo.toml index a7734a87..07a71f88 100644 --- a/crates/mold-cli/Cargo.toml +++ b/crates/mold-cli/Cargo.toml @@ -22,7 +22,7 @@ discord = ["mold-discord"] expand = ["mold-inference/expand", "mold-server/expand", "mold-tui?/expand"] tui = ["dep:mold-tui"] webp = ["mold-inference/webp"] -mp4 = ["mold-inference/mp4"] +mp4 = ["mold-inference/mp4", "mold-server/mp4"] metrics = ["mold-server/metrics"] [dependencies] diff --git a/crates/mold-inference/src/engine.rs b/crates/mold-inference/src/engine.rs index 949f5089..cddafc7a 100644 --- a/crates/mold-inference/src/engine.rs +++ b/crates/mold-inference/src/engine.rs @@ -35,6 +35,17 @@ pub trait InferenceEngine: Send + Sync { fn model_paths(&self) -> Option<&mold_core::ModelPaths> { None } + + /// Returns a [`ChainStageRenderer`] view of this engine if the family + /// supports chained video generation. Default is `None` — only LTX-2 + /// distilled overrides this in v1. + /// + /// Callers (the server chain route) invoke this once per stage to drive + /// [`crate::ltx2::Ltx2ChainOrchestrator::run`]; engines that don't support + /// chaining return `None` and the caller responds with 422. + fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> { + None + } } /// Restores an `Option` slot even if the current scope unwinds. diff --git a/crates/mold-inference/src/ltx2/chain.rs b/crates/mold-inference/src/ltx2/chain.rs index d40c95de..81eb2b47 100644 --- a/crates/mold-inference/src/ltx2/chain.rs +++ b/crates/mold-inference/src/ltx2/chain.rs @@ -159,11 +159,11 @@ pub struct ChainRunOutput { /// Drives the per-stage render loop for a chained generation. Borrows its /// renderer mutably so the loop can re-enter the engine on the same GPU /// context across stages. -pub struct Ltx2ChainOrchestrator<'a, R: ChainStageRenderer> { +pub struct Ltx2ChainOrchestrator<'a, R: ChainStageRenderer + ?Sized> { renderer: &'a mut R, } -impl<'a, R: ChainStageRenderer> Ltx2ChainOrchestrator<'a, R> { +impl<'a, R: ChainStageRenderer + ?Sized> Ltx2ChainOrchestrator<'a, R> { pub fn new(renderer: &'a mut R) -> Self { Self { renderer } } diff --git a/crates/mold-inference/src/ltx2/pipeline.rs b/crates/mold-inference/src/ltx2/pipeline.rs index 7f419d4d..9c744d8b 100644 --- a/crates/mold-inference/src/ltx2/pipeline.rs +++ b/crates/mold-inference/src/ltx2/pipeline.rs @@ -720,6 +720,10 @@ impl InferenceEngine for Ltx2Engine { fn model_paths(&self) -> Option<&ModelPaths> { Some(&self.paths) } + + fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> { + Some(self) + } } #[cfg(test)] diff --git a/crates/mold-inference/src/ltx_video/mod.rs b/crates/mold-inference/src/ltx_video/mod.rs index 4e37627f..aa01b614 100644 --- a/crates/mold-inference/src/ltx_video/mod.rs +++ b/crates/mold-inference/src/ltx_video/mod.rs @@ -1,5 +1,8 @@ pub(crate) mod latent_upsampler; mod pipeline; -pub(crate) mod video_enc; +// Video encoding helpers (GIF/APNG/WebP/MP4 + thumbnail) are used by +// chain stitching in `mold-server`, so the module is public rather than +// crate-private. +pub mod video_enc; pub use pipeline::LtxVideoEngine; diff --git a/crates/mold-server/Cargo.toml b/crates/mold-server/Cargo.toml index 28e76c4b..34e6645f 100644 --- a/crates/mold-server/Cargo.toml +++ b/crates/mold-server/Cargo.toml @@ -25,6 +25,7 @@ default = [] cuda = ["mold-inference/cuda"] metal = ["mold-inference/metal"] expand = ["mold-inference/expand"] +mp4 = ["mold-inference/mp4"] metrics = ["dep:metrics", "dep:metrics-exporter-prometheus"] nvml = ["dep:nvml-wrapper"] @@ -72,3 +73,7 @@ async-stream = "0.3" [dev-dependencies] tempfile = "3" tokio = { version = "1", features = ["full", "test-util"] } +# Chain route tests build a synthetic motion-tail Tensor via the same +# candle APIs the inference crate uses — keep this in lockstep with +# mold-inference's pinned candle-core-mold version. +candle-core = { package = "candle-core-mold", version = "0.9.10" } diff --git a/crates/mold-server/src/lib.rs b/crates/mold-server/src/lib.rs index 97ea6f42..62ef01ca 100644 --- a/crates/mold-server/src/lib.rs +++ b/crates/mold-server/src/lib.rs @@ -13,6 +13,7 @@ pub mod rate_limit; pub mod request_id; pub mod resources; pub mod routes; +pub mod routes_chain; pub mod state; pub mod web_ui; diff --git a/crates/mold-server/src/routes.rs b/crates/mold-server/src/routes.rs index 4ea11cd8..2d6ffc3c 100644 --- a/crates/mold-server/src/routes.rs +++ b/crates/mold-server/src/routes.rs @@ -133,7 +133,19 @@ use crate::queue::clean_error_message; #[derive(OpenApi)] #[openapi( - paths(generate, generate_stream, expand_prompt, list_models, load_model, pull_model_endpoint, unload_model, server_status, health), + paths( + generate, + generate_stream, + expand_prompt, + list_models, + load_model, + pull_model_endpoint, + unload_model, + server_status, + health, + crate::routes_chain::generate_chain, + crate::routes_chain::generate_chain_stream, + ), components(schemas( mold_core::GenerateRequest, mold_core::GenerateResponse, @@ -148,6 +160,11 @@ use crate::queue::clean_error_message; mold_core::SseProgressEvent, mold_core::SseCompleteEvent, mold_core::SseErrorEvent, + mold_core::ChainRequest, + mold_core::ChainResponse, + mold_core::ChainStage, + mold_core::ChainProgressEvent, + mold_core::SseChainCompleteEvent, ModelInfoExtended, LoadModelBody, UnloadRequest, @@ -171,6 +188,14 @@ pub fn create_router(state: AppState) -> Router { Router::new() .route("/api/generate", post(generate)) .route("/api/generate/stream", post(generate_stream)) + .route( + "/api/generate/chain", + post(crate::routes_chain::generate_chain), + ) + .route( + "/api/generate/chain/stream", + post(crate::routes_chain::generate_chain_stream), + ) .route("/api/expand", post(expand_prompt)) .route("/api/models", get(list_models)) .route("/api/models/load", post(load_model)) diff --git a/crates/mold-server/src/routes_chain.rs b/crates/mold-server/src/routes_chain.rs new file mode 100644 index 00000000..c8e4ef68 --- /dev/null +++ b/crates/mold-server/src/routes_chain.rs @@ -0,0 +1,788 @@ +//! Server-side chained video generation endpoints. +//! +//! Exposes `POST /api/generate/chain` (synchronous) and +//! `POST /api/generate/chain/stream` (SSE). Both drive +//! [`mold_inference::ltx2::Ltx2ChainOrchestrator`] through an engine's +//! [`mold_inference::ltx2::ChainStageRenderer`] view. +//! +//! Unlike the single-shot generate path (which queues through +//! [`crate::state::QueueHandle`] to keep small GPU jobs FIFO-fair), chains +//! are multi-minute compound jobs — the handler take/restores the engine +//! out of the model cache and runs the full sequence in a +//! [`tokio::task::spawn_blocking`] so the sync orchestrator never blocks +//! the async runtime. While the chain is running the engine is removed +//! from the cache, so concurrent generate/chain requests for the same +//! model cannot race. + +use std::convert::Infallible; + +use axum::{ + extract::State, + response::sse::{Event as SseEvent, KeepAlive, Sse}, + Json, +}; +use base64::Engine as _; +use mold_core::chain::{ChainProgressEvent, ChainRequest, ChainResponse, SseChainCompleteEvent}; +use mold_core::{OutputFormat, OutputMetadata, VideoData}; +use tokio_stream::StreamExt as _; + +use crate::model_cache::CachedEngine; +use crate::model_manager; +use crate::queue::save_video_to_dir; +use crate::routes::ApiError; +use crate::state::AppState; + +/// Internal wire event used by the chain SSE stream before per-event +/// serialization. Separate from [`crate::state::SseMessage`] because chain +/// complete events carry a different payload (`SseChainCompleteEvent`) and +/// progress events are chain-shaped (`ChainProgressEvent`) rather than the +/// single-stage `SseProgressEvent`. +pub(crate) enum ChainSseMessage { + Progress(ChainProgressEvent), + Complete(SseChainCompleteEvent), + Error(String), +} + +fn chain_sse_event(msg: ChainSseMessage) -> SseEvent { + match msg { + ChainSseMessage::Progress(ev) => match serde_json::to_string(&ev) { + Ok(data) => SseEvent::default().event("progress").data(data), + Err(e) => SseEvent::default() + .event("error") + .data(format!(r#"{{"message":"serialize progress: {e}"}}"#)), + }, + ChainSseMessage::Complete(ev) => match serde_json::to_string(&ev) { + Ok(data) => SseEvent::default().event("complete").data(data), + Err(e) => SseEvent::default() + .event("error") + .data(format!(r#"{{"message":"serialize complete: {e}"}}"#)), + }, + ChainSseMessage::Error(message) => SseEvent::default() + .event("error") + .data(serde_json::json!({ "message": message }).to_string()), + } +} + +/// Encode chain frames into bytes for the requested output format. Returns +/// the encoded payload plus a best-effort animated-GIF preview for the +/// gallery. +/// +/// MP4 is gated behind the `mp4` feature flag; when the flag is disabled, +/// the handler falls back to APNG so the endpoint still produces a usable +/// animation on every build. +fn encode_chain_output( + frames: &[image::RgbImage], + fps: u32, + format: OutputFormat, +) -> anyhow::Result<(Vec, OutputFormat, Vec)> { + use mold_inference::ltx_video::video_enc; + + // Always produce a GIF preview for the gallery UI. Non-fatal. + let gif_preview = match video_enc::encode_gif(frames, fps) { + Ok(b) => b, + Err(e) => { + tracing::warn!("chain gif preview encode failed: {e:#}"); + Vec::new() + } + }; + + let (bytes, actual_format) = match format { + OutputFormat::Mp4 => { + #[cfg(feature = "mp4")] + { + (video_enc::encode_mp4(frames, fps)?, OutputFormat::Mp4) + } + #[cfg(not(feature = "mp4"))] + { + tracing::warn!( + "chain requested MP4 but server was built without the `mp4` feature — \ + falling back to APNG" + ); + ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ) + } + } + OutputFormat::Apng => ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ), + OutputFormat::Gif => (video_enc::encode_gif(frames, fps)?, OutputFormat::Gif), + // WebP is always available here because mold-inference's webp + // feature would need to gate at the transitive-dep level; for the + // chain route v1 we fall back to APNG when WebP is requested so + // we don't bind the server crate to another optional dep. + OutputFormat::Webp => { + tracing::warn!( + "chain WebP output is not supported on the server yet — falling back to APNG" + ); + ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ) + } + other => anyhow::bail!("{other:?} is not a video output format for chain generation"), + }; + + Ok((bytes, actual_format, gif_preview)) +} + +/// Build the `OutputMetadata` for a stitched chain output. Pulls chain- +/// level parameters (dimensions, seed, steps) from `req` and the prompt / +/// negative prompt from `stages[0]`. +fn chain_output_metadata(req: &ChainRequest, frame_count: u32) -> OutputMetadata { + let first_stage = req.stages.first(); + OutputMetadata { + prompt: first_stage.map(|s| s.prompt.clone()).unwrap_or_default(), + negative_prompt: first_stage.and_then(|s| s.negative_prompt.clone()), + original_prompt: None, + model: req.model.clone(), + seed: req.seed.unwrap_or(0), + steps: req.steps, + guidance: req.guidance, + width: req.width, + height: req.height, + strength: Some(req.strength), + scheduler: None, + lora: None, + lora_scale: None, + frames: Some(frame_count), + fps: Some(req.fps), + version: mold_core::build_info::version_string().to_string(), + } +} + +/// Trim a frame buffer to the caller's requested total frame count, per +/// the signed-off "trim from tail" decision (2026-04-20). The orchestrator +/// always over-produces to hit or exceed `total_frames`; trimming here +/// keeps the output length deterministic without altering per-stage +/// denoise behaviour. +fn trim_to_total_frames(frames: &mut Vec, total_frames: Option) { + if let Some(target) = total_frames { + let target = target as usize; + if frames.len() > target { + frames.truncate(target); + } + } +} + +/// Produce a PNG thumbnail for the chain output — best-effort, returns +/// an empty `Vec` on failure so the save/response paths still succeed. +fn chain_thumbnail(frames: &[image::RgbImage]) -> Vec { + match mold_inference::ltx_video::video_enc::first_frame_png(frames) { + Ok(b) => b, + Err(e) => { + tracing::warn!("chain thumbnail encode failed: {e:#}"); + Vec::new() + } + } +} + +/// Build a `VideoData` for the `ChainResponse` body. +fn build_video_data( + bytes: Vec, + format: OutputFormat, + req: &ChainRequest, + frame_count: u32, + thumbnail: Vec, + gif_preview: Vec, +) -> VideoData { + let duration_ms = if req.fps == 0 { + None + } else { + Some((frame_count as u64 * 1000) / req.fps as u64) + }; + VideoData { + data: bytes, + format, + width: req.width, + height: req.height, + frames: frame_count, + fps: req.fps, + thumbnail, + gif_preview, + has_audio: false, + duration_ms, + audio_sample_rate: None, + audio_channels: None, + } +} + +/// Build the SSE `complete` payload for a finished chain run. Sibling of +/// [`crate::queue::build_sse_complete_event`] — kept in this module so the +/// chain-specific payload can evolve independently from the single-shot +/// one. +fn build_sse_chain_complete_event( + resp: &ChainResponse, + generation_time_ms: u64, +) -> SseChainCompleteEvent { + let b64 = base64::engine::general_purpose::STANDARD; + let video = &resp.video; + SseChainCompleteEvent { + video: b64.encode(&video.data), + format: video.format, + width: video.width, + height: video.height, + frames: video.frames, + fps: video.fps, + thumbnail: if video.thumbnail.is_empty() { + None + } else { + Some(b64.encode(&video.thumbnail)) + }, + gif_preview: if video.gif_preview.is_empty() { + None + } else { + Some(b64.encode(&video.gif_preview)) + }, + has_audio: video.has_audio, + duration_ms: video.duration_ms, + audio_sample_rate: video.audio_sample_rate, + audio_channels: video.audio_channels, + stage_count: resp.stage_count, + gpu: resp.gpu, + generation_time_ms: Some(generation_time_ms), + } +} + +/// Errors surfaced from the chain-run helper. Mapped to appropriate HTTP +/// status codes by the route handlers. +#[derive(Debug)] +enum ChainRunError { + /// Model family doesn't support chain rendering (422). + UnsupportedModel(String), + /// Engine missing from cache after `ensure_model_ready` (500). + CacheMiss(String), + /// Orchestrator returned an error mid-chain (502). + Inference(String), + /// Output encoding / stitch failure (500). + Encode(String), + /// Task panic or join error (500). + Internal(String), +} + +impl From for ApiError { + fn from(err: ChainRunError) -> Self { + match err { + ChainRunError::UnsupportedModel(msg) => ApiError::validation(msg), + ChainRunError::CacheMiss(msg) => ApiError::internal(msg), + ChainRunError::Inference(msg) => { + ApiError::internal_with_status(msg, axum::http::StatusCode::BAD_GATEWAY) + } + ChainRunError::Encode(msg) => ApiError::internal(msg), + ChainRunError::Internal(msg) => ApiError::internal(msg), + } + } +} + +/// Drive the chain to completion. Shared between the non-streaming and SSE +/// paths — the only caller-provided variable is `progress_cb`, which is +/// `None` for the plain JSON endpoint and `Some` for the SSE endpoint. +async fn run_chain( + state: &AppState, + req: ChainRequest, + progress_cb: Option>, +) -> Result<(ChainResponse, u64), ChainRunError> { + // Ensure the model is loaded. Progress forwarding is not plumbed yet — + // load-time events go through the model manager's own tracing. Chain + // stage events (StageStart/DenoiseStep/StageDone/Stitching) come from + // the orchestrator during the blocking task below. + model_manager::ensure_model_ready(state, &req.model, None) + .await + .map_err(|e| ChainRunError::CacheMiss(e.error))?; + + // Take the engine out of the cache so the blocking orchestrator run + // owns it for the full multi-minute chain without holding the async + // mutex guard across an await. Restore when we're done (or on error). + let mut cache = state.model_cache.lock().await; + let cached: CachedEngine = cache.take(&req.model).ok_or_else(|| { + ChainRunError::CacheMiss(format!( + "engine '{}' vanished from cache after ensure_model_ready", + req.model + )) + })?; + drop(cache); + + let req_for_task = req.clone(); + let join_handle = tokio::task::spawn_blocking(move || { + let mut cached = cached; + let mut progress_cb = progress_cb; + let outcome = { + let engine = &mut cached.engine; + match engine.as_chain_renderer() { + Some(renderer) => { + let mut orch = mold_inference::ltx2::Ltx2ChainOrchestrator::new(renderer); + // The orchestrator expects `Option<&mut dyn FnMut(...)>` + // — synthesise that from the optional boxed callback we + // moved into this task. + let result = if let Some(cb) = progress_cb.as_deref_mut() { + orch.run(&req_for_task, Some(cb)) + } else { + orch.run(&req_for_task, None) + }; + result.map_err(|e| ChainRunError::Inference(format!("{e:#}"))) + } + None => Err(ChainRunError::UnsupportedModel(format!( + "model '{}' does not support chained video generation", + req_for_task.model + ))), + } + }; + (cached, outcome) + }); + + let (cached, outcome) = match join_handle.await { + Ok(pair) => pair, + Err(join_err) => { + return Err(ChainRunError::Internal(format!( + "chain orchestrator task failed: {join_err}" + ))); + } + }; + + // Restore the engine to the cache regardless of success/failure so the + // next request can reuse it. + { + let mut cache = state.model_cache.lock().await; + cache.restore(cached); + } + + let chain_output = outcome?; + let stage_count = chain_output.stage_count; + let generation_time_ms = chain_output.generation_time_ms; + let mut frames = chain_output.frames; + trim_to_total_frames(&mut frames, req.total_frames); + + if frames.is_empty() { + return Err(ChainRunError::Encode( + "chain run emitted zero frames after trim".to_string(), + )); + } + + let (bytes, output_format, gif_preview) = + encode_chain_output(&frames, req.fps, req.output_format) + .map_err(|e| ChainRunError::Encode(format!("encode chain output: {e:#}")))?; + let thumbnail = chain_thumbnail(&frames); + let frame_count = frames.len() as u32; + + // Save to the gallery directory (best-effort, non-blocking). + let output_dir = { + let config = state.config.read().await; + if config.is_output_disabled() { + None + } else { + Some(config.effective_output_dir()) + } + }; + if let Some(dir) = output_dir { + let metadata = chain_output_metadata(&req, frame_count); + let bytes_clone = bytes.clone(); + let gif_clone = gif_preview.clone(); + let model = req.model.clone(); + let db = state.metadata_db.clone(); + tokio::task::spawn_blocking(move || { + save_video_to_dir( + &dir, + &bytes_clone, + &gif_clone, + output_format, + &model, + &metadata, + Some(generation_time_ms as i64), + db.as_ref().as_ref(), + ); + }); + } + + let video = build_video_data( + bytes, + output_format, + &req, + frame_count, + thumbnail, + gif_preview, + ); + let response = ChainResponse { + video, + stage_count, + gpu: None, + }; + Ok((response, generation_time_ms)) +} + +/// `POST /api/generate/chain` — synchronous chained video generation. +#[utoipa::path( + post, + path = "/api/generate/chain", + tag = "generation", + request_body = mold_core::ChainRequest, + responses( + (status = 200, description = "Stitched chain video", body = mold_core::ChainResponse), + (status = 422, description = "Invalid request or unsupported model"), + (status = 500, description = "Chain render failed"), + (status = 502, description = "Chain render failed mid-stage"), + ) +)] +pub async fn generate_chain( + State(state): State, + Json(req): Json, +) -> Result, ApiError> { + let req = req + .normalise() + .map_err(|e| ApiError::validation(e.to_string()))?; + + tracing::info!( + model = %req.model, + stages = req.stages.len(), + width = req.width, + height = req.height, + fps = req.fps, + "generate/chain request" + ); + + let (response, _elapsed_ms) = run_chain(&state, req, None).await?; + Ok(Json(response)) +} + +/// `POST /api/generate/chain/stream` — SSE-streamed chain generation. Emits +/// [`ChainProgressEvent`]s as `event: progress` frames while the chain +/// runs, and a single `event: complete` frame with a [`SseChainCompleteEvent`] +/// payload when the stitched output is ready. Mid-chain failure closes the +/// stream with an `event: error` frame carrying the orchestrator message. +#[utoipa::path( + post, + path = "/api/generate/chain/stream", + tag = "generation", + request_body = mold_core::ChainRequest, + responses( + (status = 200, description = "SSE event stream with chain progress and completion"), + (status = 422, description = "Invalid request or unsupported model"), + (status = 500, description = "Chain render failed"), + ) +)] +pub async fn generate_chain_stream( + State(state): State, + Json(req): Json, +) -> Result>>, ApiError> { + let req = req + .normalise() + .map_err(|e| ApiError::validation(e.to_string()))?; + + tracing::info!( + model = %req.model, + stages = req.stages.len(), + width = req.width, + height = req.height, + fps = req.fps, + "generate/chain/stream request" + ); + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + let state_clone = state.clone(); + let tx_for_task = tx.clone(); + + tokio::spawn(async move { + let tx_for_cb = tx_for_task.clone(); + let cb: Box = Box::new(move |event| { + let _ = tx_for_cb.send(ChainSseMessage::Progress(event)); + }); + match run_chain(&state_clone, req, Some(cb)).await { + Ok((response, elapsed_ms)) => { + let complete = build_sse_chain_complete_event(&response, elapsed_ms); + let _ = tx_for_task.send(ChainSseMessage::Complete(complete)); + } + Err(err) => { + let api_err: ApiError = err.into(); + let _ = tx_for_task.send(ChainSseMessage::Error(api_err.error)); + } + } + // `tx_for_task` is dropped here, closing the channel and finalizing + // the SSE stream after the last complete/error frame. + }); + drop(tx); // ensure only the task holds the sender + + let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx) + .map(|msg| Ok::<_, Infallible>(chain_sse_event(msg))); + + Ok(Sse::new(stream).keep_alive( + KeepAlive::new() + .interval(std::time::Duration::from_secs(15)) + .text("ping"), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + use candle_core::{DType, Device, Tensor}; + use image::{Rgb, RgbImage}; + use mold_core::chain::{ChainProgressEvent, ChainRequest, ChainStage}; + use mold_core::{GenerateRequest, GenerateResponse}; + use mold_inference::ltx2::{ChainStageRenderer, ChainTail, StageOutcome, StageProgressEvent}; + use mold_inference::InferenceEngine; + use std::sync::{Arc, Mutex}; + + /// Mock engine that delegates to a simple chain renderer producing + /// deterministic solid-color frames + a zero-valued latent tail. The + /// chain renderer is owned by the engine so `as_chain_renderer` can + /// hand out a `&mut dyn ChainStageRenderer` over it. + struct ChainMockEngine { + loaded: bool, + fail_on_stage: Option, + renderer_calls: Arc>, + } + + impl ChainMockEngine { + fn ready() -> Self { + Self { + loaded: true, + fail_on_stage: None, + renderer_calls: Arc::new(Mutex::new(0)), + } + } + fn failing_at(idx: usize) -> Self { + Self { + loaded: true, + fail_on_stage: Some(idx), + renderer_calls: Arc::new(Mutex::new(0)), + } + } + } + + impl ChainStageRenderer for ChainMockEngine { + fn render_stage( + &mut self, + stage_req: &GenerateRequest, + _carry: Option<&ChainTail>, + _motion_tail_pixel_frames: u32, + _stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, + ) -> Result { + let idx = { + let mut calls = self.renderer_calls.lock().unwrap(); + let idx = *calls; + *calls += 1; + idx + }; + if self.fail_on_stage == Some(idx) { + anyhow::bail!("simulated chain failure at stage {idx}"); + } + let frame_count = stage_req.frames.expect("chain stage missing frame count") as usize; + let width = stage_req.width; + let height = stage_req.height; + let mut frames = Vec::with_capacity(frame_count); + for f in 0..frame_count { + let shade = (idx as u8).wrapping_mul(17).wrapping_add(f as u8); + frames.push(RgbImage::from_pixel(width, height, Rgb([shade, 0, 0]))); + } + let last_frame = frames.last().cloned().unwrap(); + let latent = Tensor::zeros( + (1, 128, 1, height as usize / 32, width as usize / 32), + DType::F32, + &Device::Cpu, + )?; + Ok(StageOutcome { + frames, + tail: ChainTail { + frames: 4, + latents: latent, + last_rgb_frame: last_frame, + }, + generation_time_ms: 10, + }) + } + } + + impl InferenceEngine for ChainMockEngine { + fn generate(&mut self, _req: &GenerateRequest) -> Result { + anyhow::bail!("chain mock engine does not support single-shot generate") + } + fn model_name(&self) -> &str { + "ltx-2-19b-distilled:mock" + } + fn is_loaded(&self) -> bool { + self.loaded + } + fn load(&mut self) -> Result<()> { + self.loaded = true; + Ok(()) + } + fn as_chain_renderer( + &mut self, + ) -> Option<&mut dyn mold_inference::ltx2::ChainStageRenderer> { + Some(self) + } + } + + /// Build an AppState whose model cache already contains a chain-capable + /// mock engine under the model name the tests pass in their requests. + fn state_with_chain_engine(engine: ChainMockEngine) -> AppState { + AppState::with_engine(engine) + } + + fn chain_req_for_mock(model: &str, stages: u32) -> ChainRequest { + ChainRequest { + model: model.to_string(), + stages: (0..stages) + .map(|_| ChainStage { + prompt: "a cat walking".into(), + frames: 9, + source_image: None, + negative_prompt: None, + seed_offset: None, + }) + .collect(), + motion_tail_frames: 0, // simplifies frame accounting for the mock + width: 64, + height: 64, + fps: 12, + seed: Some(42), + steps: 4, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Apng, // avoid needing the mp4 feature in tests + placement: None, + prompt: None, + total_frames: None, + clip_frames: None, + source_image: None, + } + } + + #[tokio::test] + async fn chain_happy_path_returns_stage_count_and_video() { + let engine = ChainMockEngine::ready(); + let state = state_with_chain_engine(engine); + let req = chain_req_for_mock("ltx-2-19b-distilled:mock", 3); + + let (resp, elapsed_ms) = run_chain(&state, req, None) + .await + .expect("chain run succeeds"); + + assert_eq!(resp.stage_count, 3, "response must report all 3 stages"); + assert_eq!(resp.video.fps, 12); + assert_eq!(resp.video.frames, 9 * 3, "3 stages × 9 frames with tail=0"); + assert_eq!(resp.video.format, OutputFormat::Apng); + assert!(!resp.video.data.is_empty(), "apng bytes written"); + // elapsed_ms is the sum of the mock's reported per-stage time (10ms each). + assert_eq!(elapsed_ms, 30); + } + + #[tokio::test] + async fn chain_stream_emits_progress_then_complete_in_order() { + let engine = ChainMockEngine::ready(); + let state = state_with_chain_engine(engine); + let req = chain_req_for_mock("ltx-2-19b-distilled:mock", 2); + + let collected: Arc>> = Arc::new(Mutex::new(Vec::new())); + let collected_cb = collected.clone(); + let cb: Box = Box::new(move |ev| { + collected_cb.lock().unwrap().push(ev); + }); + let (resp, _) = run_chain(&state, req, Some(cb)) + .await + .expect("chain run succeeds"); + + assert_eq!(resp.stage_count, 2); + let events = collected.lock().unwrap(); + assert!(!events.is_empty(), "progress events must flow"); + assert!( + matches!( + events[0], + ChainProgressEvent::ChainStart { stage_count: 2, .. } + ), + "first event must be ChainStart, got {:?}", + events[0] + ); + assert!( + matches!(events.last().unwrap(), ChainProgressEvent::Stitching { .. }), + "last event must be Stitching, got {:?}", + events.last() + ); + // There must be exactly one StageStart + StageDone per stage. + let stage_starts = events + .iter() + .filter(|e| matches!(e, ChainProgressEvent::StageStart { .. })) + .count(); + let stage_dones = events + .iter() + .filter(|e| matches!(e, ChainProgressEvent::StageDone { .. })) + .count(); + assert_eq!(stage_starts, 2); + assert_eq!(stage_dones, 2); + } + + #[tokio::test] + async fn chain_mid_chain_failure_maps_to_bad_gateway() { + let engine = ChainMockEngine::failing_at(1); + let state = state_with_chain_engine(engine); + let req = chain_req_for_mock("ltx-2-19b-distilled:mock", 3); + + let err = run_chain(&state, req, None) + .await + .expect_err("mid-chain failure must bubble up"); + match err { + ChainRunError::Inference(msg) => { + assert!( + msg.contains("simulated chain failure"), + "inference error must carry renderer message, got: {msg}" + ); + } + other => panic!("expected Inference error, got {other:?}"), + } + } + + #[tokio::test] + async fn chain_unsupported_model_rejects_with_validation() { + /// Engine that is fully capable of single-shot generate but refuses + /// chain rendering (mirrors every non-LTX-2 family). + struct NonChainEngine; + impl InferenceEngine for NonChainEngine { + fn generate(&mut self, _req: &GenerateRequest) -> Result { + anyhow::bail!("no single-shot generate in this test either") + } + fn model_name(&self) -> &str { + "flux-dev:q8" + } + fn is_loaded(&self) -> bool { + true + } + fn load(&mut self) -> Result<()> { + Ok(()) + } + // No override for as_chain_renderer — default returns None. + } + + let state = AppState::with_engine(NonChainEngine); + let mut req = chain_req_for_mock("flux-dev:q8", 2); + req.model = "flux-dev:q8".into(); + let err = run_chain(&state, req, None) + .await + .expect_err("non-chain model must fail"); + match err { + ChainRunError::UnsupportedModel(msg) => { + assert!( + msg.contains("does not support chained video generation"), + "unsupported-model error must name the constraint, got: {msg}" + ); + } + other => panic!("expected UnsupportedModel, got {other:?}"), + } + } + + #[tokio::test] + async fn chain_trims_frames_from_tail_when_total_frames_set() { + let engine = ChainMockEngine::ready(); + let state = state_with_chain_engine(engine); + let mut req = chain_req_for_mock("ltx-2-19b-distilled:mock", 2); + // Each stage produces 9 frames with tail=0 → 18 total. Trim to 10. + req.total_frames = Some(10); + + let (resp, _) = run_chain(&state, req, None).await.expect("chain runs"); + assert_eq!( + resp.video.frames, 10, + "total_frames must trim the stitched output length" + ); + } +} From 6ed6b59a93a3f65a1cc9f2b6440e7d701d05d8c5 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 18:54:43 -0700 Subject: [PATCH 10/31] feat(cli): chain rendering for --frames above clip cap When --frames exceeds the model's per-clip cap (97 for LTX-2 distilled), `mold run` now auto-builds a ChainRequest and routes to POST /api/generate/chain/stream (server mode) or runs the Ltx2ChainOrchestrator in-process (--local mode). New flags --clip-frames and --motion-tail let users tune the per-clip length and the motion-tail overlap (default 4 frames of latent carryover between clips). Stacked progress bars render a parent "Chain" bar (total frames) and a wiping per-stage bar (denoise step / total). Both server and local paths share a single encode+save+preview epilogue so output formatting, stdout piping, and gallery save are identical. Models outside LTX-2 distilled families error fast when --frames exceeds the single-clip cap rather than silently dropping frames or hitting the server's chain route with a non-chainable model. A pure `decide_chain_routing` helper captures the branching logic so auto- routing is unit-testable without async or network. Part of render-chain v1 (Phase 3). Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/mold-cli/src/commands/chain.rs | 843 +++++++++++++++++++++++ crates/mold-cli/src/commands/generate.rs | 118 ++++ crates/mold-cli/src/commands/mod.rs | 1 + crates/mold-cli/src/commands/run.rs | 4 + crates/mold-cli/src/main.rs | 64 ++ 5 files changed, 1030 insertions(+) create mode 100644 crates/mold-cli/src/commands/chain.rs diff --git a/crates/mold-cli/src/commands/chain.rs b/crates/mold-cli/src/commands/chain.rs new file mode 100644 index 00000000..a0988b72 --- /dev/null +++ b/crates/mold-cli/src/commands/chain.rs @@ -0,0 +1,843 @@ +//! CLI-side render-chain orchestration for LTX-2 distilled models. +//! +//! When `mold run --frames N` exceeds the per-clip cap of the selected model, +//! this module takes over from [`super::generate::run`]: it assembles a +//! [`ChainRequest`] from the user's CLI args and either submits it to a +//! running server via [`MoldClient::generate_chain_stream`] or, in `--local` +//! mode, drives an in-process [`Ltx2ChainOrchestrator`]. +//! +//! Both paths funnel through [`encode_and_save`] so stdout piping, gallery +//! save, metadata DB writes, and preview behaviour match the single-clip +//! path byte-for-byte. + +use std::io::Write; +use std::time::Duration; + +use anyhow::Result; +use colored::Colorize; +use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use mold_core::chain::{ChainProgressEvent, ChainRequest}; +use mold_core::{Config, MoldClient, OutputFormat, VideoData}; + +use crate::control::CliContext; +use crate::output::{is_piped, status}; +use crate::theme; + +/// Per-clip frame cap for LTX-2 19B/22B distilled. The distilled VAE +/// pipeline maxes at 97 pixel frames (13 latent frames) per clip. +pub const LTX2_DISTILLED_CLIP_CAP: u32 = 97; + +/// Outcome of [`decide_chain_routing`]: either the caller should continue +/// down the single-clip path, build a chain with the given settings, or +/// reject the request because the model family can't be chained. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ChainRoutingDecision { + /// Go through the normal single-clip path; no chaining required. + SingleClip, + /// Submit a chain. `clip_frames` is the clamped per-clip cap. + Chain { clip_frames: u32, motion_tail: u32 }, + /// Model family doesn't support chaining and `frames` exceeds its cap. + Rejected { reason: String }, +} + +/// Pure decision function — given a model family, the user's requested +/// `frames`, and the optional `--clip-frames` override, decide whether to +/// chain, stay single-clip, or reject. +/// +/// The clamp-to-cap behaviour surfaces through the returned `clip_frames` +/// field; callers warn the user via stderr when they had to clamp. +pub fn decide_chain_routing( + frames: Option, + family: Option<&str>, + model: &str, + clip_frames_flag: Option, + motion_tail: u32, +) -> ChainRoutingDecision { + let Some(total_frames) = frames else { + return ChainRoutingDecision::SingleClip; + }; + + let is_ltx2_distilled = family == Some("ltx2") && model.contains("distilled"); + + if !is_ltx2_distilled { + // Non-chainable families: if the requested frame count is within a + // conservative single-clip budget, stay on the single-clip path and + // let the engine decide if it's acceptable. Otherwise, reject with + // a clear message rather than silently over-producing. + if total_frames <= LTX2_DISTILLED_CLIP_CAP { + return ChainRoutingDecision::SingleClip; + } + return ChainRoutingDecision::Rejected { + reason: format!( + "model '{model}' does not support chained video generation \ + (only LTX-2 distilled families do); specify --frames <= {} \ + per clip for this model", + LTX2_DISTILLED_CLIP_CAP, + ), + }; + } + + let cap = LTX2_DISTILLED_CLIP_CAP; + let effective_clip_frames = clip_frames_flag.unwrap_or(cap).min(cap); + + if total_frames <= effective_clip_frames { + return ChainRoutingDecision::SingleClip; + } + + if motion_tail >= effective_clip_frames { + return ChainRoutingDecision::Rejected { + reason: format!( + "--motion-tail ({motion_tail}) must be strictly less than \ + --clip-frames ({effective_clip_frames}) so every continuation \ + emits at least one new frame", + ), + }; + } + + ChainRoutingDecision::Chain { + clip_frames: effective_clip_frames, + motion_tail, + } +} + +/// Emit a stderr warning if `--clip-frames` was above the model's cap and +/// got clamped. Returns the effective value (caller should already have it). +pub fn warn_if_clamped(flag: Option, cap: u32) { + if let Some(requested) = flag { + if requested > cap { + crate::output::status!( + "{} --clip-frames {} exceeds model cap {}, clamping to {}", + theme::prefix_warning(), + requested, + cap, + cap, + ); + } + } +} + +/// Caller-supplied inputs for a chain run, bundled so the remote + local +/// paths can share a single helper without a 20-arg function signature. +#[allow(clippy::too_many_arguments)] +pub struct ChainInputs { + pub prompt: String, + pub model: String, + pub width: u32, + pub height: u32, + pub steps: u32, + pub guidance: f64, + pub strength: f64, + pub seed: Option, + pub fps: u32, + pub output_format: OutputFormat, + pub total_frames: u32, + pub clip_frames: u32, + pub motion_tail: u32, + pub source_image: Option>, + pub placement: Option, +} + +impl ChainInputs { + fn to_chain_request(&self) -> ChainRequest { + ChainRequest { + model: self.model.clone(), + stages: Vec::new(), + motion_tail_frames: self.motion_tail, + width: self.width, + height: self.height, + fps: self.fps, + seed: self.seed, + steps: self.steps, + guidance: self.guidance, + strength: self.strength, + output_format: self.output_format, + placement: self.placement.clone(), + prompt: Some(self.prompt.clone()), + total_frames: Some(self.total_frames), + clip_frames: Some(self.clip_frames), + source_image: self.source_image.clone(), + } + } +} + +/// Run a chain end-to-end, dispatching to the server (streaming) or the +/// local orchestrator based on the `local` flag. Handles encoding, save, +/// preview, and final status messages. +#[allow(clippy::too_many_arguments)] +pub async fn run_chain( + inputs: ChainInputs, + host: Option, + output: Option, + no_metadata: bool, + preview: bool, + local: bool, + gpus: Option, + t5_variant: Option, + qwen3_variant: Option, + qwen2_variant: Option, + qwen2_text_encoder_mode: Option, + eager: bool, + offload: bool, +) -> Result<()> { + // Validate the auto-expand form before touching the network / GPU so + // obvious mistakes (bad clip_frames math, too many stages) fail fast. + let chain_req = inputs.to_chain_request(); + let normalised = chain_req.clone().normalise()?; + let stage_count = normalised.stages.len() as u32; + + status!( + "{} Chain mode: {} frames → {} stages × {} frames (tail {})", + theme::icon_mode(), + inputs.total_frames, + stage_count, + inputs.clip_frames, + inputs.motion_tail, + ); + + let ctx = CliContext::new(host.as_deref()); + let config = ctx.config().clone(); + let embed_metadata = config.effective_embed_metadata(no_metadata.then_some(false)); + let _ = embed_metadata; // reserved for future metadata-embed work on chain output + + let t0 = std::time::Instant::now(); + let video = if local { + #[cfg(any(feature = "cuda", feature = "metal"))] + { + crate::ui::print_using_local_inference(); + run_chain_local( + &chain_req, + &config, + gpus, + t5_variant, + qwen3_variant, + qwen2_variant, + qwen2_text_encoder_mode, + eager, + offload, + ) + .await? + } + #[cfg(not(any(feature = "cuda", feature = "metal")))] + { + let _ = ( + gpus, + t5_variant, + qwen3_variant, + qwen2_variant, + qwen2_text_encoder_mode, + eager, + offload, + ); + anyhow::bail!( + "No mold server running and this binary was built without GPU support.\n\ + Either start a server with `mold serve` or rebuild with --features cuda" + ) + } + } else { + run_chain_remote(ctx.client(), &chain_req).await? + }; + + let elapsed_ms = t0.elapsed().as_millis() as u64; + let base_seed = inputs.seed.unwrap_or(0); + + encode_and_save( + &inputs, + &video, + output.as_deref(), + preview, + elapsed_ms, + base_seed, + )?; + + Config::write_last_model(&inputs.model); + Ok(()) +} + +/// Remote chain: streaming SSE with stacked progress bars. +async fn run_chain_remote(client: &MoldClient, req: &ChainRequest) -> Result { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + let render = tokio::spawn(render_chain_progress(rx)); + + let stream_result = client.generate_chain_stream(req, tx).await; + let _ = render.await; + + match stream_result { + Ok(Some(resp)) => Ok(resp.video), + Ok(None) => { + // Server predates chain endpoint; fall back to non-streaming. + status!( + "{} Server SSE chain endpoint unavailable, falling back to blocking endpoint", + theme::prefix_warning(), + ); + let resp = client.generate_chain(req).await?; + Ok(resp.video) + } + Err(e) => Err(e), + } +} + +#[cfg(any(feature = "cuda", feature = "metal"))] +#[allow(clippy::too_many_arguments)] +async fn run_chain_local( + chain_req: &ChainRequest, + config: &Config, + gpus: Option, + t5_variant_override: Option, + qwen3_variant_override: Option, + qwen2_variant_override: Option, + qwen2_text_encoder_mode_override: Option, + eager: bool, + offload: bool, +) -> Result { + use mold_core::manifest::find_manifest; + use mold_core::ModelPaths; + use mold_inference::LoadStrategy; + + // Normalise so we have expanded stages locally too. + let req = chain_req.clone().normalise()?; + + // Apply encoder-variant overrides before constructing the engine so the + // factory's auto-select picks them up. + apply_local_engine_env_overrides( + t5_variant_override.as_deref(), + qwen3_variant_override.as_deref(), + qwen2_variant_override.as_deref(), + qwen2_text_encoder_mode_override.as_deref(), + ); + + let model_name = req.model.clone(); + + // Ensure the model is pulled + config rows are in place. + let (paths, effective_config) = if let Some(p) = ModelPaths::resolve(&model_name, config) { + (p, config.clone()) + } else if find_manifest(&model_name).is_some() { + crate::output::status!( + "{} Model '{}' not found locally, pulling...", + theme::icon_info(), + model_name.bold(), + ); + let updated = super::pull::pull_and_configure( + &model_name, + &mold_core::download::PullOptions::default(), + ) + .await?; + let p = ModelPaths::resolve(&model_name, &updated).ok_or_else(|| { + anyhow::anyhow!("model '{model_name}' was pulled but paths could not be resolved") + })?; + (p, updated) + } else { + anyhow::bail!( + "no model paths configured for '{model_name}'. Add [models.{model_name}] \ + to ~/.mold/config.toml or pull via `mold pull {model_name}`." + ); + }; + + let is_eager = eager || std::env::var("MOLD_EAGER").is_ok_and(|v| v == "1"); + let load_strategy = if is_eager { + LoadStrategy::Eager + } else { + LoadStrategy::Sequential + }; + if is_eager { + std::env::set_var("MOLD_EAGER", "1"); + } + let is_offload = offload || std::env::var("MOLD_OFFLOAD").is_ok_and(|v| v == "1"); + + let gpu_selection = match &gpus { + Some(s) => mold_core::types::GpuSelection::parse(s)?, + None => effective_config.gpu_selection(), + }; + let discovered = mold_inference::device::discover_gpus(); + let available = mold_inference::device::filter_gpus(&discovered, &gpu_selection); + let gpu_ordinal = mold_inference::device::select_best_gpu(&available) + .map(|g| g.ordinal) + .unwrap_or(0); + + let mut engine = mold_inference::create_engine( + &model_name, + paths, + &effective_config, + load_strategy, + gpu_ordinal, + is_offload, + )?; + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + let render = tokio::spawn(render_chain_progress(rx)); + + let fps = req.fps; + let output_format = req.output_format; + let total_frames_opt = Some(req.total_frames.unwrap_or(u32::MAX)); + let req_clone = req.clone(); + + let handle = tokio::task::spawn_blocking(move || -> Result { + engine.load()?; + let renderer = engine.as_chain_renderer().ok_or_else(|| { + anyhow::anyhow!( + "model '{}' does not support chained video generation \ + (only LTX-2 distilled engines expose a ChainStageRenderer view)", + req_clone.model, + ) + })?; + let mut orch = mold_inference::ltx2::Ltx2ChainOrchestrator::new(renderer); + + let tx = tx; + let mut chain_cb = move |event: ChainProgressEvent| { + let _ = tx.send(event); + }; + let chain_output = orch.run(&req_clone, Some(&mut chain_cb))?; + + let mut frames = chain_output.frames; + if let Some(target) = total_frames_opt { + let target = target as usize; + if frames.len() > target { + frames.truncate(target); + } + } + if frames.is_empty() { + anyhow::bail!("chain run emitted zero frames after trim"); + } + + encode_local_frames(&frames, fps, output_format) + }); + + let result = handle.await??; + let _ = render.await; + Ok(result) +} + +#[cfg(any(feature = "cuda", feature = "metal"))] +fn apply_local_engine_env_overrides( + t5_variant: Option<&str>, + qwen3_variant: Option<&str>, + qwen2_variant: Option<&str>, + qwen2_text_encoder_mode: Option<&str>, +) { + if let Some(v) = t5_variant { + std::env::set_var("MOLD_T5_VARIANT", v); + } + if let Some(v) = qwen3_variant { + std::env::set_var("MOLD_QWEN3_VARIANT", v); + } + if let Some(v) = qwen2_variant { + std::env::set_var("MOLD_QWEN2_VARIANT", v); + } + if let Some(v) = qwen2_text_encoder_mode { + std::env::set_var("MOLD_QWEN2_TEXT_ENCODER_MODE", v); + } +} + +/// Encode stitched frames to the requested container. MP4 is feature-gated; +/// fall back to APNG when the CLI was built without `mp4`. +#[cfg(any(feature = "cuda", feature = "metal"))] +fn encode_local_frames( + frames: &[image::RgbImage], + fps: u32, + output_format: OutputFormat, +) -> Result { + use mold_inference::ltx_video::video_enc; + + let gif_preview = video_enc::encode_gif(frames, fps).unwrap_or_default(); + let thumbnail = video_enc::first_frame_png(frames).unwrap_or_default(); + + let (bytes, actual_format) = match output_format { + OutputFormat::Mp4 => { + #[cfg(feature = "mp4")] + { + (video_enc::encode_mp4(frames, fps)?, OutputFormat::Mp4) + } + #[cfg(not(feature = "mp4"))] + { + crate::output::status!( + "{} MP4 requested but this binary was built without --features mp4; \ + falling back to APNG", + theme::prefix_warning(), + ); + ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ) + } + } + OutputFormat::Apng => ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ), + OutputFormat::Gif => (video_enc::encode_gif(frames, fps)?, OutputFormat::Gif), + OutputFormat::Webp => { + crate::output::status!( + "{} WebP chain output not supported locally yet; falling back to APNG", + theme::prefix_warning(), + ); + ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ) + } + other => anyhow::bail!("{other:?} is not a video output format for chain generation"), + }; + + let width = frames[0].width(); + let height = frames[0].height(); + let frame_count = frames.len() as u32; + let duration_ms = if fps == 0 { + None + } else { + Some((frame_count as u64 * 1000) / fps as u64) + }; + + Ok(VideoData { + data: bytes, + format: actual_format, + width, + height, + frames: frame_count, + fps, + thumbnail, + gif_preview, + has_audio: false, + duration_ms, + audio_sample_rate: None, + audio_channels: None, + }) +} + +/// Shared epilogue: write the stitched video to stdout/file/gallery and +/// emit a terminal preview if requested. +fn encode_and_save( + inputs: &ChainInputs, + video: &VideoData, + output: Option<&str>, + preview: bool, + elapsed_ms: u64, + base_seed: u64, +) -> Result<()> { + let piped = is_piped(); + + if piped && output.is_none() { + let mut stdout = std::io::stdout().lock(); + stdout.write_all(&video.data)?; + stdout.flush()?; + } else { + let filename = match output { + Some("-") => { + let mut stdout = std::io::stdout().lock(); + stdout.write_all(&video.data)?; + stdout.flush()?; + None + } + Some(path) => Some(path.to_string()), + None => { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + Some(mold_core::default_output_filename( + &inputs.model, + timestamp, + video.format.extension(), + 1, + 0, + )) + } + }; + if let Some(ref filename) = filename { + if std::path::Path::new(filename).exists() { + status!("{} Overwriting: {}", theme::icon_alert(), filename); + } + std::fs::write(filename, &video.data)?; + status!( + "{} Saved: {} ({} frames, {}x{}, {} fps)", + theme::icon_done(), + filename.bold(), + video.frames, + video.width, + video.height, + video.fps, + ); + + // Persist to the gallery metadata DB. Build a synthetic + // GenerateRequest so the existing record_local_save helper can + // infer dimensions/seed/steps/etc. without a dedicated chain + // row schema. + let req = synth_generate_request(inputs, video); + crate::metadata_db::record_local_save( + std::path::Path::new(filename), + &req, + inputs.seed.unwrap_or(base_seed), + elapsed_ms, + video.format, + ); + } + } + + if preview && !piped { + // Best-effort: show the gif preview or fall back to the thumbnail + // or the video bytes themselves (GIF/APNG decode as images). + let bytes_for_preview: &[u8] = if !video.gif_preview.is_empty() { + &video.gif_preview + } else if !video.thumbnail.is_empty() { + &video.thumbnail + } else { + &video.data + }; + super::generate::preview_image(bytes_for_preview); + } + + status!( + "{} Done — {} in {:.1}s ({} frames, seed: {})", + theme::icon_done(), + inputs.model.bold(), + elapsed_ms as f64 / 1000.0, + video.frames, + inputs.seed.unwrap_or(base_seed), + ); + + Ok(()) +} + +fn synth_generate_request(inputs: &ChainInputs, video: &VideoData) -> mold_core::GenerateRequest { + mold_core::GenerateRequest { + prompt: inputs.prompt.clone(), + negative_prompt: None, + model: inputs.model.clone(), + width: inputs.width, + height: inputs.height, + steps: inputs.steps, + guidance: inputs.guidance, + seed: inputs.seed, + batch_size: 1, + output_format: video.format, + embed_metadata: Some(false), + scheduler: None, + edit_images: None, + source_image: inputs.source_image.clone(), + strength: inputs.strength, + mask_image: None, + control_image: None, + control_model: None, + control_scale: 1.0, + expand: None, + original_prompt: None, + lora: None, + frames: Some(video.frames), + fps: Some(video.fps), + upscale_model: None, + gif_preview: false, + enable_audio: None, + audio_file: None, + source_video: None, + keyframes: None, + pipeline: None, + loras: None, + retake_range: None, + spatial_upscale: None, + temporal_upscale: None, + placement: inputs.placement.clone(), + } +} + +/// Stacked progress bars for chain render: a parent "Chain" bar covering +/// all pixel frames and a transient per-stage bar covering denoise steps. +async fn render_chain_progress(mut rx: tokio::sync::mpsc::UnboundedReceiver) { + // Always draw to stderr so image bytes piped to stdout stay clean. + let mp = MultiProgress::with_draw_target(ProgressDrawTarget::stderr()); + + let parent = mp.add(ProgressBar::new(0)); + parent.set_style( + ProgressStyle::default_bar() + .template(&format!( + "{{prefix:.{c}}} [{{bar:30.{c}/dim}}] {{pos}}/{{len}} frames {{msg}}", + c = theme::SPINNER_STYLE, + )) + .unwrap() + .progress_chars("━╸─"), + ); + parent.set_prefix("Chain"); + parent.enable_steady_tick(Duration::from_millis(100)); + + let mut stage_bar: Option = None; + let mut stage_count: u32 = 0; + + while let Some(event) = rx.recv().await { + match event { + ChainProgressEvent::ChainStart { + stage_count: sc, + estimated_total_frames, + } => { + stage_count = sc; + parent.set_length(estimated_total_frames as u64); + parent.set_message(format!("(stages {sc})")); + } + ChainProgressEvent::StageStart { stage_idx } => { + if let Some(old) = stage_bar.take() { + old.finish_and_clear(); + } + parent.set_message(format!("stage {}/{}", stage_idx + 1, stage_count)); + let sb = mp.add(ProgressBar::new(0)); + sb.set_style( + ProgressStyle::default_bar() + .template(&format!( + " Stage {{prefix}} [{{bar:30.{c}/dim}}] {{pos}}/{{len}} steps", + c = theme::SPINNER_STYLE, + )) + .unwrap() + .progress_chars("━╸─"), + ); + sb.set_prefix(format!("{}", stage_idx + 1)); + sb.enable_steady_tick(Duration::from_millis(100)); + stage_bar = Some(sb); + } + ChainProgressEvent::DenoiseStep { + stage_idx: _, + step, + total, + } => { + if let Some(ref sb) = stage_bar { + if sb.length().unwrap_or(0) == 0 { + sb.set_length(total as u64); + } + sb.set_position(step as u64); + } + } + ChainProgressEvent::StageDone { + stage_idx: _, + frames_emitted, + } => { + if let Some(sb) = stage_bar.take() { + sb.finish_and_clear(); + } + parent.inc(frames_emitted as u64); + } + ChainProgressEvent::Stitching { total_frames } => { + if let Some(sb) = stage_bar.take() { + sb.finish_and_clear(); + } + parent.set_message(format!("stitching {total_frames} frames…")); + } + } + } + + if let Some(sb) = stage_bar.take() { + sb.finish_and_clear(); + } + parent.finish_and_clear(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn routing_single_clip_under_cap() { + let d = decide_chain_routing(Some(97), Some("ltx2"), "ltx-2-19b-distilled:fp8", None, 4); + assert_eq!(d, ChainRoutingDecision::SingleClip); + } + + #[test] + fn routing_single_clip_when_frames_absent() { + let d = decide_chain_routing(None, Some("ltx2"), "ltx-2-19b-distilled:fp8", None, 4); + assert_eq!(d, ChainRoutingDecision::SingleClip); + } + + #[test] + fn routing_chain_over_cap_ltx2_distilled() { + let d = decide_chain_routing(Some(200), Some("ltx2"), "ltx-2-19b-distilled:fp8", None, 4); + assert_eq!( + d, + ChainRoutingDecision::Chain { + clip_frames: 97, + motion_tail: 4, + }, + ); + } + + #[test] + fn routing_rejects_non_distilled_over_cap() { + let d = decide_chain_routing(Some(200), Some("flux"), "flux-dev:q4", None, 4); + match d { + ChainRoutingDecision::Rejected { reason } => { + assert!( + reason.contains("does not support chained video"), + "unexpected reason: {reason}" + ); + } + other => panic!("expected Rejected, got {other:?}"), + } + } + + #[test] + fn routing_rejects_non_ltx2_family_over_cap() { + // ltx-video (not ltx2) is not chainable in v1. + let d = decide_chain_routing(Some(200), Some("ltx-video"), "ltx-video:0.9.6", None, 4); + assert!(matches!(d, ChainRoutingDecision::Rejected { .. })); + } + + #[test] + fn routing_clip_frames_above_cap_clamps_to_cap() { + let d = decide_chain_routing( + Some(300), + Some("ltx2"), + "ltx-2-19b-distilled:fp8", + Some(200), + 4, + ); + assert_eq!( + d, + ChainRoutingDecision::Chain { + clip_frames: 97, + motion_tail: 4, + }, + ); + } + + #[test] + fn routing_clip_frames_under_cap_respected() { + let d = decide_chain_routing( + Some(300), + Some("ltx2"), + "ltx-2-19b-distilled:fp8", + Some(65), + 4, + ); + assert_eq!( + d, + ChainRoutingDecision::Chain { + clip_frames: 65, + motion_tail: 4, + }, + ); + } + + #[test] + fn routing_motion_tail_ge_clip_frames_rejects() { + let d = decide_chain_routing( + Some(300), + Some("ltx2"), + "ltx-2-19b-distilled:fp8", + Some(49), + 49, + ); + match d { + ChainRoutingDecision::Rejected { reason } => { + assert!( + reason.contains("--motion-tail"), + "unexpected reason: {reason}" + ); + } + other => panic!("expected Rejected, got {other:?}"), + } + } + + #[test] + fn routing_motion_tail_at_clip_frames_rejects() { + let d = decide_chain_routing(Some(200), Some("ltx2"), "ltx-2-19b-distilled:fp8", None, 97); + assert!(matches!(d, ChainRoutingDecision::Rejected { .. })); + } + + #[test] + fn ltx2_distilled_cap_matches_engine_constraint() { + // 97 = 8 * 12 + 1, satisfying the VAE 8k+1 constraint. + assert_eq!(LTX2_DISTILLED_CLIP_CAP % 8, 1); + } +} diff --git a/crates/mold-cli/src/commands/generate.rs b/crates/mold-cli/src/commands/generate.rs index 028fa089..cded5bb1 100644 --- a/crates/mold-cli/src/commands/generate.rs +++ b/crates/mold-cli/src/commands/generate.rs @@ -157,6 +157,11 @@ fn apply_local_engine_env_overrides( pub struct Ltx2Options { pub frames: Option, pub fps: Option, + /// Per-clip cap for chained rendering. `None` = use the model-family default + /// (currently 97 for LTX-2 distilled). Only read when `frames > cap`. + pub clip_frames: Option, + /// Motion-tail overlap between chained clips (pixel frames). + pub motion_tail: u32, pub enable_audio: Option, pub audio_file: Option>, pub source_video: Option>, @@ -210,6 +215,8 @@ pub async fn run( let Ltx2Options { frames, fps, + clip_frames, + motion_tail, enable_audio, audio_file, source_video, @@ -243,6 +250,117 @@ pub async fn run( } else { format }; + + // ── Chain routing ───────────────────────────────────────────────────── + // When --frames exceeds the per-clip cap, auto-build a ChainRequest and + // delegate to the chain helper. Only LTX-2 distilled is chainable in v1; + // other video families error fast rather than silently over-producing. + { + use super::chain::{decide_chain_routing, warn_if_clamped, ChainRoutingDecision}; + let decision = decide_chain_routing( + effective_frames, + family.as_deref(), + model, + clip_frames, + motion_tail, + ); + match decision { + ChainRoutingDecision::SingleClip => { + // Fall through to the existing single-clip path below. + } + ChainRoutingDecision::Rejected { reason } => { + anyhow::bail!(reason); + } + ChainRoutingDecision::Chain { + clip_frames: cf, + motion_tail: mt, + } => { + warn_if_clamped(clip_frames, super::chain::LTX2_DISTILLED_CLIP_CAP); + let (eff_w, eff_h) = effective_dimensions( + &config, + &model_cfg, + family.as_deref(), + width, + height, + source_image.as_deref(), + edit_images.as_deref(), + )?; + let eff_steps = steps.unwrap_or_else(|| model_cfg.effective_steps(&config)); + let eff_guidance = guidance.unwrap_or_else(|| model_cfg.effective_guidance()); + let eff_fps = effective_fps.unwrap_or(24); + let total_frames = effective_frames + .expect("decide_chain_routing only returns Chain when frames is Some"); + + // Chain path doesn't use batch/edit_images/mask/control/loras — + // those are single-clip concepts. If the user set them, warn and + // continue (we don't hard-error to keep the UX lenient). + if batch > 1 { + status!( + "{} --batch has no effect in chain mode; rendering a single stitched video", + theme::icon_warn(), + ); + } + + let inputs = super::chain::ChainInputs { + prompt: prompt.to_string(), + model: model.to_string(), + width: eff_w, + height: eff_h, + steps: eff_steps, + guidance: eff_guidance, + strength, + seed, + fps: eff_fps, + output_format, + total_frames, + clip_frames: cf, + motion_tail: mt, + source_image: source_image.clone(), + placement: placement.clone(), + }; + // Consume otherwise-unused LTX-2 knobs that chain v1 ignores so + // clippy doesn't fire `unused_variables` on the early return. + let _ = ( + &audio_file, + &source_video, + &keyframes, + &pipeline, + &loras, + &retake_range, + &spatial_upscale, + &temporal_upscale, + &enable_audio, + &mask_image, + &control_image, + &control_model, + control_scale, + &negative_prompt, + &original_prompt, + &batch_prompts, + &lora, + &scheduler, + expand, + ); + return super::chain::run_chain( + inputs, + host, + output, + no_metadata, + preview, + local, + gpus, + t5_variant, + qwen3_variant, + qwen2_variant, + qwen2_text_encoder_mode, + eager, + offload, + ) + .await; + } + } + } + let piped = is_piped(); // Reject batch > 1 when output goes to stdout (piped with no --output, or --output -) diff --git a/crates/mold-cli/src/commands/mod.rs b/crates/mold-cli/src/commands/mod.rs index ec4bbb27..82a19d16 100644 --- a/crates/mold-cli/src/commands/mod.rs +++ b/crates/mold-cli/src/commands/mod.rs @@ -1,3 +1,4 @@ +pub mod chain; pub mod clean; pub(crate) mod cleanup; pub mod config; diff --git a/crates/mold-cli/src/commands/run.rs b/crates/mold-cli/src/commands/run.rs index ba18d479..85269678 100644 --- a/crates/mold-cli/src/commands/run.rs +++ b/crates/mold-cli/src/commands/run.rs @@ -436,6 +436,8 @@ pub async fn run( batch: u32, frames: Option, fps: Option, + clip_frames: Option, + motion_tail: u32, audio: bool, no_audio: bool, audio_file: Option, @@ -825,6 +827,8 @@ pub async fn run( generate::Ltx2Options { frames, fps, + clip_frames, + motion_tail, enable_audio: if audio { Some(true) } else if no_audio { diff --git a/crates/mold-cli/src/main.rs b/crates/mold-cli/src/main.rs index 03293886..f7e616d2 100644 --- a/crates/mold-cli/src/main.rs +++ b/crates/mold-cli/src/main.rs @@ -377,6 +377,9 @@ Examples: /// Number of video frames to generate (video models only, e.g. ltx-video). /// Implies video output mode; output defaults to .gif format. + /// + /// For LTX-2 distilled, values above 97 automatically chain multiple + /// clips at render time (see `--clip-frames` / `--motion-tail`). #[arg(long, help_heading = "Video")] frames: Option, @@ -385,6 +388,19 @@ Examples: #[arg(long, help_heading = "Video")] fps: Option, + /// Per-clip frame cap for chained video. When --frames exceeds this, + /// the CLI splits into multiple chained clips stitched at render time. + /// Defaults to the model's native cap (97 for LTX-2 distilled). + #[arg(long, value_name = "N", help_heading = "Video")] + clip_frames: Option, + + /// Motion-tail overlap between chained clips in pixel frames. Each clip + /// after the first reuses this many trailing latents from the prior + /// clip, trimming the duplicated pixel frames at stitch time. 0 disables + /// latent carryover (simple concat). Default 4. + #[arg(long, value_name = "N", default_value_t = 4, help_heading = "Video")] + motion_tail: u32, + /// Enable synchronized audio for LTX-2 / LTX-2.3 generation. #[arg(long, help_heading = "Video", conflicts_with = "no_audio")] audio: bool, @@ -1147,6 +1163,8 @@ async fn run() -> anyhow::Result<()> { batch, frames, fps, + clip_frames, + motion_tail, audio, no_audio, audio_file, @@ -1205,6 +1223,8 @@ async fn run() -> anyhow::Result<()> { batch, frames, fps, + clip_frames, + motion_tail, audio, no_audio, audio_file, @@ -2131,6 +2151,50 @@ mod tests { } } + #[test] + fn run_chain_flags_parse() { + let cli = parse(&[ + "run", + "ltx-2-19b-distilled:fp8", + "a cat", + "--frames", + "200", + "--clip-frames", + "97", + "--motion-tail", + "4", + ]); + match cli.command { + Commands::Run { + frames, + clip_frames, + motion_tail, + .. + } => { + assert_eq!(frames, Some(200)); + assert_eq!(clip_frames, Some(97)); + assert_eq!(motion_tail, 4); + } + _ => panic!("expected Run"), + } + } + + #[test] + fn run_motion_tail_defaults_to_four() { + let cli = parse(&["run", "ltx-2-19b-distilled:fp8", "a cat", "--frames", "200"]); + match cli.command { + Commands::Run { + motion_tail, + clip_frames, + .. + } => { + assert_eq!(motion_tail, 4, "default motion tail must be 4 frames"); + assert_eq!(clip_frames, None); + } + _ => panic!("expected Run"), + } + } + // --- Regression test for issue #190: --version includes git SHA --- #[test] From 62111820d246a9d0cdd44bf763fc7f2f08c3f1e2 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 19:02:50 -0700 Subject: [PATCH 11/31] docs(chain): ltx2 guide, api endpoint, changelog, and skill updates Document render-chain v1 across the four surfaces: a new "Chained video output" section in website/models/ltx2.md explaining the per-clip cap, motion-tail carryover, and the --frames / --clip-frames / --motion-tail CLI contract; request/response/SSE schemas for the new POST /api/generate/chain[/stream] endpoints in website/api/index.md; an Unreleased/Added bullet in CHANGELOG.md covering the feature end-to-end; and the new flags + endpoint in .claude/skills/mold/SKILL.md so OpenClaw and the other AI agents surface chained video correctly. Part of render-chain v1 (Phase 4). Co-Authored-By: Claude Opus 4.7 (1M context) --- .claude/skills/mold/SKILL.md | 18 +++- CHANGELOG.md | 1 + website/api/index.md | 157 +++++++++++++++++++++++++++++++++++ website/models/ltx2.md | 81 ++++++++++++++++++ 4 files changed, 256 insertions(+), 1 deletion(-) diff --git a/.claude/skills/mold/SKILL.md b/.claude/skills/mold/SKILL.md index 4c8ddeee..97f28b29 100644 --- a/.claude/skills/mold/SKILL.md +++ b/.claude/skills/mold/SKILL.md @@ -178,7 +178,9 @@ mold run ltx-2-19b-distilled:fp8 "lantern-lit cave entrance" --camera-control do **Models:** `ltx-2-19b-dev:fp8`, `ltx-2-19b-distilled:fp8`, `ltx-2.3-22b-dev:fp8`, `ltx-2.3-22b-distilled:fp8` -**Important flags:** `--audio`, `--no-audio`, `--audio-file`, `--video`, repeatable `--keyframe`, repeatable `--lora`, `--pipeline`, `--retake`, `--camera-control`, `--spatial-upscale`, `--temporal-upscale` +**Important flags:** `--audio`, `--no-audio`, `--audio-file`, `--video`, repeatable `--keyframe`, repeatable `--lora`, `--pipeline`, `--retake`, `--camera-control`, `--spatial-upscale`, `--temporal-upscale`, `--clip-frames`, `--motion-tail` + +**Chained (arbitrary-length) video output:** for LTX-2 19B and 22B distilled models, `--frames` above the 97-frame per-clip cap automatically renders multiple clips with a motion-tail of latents carried across each clip boundary, then stitches them into a single MP4. The CLI picks this path transparently — `mold run ltx-2-19b-distilled:fp8 "a cat walking" --frames 400` produces one 400-frame MP4 from 5 chained stages. Advanced callers can override the per-clip length via `--clip-frames N` (must be `8k+1`, clamped to the model cap) and the overlap via `--motion-tail N` (default 4 pixel frames, 0 disables carryover). Chains fail closed on mid-stage failure (no partial output) and run on a single GPU. Other model families reject `--frames > 97` with an actionable error. **Current constraints:** `x2` spatial upscaling is wired across the family, `x1.5` spatial upscaling is wired for `ltx-2.3-*`, and `x2` temporal upscaling is wired in the native runtime. Camera-control preset aliases currently auto-resolve the published LTX-2 19B LoRAs only. The family runs through the native Rust stack in `mold-inference`, with CUDA as the supported backend for real local generation, CPU as a correctness-only fallback, and Metal unsupported. On 24 GB Ada GPUs such as the RTX 4090, the validated path stays on the compatible `fp8-cast` mode rather than Hopper-only `fp8-scaled-mm`. The native CUDA matrix is validated across 19B/22B text+audio-video, image-to-video, audio-to-video, keyframe, retake, public IC-LoRA, spatial upscale (`x1.5` / `x2` where published), and temporal upscale (`x2`). When requests go through `mold serve`, the built-in body limit is `64 MiB`, which is enough for common inline source-video and source-audio workflows. @@ -535,6 +537,20 @@ MOLD_HOST=http://gpu-host:7680 mold run "a cat" MOLD_OUTPUT_DIR=/srv/mold/output mold serve ``` +### HTTP API Endpoints + +Core endpoints exposed by `mold serve` (full list + schemas at `/api/docs`): + +- `POST /api/generate` — image/video generation, raw bytes response +- `POST /api/generate/stream` — SSE progress + base64 complete event +- `POST /api/generate/chain` — chained arbitrary-length video (LTX-2 distilled); body is `mold_core::chain::ChainRequest` (canonical `stages[]` or auto-expand `prompt`+`total_frames`+`clip_frames`) +- `POST /api/generate/chain/stream` — same as above, SSE progress with per-stage `denoise_step` events +- `POST /api/expand` — LLM prompt expansion +- `GET /api/models` · `POST /api/models/load` · `POST /api/models/pull` · `DELETE /api/models/unload` +- `GET /api/gallery` · `GET /api/gallery/image/:name` · `GET /api/gallery/thumbnail/:name` · `DELETE /api/gallery/image/:name` +- `POST /api/upscale` · `POST /api/upscale/stream` +- `GET /api/status` · `GET /health` · `GET /api/capabilities` + ### Prometheus Metrics When built with the `metrics` feature flag (included in Docker images and Nix builds), the server exposes a `GET /metrics` endpoint in Prometheus text exposition format. This endpoint is excluded from auth and rate limiting for monitoring scrapers. diff --git a/CHANGELOG.md b/CHANGELOG.md index af96fa71..1867930b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **Render chain for arbitrary-length LTX-2 distilled video.** `mold run ltx-2-19b-distilled:fp8 "a cat walking" --frames 400` now produces a single stitched MP4 by splitting the request into multiple per-clip renders and carrying a motion-tail of latents across each clip boundary so the continuation stays coherent without a VAE encode/decode round-trip. New server endpoints `POST /api/generate/chain` and `POST /api/generate/chain/stream` (SSE) accept either a canonical `stages[]` body or an auto-expand form (`prompt` + `total_frames` + `clip_frames`) — the wire format is stages-based from day one so the v2 movie-maker UI can author per-stage prompts/keyframes without a breaking change. Request/response/event types live in `crates/mold-core/src/chain.rs` (`ChainRequest`, `ChainResponse`, `ChainProgressEvent`, `SseChainCompleteEvent`); the LTX-2 orchestrator is in `crates/mold-inference/src/ltx2/chain.rs` (`Ltx2ChainOrchestrator`, `ChainTail`); the server routes in `crates/mold-server/src/routes_chain.rs`; and the CLI side in `crates/mold-cli/src/commands/chain.rs`. `mold run` auto-routes to the chain endpoint when `--frames` exceeds the model's per-clip cap (97 for LTX-2 19B/22B distilled); non-distilled families fail fast with an actionable error instead of silently over-producing. New flags `--clip-frames N` (default = model cap) and `--motion-tail N` (default 4, 0 disables carryover) let advanced callers tune the split. The orchestrator derives per-stage seeds as `base_seed ^ ((stage_idx as u64) << 32)` so the whole chain reproduces from a single seed without identical-noise artefacts when every stage shares a prompt. Over-production at the final clip is trimmed from the tail (the head carries the user-anchored starting image and is perceptually load-bearing); mid-chain failures fail closed with HTTP 502 and no partial stitch is ever written to the gallery. Chains run on a single GPU — the chain handler bypasses the single-job queue and holds the `ModelCache` lock for the full chain duration (a multi-minute compound operation would otherwise stall the FIFO queue). Both the remote SSE path and the `--local` in-process path funnel through the same orchestrator via `Ltx2Engine::as_chain_renderer`, and `mold run` renders stacked `indicatif` progress bars (parent "Chain" frame counter + per-stage denoise-step bar). v1 is LTX-2 distilled only, single-GPU, and single-prompt; per-stage prompts, keyframes, selective regen, and multi-GPU stage fan-out are v2 movie-maker work. - **In-browser model downloads with queued progress, ETA, cancel, and retry** ([#255](https://github.com/utensils/mold/pull/255)). `ModelPicker.vue` now shows `(X GB)` next to every model — click an undownloaded one to enqueue a pull without leaving the generate flow. A new `DownloadsDrawer` (opened from a TopBar button with an active/queued count badge) shows per-file progress, client-computed ETA (10 s sliding window), and cancel/retry controls. Undownloaded models in the picker switch to inline progress or a "Queued (#N)" chip while their job is alive, and the picker auto-refreshes on `JobDone` so the model becomes selectable without a page reload. Server-side: a new single-writer `DownloadQueue` in `AppState` drives the existing `mold_core::download::pull_model_with_callback` one model at a time (files sequential inside a set — HF's CDN is bandwidth-bound, so file-level parallelism would only trip rate limits), with one auto-retry on transient failure. Cancellation aborts the in-flight pull, cleans up partials under `MOLD_MODELS_DIR//` while preserving any `.sha256-verified` markers, and leaves the HF blob cache intact so resume is cheap. The same cleanup runs on terminal failures, not just cancel. New routes: `POST /api/downloads` (idempotent — returns the existing job id on a second enqueue), `DELETE /api/downloads/:id`, `GET /api/downloads` (active + queued + last 20 history), `GET /api/downloads/stream` (SSE multiplex of `DownloadEvent` frames — `Enqueued`, `Started`, `Progress`, `FileDone`, `JobDone`, `JobFailed`, `JobCancelled`). Existing `POST /api/models/pull` becomes a thin compat shim that enqueues via the queue and re-emits the legacy SSE event shape, so the TUI keeps working unchanged. - **Always-visible VRAM + system RAM telemetry on `/generate`** ([#254](https://github.com/utensils/mold/pull/254)). A new `ResourceStrip.vue` docks at the bottom of the Composer sidebar on desktop (and collapses to a `🧠 used · total` chip in the TopBar on narrow viewports), showing one stacked-bar row per discovered GPU plus one for system RAM. Each row breaks usage into `mold` / `other` / `free` on CUDA hosts with per-process attribution (NVML feature-gated as `mold-ai-server` `--features nvml`, `nvidia-smi` subprocess fallback on by default) — on Metal the per-process fields are intentionally `None` and the SPA hides those breakdowns, since macOS doesn't expose per-process GPU attribution without private entitlements. Aggregated once per second on the server into a `ResourceSnapshot { hostname, gpus, system_ram }`, exposed as `GET /api/resources` (one-shot; `503` before the first aggregator tick) and `GET /api/resources/stream` (SSE broadcast with 15 s keepalive and the cached snapshot prepended as the first frame so new subscribers don't wait a full second). The aggregator handle is bound to `axum::serve`'s shutdown path so it's aborted on graceful exit. The strip's `useResources` composable is a provide/inject singleton mounted in `App.vue`, and it exposes a `gpuList: ComputedRef` that the new device-placement UI consumes directly. - **Per-component device placement for FLUX, Flux.2, Z-Image, and Qwen-Image** ([#256](https://github.com/utensils/mold/pull/256)). A new `PlacementPanel` disclosure inside the Composer lets users override which device each part of the pipeline runs on. Tier 1 is a single "Text encoders: Auto / CPU / GPU N" dropdown that applies to every model family (SD1.5, SDXL, SD3.5, Wuerstchen, LTX-Video, LTX-2 in addition to the Tier 2 four) — picking CPU reliably moves the text encoder off-GPU so a large transformer can stay on-device without triggering block-level offload. Tier 2 adds per-component selects (transformer, VAE, and family-appropriate encoder slots) for FLUX, Flux.2, Z-Image, and Qwen-Image, where the plumbing is cheapest and the value is clearest. SD3.5 was marked stretch in the design and cut cleanly — the UI correctly hides Advanced for SD3.5 with a tooltip so no user sees an override that silently no-ops. A new `DevicePlacement` serde type (`DeviceRef = Auto | Cpu | Gpu(ordinal)` plus an optional `AdvancedPlacement` sub-struct for per-component overrides) rides as an optional field on `GenerateRequest`; `None` preserves the existing VRAM-aware auto-placement end-to-end. A shared `resolve_device()` helper in `mold_inference::device` (and a companion `effective_device_ref()` shared by the four Tier-2 engines) maps each `DeviceRef` variant to a `candle_core::Device`, returning a clean `anyhow::Error` for bad ordinals instead of panicking. Defaults are saved per-model in `[models."name:tag".placement]` (with `MOLD_PLACE_TEXT_ENCODERS`, `MOLD_PLACE_TRANSFORMER`, `MOLD_PLACE_VAE`, `MOLD_PLACE_CLIP_L`, `MOLD_PLACE_CLIP_G`, `MOLD_PLACE_T5`, `MOLD_PLACE_QWEN` env overrides) via a new `PUT /api/config/model/:name/placement` route (with `DELETE` to clear); the route now returns a real `500` when `Config::save()` fails instead of silently lying to the client. The placement UI reads its GPU list from `useResources().gpuList`, so spinning up a mold server on a dual-3090 box auto-populates "GPU 0 · RTX 3090" / "GPU 1 · RTX 3090" in every dropdown without any extra discovery wiring. `mold run` gains matching CLI flags — `--device-text-encoders`, `--device-transformer`, `--device-vae`, `--device-t5`, `--device-clip-l`, `--device-clip-g`, `--device-qwen` — which override env vars and config; flag parse errors surface with the specific flag name so `--device-vae banana` reports `--device-vae: invalid device 'banana' (expected auto|cpu|gpu[:N])` instead of a generic failure. Documented in `website/guide/configuration.md` (new "Per-component device placement" section) and `website/guide/performance.md` (the "CPU text encoders" subsection now points at the CLI flags for deliberate VRAM tuning). diff --git a/website/api/index.md b/website/api/index.md index ff04172f..9c5c7c5a 100644 --- a/website/api/index.md +++ b/website/api/index.md @@ -8,6 +8,8 @@ When running `mold serve`, you get a REST API for remote image generation. | -------- | ------------------------------ | ------------------------------------ | | `POST` | `/api/generate` | Generate images from prompt | | `POST` | `/api/generate/stream` | Generate with SSE progress streaming | +| `POST` | `/api/generate/chain` | Chained video generation (LTX-2) | +| `POST` | `/api/generate/chain/stream` | Chained video with SSE progress | | `POST` | `/api/expand` | Expand a prompt using LLM | | `GET` | `/api/models` | List available models | | `POST` | `/api/models/load` | Load/swap the active model | @@ -223,6 +225,161 @@ server internally. RunPod's proxy has a 100-second timeout. Use the SSE streaming endpoint for long generations to keep the connection alive. ::: +## `/api/generate/chain` + +Chained video generation for LTX-2 distilled models. Splits a long video into +N per-clip renders, threads a motion-tail of latents across each clip +boundary, and returns a single stitched MP4. See the +[LTX-2 chained video output guide](/models/ltx2#chained-video-output) for the +user-facing story; this section documents the wire format. + +The request body maps to `mold_core::chain::ChainRequest`; the response body +maps to `mold_core::chain::ChainResponse`. The canonical schema lives in the +interactive docs at `/api/docs` (served by the running mold server) and in the +OpenAPI JSON at `/api/openapi.json`. + +The server accepts either a pre-authored `stages[]` body or the auto-expand +form (single `prompt` + `total_frames` + `clip_frames`). Auto-expand is the +shape `mold run` sends; the canonical `stages[]` shape is reserved for the +forthcoming movie-maker UI that will author per-stage prompts/keyframes. Both +normalise to the same internal `Vec` before any engine work kicks +off. + +**Auto-expand body** (what `mold run --frames N` emits): + +```json +{ + "model": "ltx-2-19b-distilled:fp8", + "prompt": "a cat walking through autumn leaves", + "total_frames": 400, + "clip_frames": 97, + "source_image": "", + "motion_tail_frames": 4, + "width": 1216, + "height": 704, + "fps": 24, + "seed": 42, + "steps": 8, + "guidance": 3.0, + "strength": 1.0, + "output_format": "mp4" +} +``` + +**Canonical body** (what the v2 movie-maker UI will author): + +```json +{ + "model": "ltx-2-19b-distilled:fp8", + "stages": [ + { "prompt": "a cat walking", "frames": 97, "source_image": "" }, + { "prompt": "a cat walking", "frames": 97 }, + { "prompt": "a cat walking", "frames": 97 }, + { "prompt": "a cat walking", "frames": 97 } + ], + "motion_tail_frames": 4, + "width": 1216, + "height": 704, + "fps": 24, + "seed": 42, + "steps": 8, + "guidance": 3.0, + "strength": 1.0, + "output_format": "mp4" +} +``` + +**Response:** + +```json +{ + "video": { + "data": "", + "format": "mp4", + "width": 1216, + "height": 704, + "frames": 400, + "fps": 24, + "thumbnail": "", + "gif_preview": "", + "has_audio": false, + "duration_ms": 16666 + }, + "stage_count": 5, + "gpu": 0 +} +``` + +**Error cases:** + +- `422 Unprocessable Entity` — validation failure (missing `prompt` + + `total_frames` in the auto-expand form, a stage with non-`8k+1` `frames`, + `motion_tail_frames >= clip_frames`, more than 16 stages, etc.). +- `422 Unprocessable Entity` — unsupported model family. Only LTX-2 distilled + engines expose a chain renderer; other families are rejected with an + error that names the constraint. +- `502 Bad Gateway` — a stage errored mid-chain. The whole chain is discarded + and nothing is written to the gallery; v1 is fail-closed and partial + resume is a v2 feature. + +::: tip Queue behaviour +The chain handler deliberately **bypasses the single-job queue**. A chain is a +multi-minute compound operation that would stall the FIFO queue for every +other request, so the handler takes the engine out of `ModelCache` for the +full chain duration and restores it on completion (or error). Chains +therefore run one-at-a-time on a given GPU; submit chains to separate GPUs +via `MOLD_GPUS` / `--gpus` if you need parallelism. +::: + +## `/api/generate/chain/stream` + +Same request body as `/api/generate/chain`, with the response delivered as +Server-Sent Events. Progress frames stream as `event: progress` and the +terminal frame is either `event: complete` (success) or `event: error` +(failure; the connection closes after the error frame). + +Progress event payloads map to `mold_core::chain::ChainProgressEvent` variants: + +```text +event: progress +data: {"type":"chain_start","stage_count":5,"estimated_total_frames":485} + +event: progress +data: {"type":"stage_start","stage_idx":0} + +event: progress +data: {"type":"denoise_step","stage_idx":0,"step":1,"total":8} + +event: progress +data: {"type":"stage_done","stage_idx":0,"frames_emitted":97} + +event: progress +data: {"type":"stitching","total_frames":385} + +event: complete +data: {"video":"","format":"mp4","width":1216,"height":704,"frames":400,"fps":24,"thumbnail":"","gif_preview":"","has_audio":false,"duration_ms":16666,"stage_count":5,"gpu":0,"generation_time_ms":226812} +``` + +The `complete` event payload maps to `mold_core::chain::SseChainCompleteEvent`. +Non-denoise engine events (weight loads, cache hits, etc.) are intentionally +not forwarded in v1 — the UX goal is per-stage progress, not per-component +telemetry. + +```bash +curl -N -X POST http://localhost:7680/api/generate/chain/stream \ + -H "Content-Type: application/json" \ + -d '{ + "model": "ltx-2-19b-distilled:fp8", + "prompt": "a cat walking through autumn leaves", + "total_frames": 400, + "clip_frames": 97, + "motion_tail_frames": 4, + "width": 1216, "height": 704, "fps": 24, + "steps": 8, "guidance": 3.0, + "output_format": "mp4" + }' +``` + ## `/api/status` Example response: diff --git a/website/models/ltx2.md b/website/models/ltx2.md index 54d0b1fa..f1561b30 100644 --- a/website/models/ltx2.md +++ b/website/models/ltx2.md @@ -96,6 +96,87 @@ mold run ltx-2.3-22b-distilled:fp8 \ --format mp4 ``` +## Chained video output + +The LTX-2 distilled pipeline maxes out at 97 pixel frames per clip (13 latent +frames after the VAE's 8× temporal compression — `8 × 12 + 1 = 97` satisfies the +`8k+1` frame-grid constraint). For anything longer, mold renders a _chain_: the +request is split into N sub-clips, each generated back-to-back, and stitched +into a single MP4 at the end. mold keeps the last few frames of clip _N_'s +final latents in memory and threads them directly into clip _N+1_'s +conditioning, skipping a VAE encode/decode round-trip so the continuation +stays visually coherent. + +`mold run` routes automatically: when `--frames` is `≤ 97` you stay on the +single-clip path; above 97 the request is rewritten into a chain and dispatched +to the new `/api/generate/chain/stream` endpoint. Chaining is supported for +LTX-2 19B and 22B distilled today. Other model families reject +`--frames > 97` with an actionable error rather than silently over-producing. + +```console +$ mold run ltx-2-19b-distilled:fp8 "a cat walking through autumn leaves" \ + --image cat.png --frames 400 + +→ Chain mode: 400 frames → 5 stages × 97 frames (tail 4) +Chain [━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━] 385/385 frames (stages 5) + Stage 1 [━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━] 8/8 steps + Stage 2 [━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━] 8/8 steps + Stage 3 [━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━] 8/8 steps + Stage 4 [━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━] 8/8 steps + Stage 5 [━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━] 8/8 steps +✓ Saved: mold-ltx-2-19b-distilled-.mp4 (400 frames, 1216x704, 24 fps) +✓ Done — ltx-2-19b-distilled:fp8 in 226.8s (400 frames, seed: 42) +``` + +### Motion-tail carryover + +`--motion-tail N` (default 4) controls how many trailing pixel frames of each +clip are reused as latent-space conditioning for the next. Instead of decoding +the prior clip's last frame back to RGB and re-encoding it through the VAE as +a new `source_image`, mold narrows the final denoise tensor along its time +axis and patchifies those latent tokens directly into the next stage's +`StageVideoConditioning` — so the handoff never leaves latent space. At stitch +time, every stage after the first drops its leading `N` output frames because +those are the overlap region shared with the prior clip. + +- `--motion-tail 0` — hard concatenation, no overlap. Visible seams are common + at clip boundaries; useful when you _want_ discrete shots. +- `--motion-tail 4` — the default. One latent frame of carryover at `fps=24` + gives the transformer enough temporal context to continue motion, object + identity, and lighting across the seam without wasting new frames. +- Higher values buy more seam-smoothing at the cost of fewer fresh pixel + frames per clip. Must stay strictly below `--clip-frames`. + +### Flags + +| Flag | Default | Description | +| ----------------- | ---------------- | ------------------------------------------------------------------------------------ | +| `--frames N` | model default | Total stitched length. Above the per-clip cap (97 for LTX-2 distilled), auto-chains. | +| `--clip-frames N` | model cap (`97`) | Per-clip length. Must be `8k+1`; values above the cap are clamped with a warning. | +| `--motion-tail N` | `4` | Pixel-frame overlap between clips. `0` disables carryover. | + +When the final clip over-produces (stage math rarely lands exactly on +`total_frames`), mold trims from the tail so the user-anchored starting image +at the head stays intact. + +### v1 constraints + +- **LTX-2 19B and 22B distilled only.** Other LTX-2 / LTX-Video variants and + every image-family model reject `--frames` above their single-clip budget. +- **Single GPU per chain.** Every stage runs on the GPU the engine was loaded + onto — multi-GPU stage fan-out is a v2 movie-maker feature. +- **Fail-closed.** If any stage errors, the whole chain returns `502` and + nothing is written to the gallery. There is no partial-resume in v1. +- **Single prompt per chain from the CLI.** The server already accepts + per-stage prompts (see [`POST /api/generate/chain`](/api/#api-generate-chain)), + but `mold run` replicates one prompt across every stage for now. + +The rest of the LTX-2 surface — `--image`, `--audio-file`, `--lora`, +`--camera-control`, `--spatial-upscale`, `--temporal-upscale`, and so on — +applies to chain renders the same way it applies to single-clip renders. An +`--image` supplied on the CLI lands on `stages[0]` and is carried forward by +the motion-tail latents from there. + ## Example Clips Here are a few longer LTX-2 examples rendered with mold. The docs page embeds From 766322ebbc84fbcdfce92867e582a27673a167a9 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 21:43:17 -0700 Subject: [PATCH 12/31] fix(cli): pass owned String to create_engine in local chain path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `cuda`/`metal` feature-gated local orchestrator branch in `run_chain_local` passed `&model_name` to `mold_inference::create_engine`, which takes `model_name: String`. Phase 3's verification only ran `cargo check --features preview,discord,expand,tui,webp,mp4` — the feature-matrix omitted `cuda`/`metal`, so CI and the local-default check both missed the mismatch. Caught at rebuild time on killswitch (sm_86 / RTX 3090 dual-GPU build). `cargo check -p mold-ai --features metal,expand` now clean locally. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/mold-cli/src/commands/chain.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/mold-cli/src/commands/chain.rs b/crates/mold-cli/src/commands/chain.rs index a0988b72..4a8796bf 100644 --- a/crates/mold-cli/src/commands/chain.rs +++ b/crates/mold-cli/src/commands/chain.rs @@ -354,7 +354,7 @@ async fn run_chain_local( .unwrap_or(0); let mut engine = mold_inference::create_engine( - &model_name, + model_name, paths, &effective_config, load_strategy, From 41e85f7a2112bc2e2d86960e5ae37bfc823e70e1 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 22:04:21 -0700 Subject: [PATCH 13/31] fix(sd3): truncate CLIP token sequences to 77 with EOS preserved `ClipWithTokenizer::encode_text_to_embedding` padded up to `max_position_embeddings` (77) but never truncated down. Prompts that tokenised to more than 77 CLIP tokens fed an `[1, N, 768]` tensor into `ClipTextTransformer`, where the 77-slot position-embedding broadcast-add blew up with `shape mismatch in broadcast_add, lhs: [1, N, 768], rhs: [1, 77, 768]`. The pooled-output slice at `eos_position = tokens.len() - 1` was also out-of-bounds on the same path. Extract the token preparation into a pure `prepare_clip_tokens` helper that truncates to `max_len` (copying the trailing EOS token into the final slot so the pooled branch still reads an EOS-position hidden state) and then pads up to `max_len`. Wire it into both CLIP-L and CLIP-G via the shared `ClipWithTokenizer` path, so every `sd3*` model benefits. Unit-tested weight-free with four cases: short prompt, exact-77, 132-token overlong (matches the observed failure shape), and an empty tokenisation. All four pass; the 132-token test was red before the fix and is green after. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../mold-inference/src/encoders/sd3_clip.rs | 106 +++++++++++++++++- 1 file changed, 101 insertions(+), 5 deletions(-) diff --git a/crates/mold-inference/src/encoders/sd3_clip.rs b/crates/mold-inference/src/encoders/sd3_clip.rs index c396563f..ddb3a89d 100644 --- a/crates/mold-inference/src/encoders/sd3_clip.rs +++ b/crates/mold-inference/src/encoders/sd3_clip.rs @@ -83,18 +83,16 @@ impl ClipWithTokenizer { .ok_or_else(|| anyhow::anyhow!("Failed to tokenize CLIP end-of-text"))?, }; - let mut tokens = self + let raw_tokens = self .tokenizer .encode(prompt, true) .map_err(|e| anyhow::anyhow!("CLIP tokenization failed: {e}"))? .get_ids() .to_vec(); - let eos_position = tokens.len() - 1; + let (tokens, eos_position) = + prepare_clip_tokens(raw_tokens, self.max_position_embeddings, pad_id); - while tokens.len() < self.max_position_embeddings { - tokens.push(pad_id); - } let tokens = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; let (_text_embeddings, text_embeddings_penultimate) = clip.forward_until_encoder_layer(&tokens, usize::MAX, -2)?; @@ -293,3 +291,101 @@ impl SD3TripleEncoder { self.clip_l.model.is_some() && self.clip_g.model.is_some() && self.t5.model.is_some() } } + +/// Prepare a CLIP token sequence for the fixed position-embedding window. +/// +/// CLIP's position-embedding table holds exactly `max_len` entries, so a token +/// tensor longer than that fails inside candle's `broadcast_add` when the +/// position embeddings are applied. This helper: +/// +/// - Truncates overlong sequences to `max_len`, copying the trailing token +/// (the tokenizer's EOS, assuming `add_special_tokens=true`) into the last +/// slot so the pooled-output path still reads an EOS-position hidden state. +/// - Pads short sequences up to `max_len` with `pad_id`. +/// - Returns the final `tokens` vector and the `eos_position` index the caller +/// uses to slice the pooled output. +fn prepare_clip_tokens(mut raw_tokens: Vec, max_len: usize, pad_id: u32) -> (Vec, usize) { + let original_len = raw_tokens.len(); + + if original_len > max_len { + let eos_id = *raw_tokens + .last() + .expect("original_len > max_len implies non-empty"); + raw_tokens.truncate(max_len); + if let Some(last) = raw_tokens.last_mut() { + *last = eos_id; + } + tracing::debug!( + "SD3 CLIP prompt exceeded {} tokens ({} raw); truncated with EOS preserved", + max_len, + original_len, + ); + } + + let eos_position = raw_tokens.len().saturating_sub(1); + + while raw_tokens.len() < max_len { + raw_tokens.push(pad_id); + } + + (raw_tokens, eos_position) +} + +#[cfg(test)] +mod tests { + use super::prepare_clip_tokens; + + const MAX_LEN: usize = 77; + const PAD_ID: u32 = 0; + const EOS_ID: u32 = 49407; + + #[test] + fn pads_short_prompt_to_max_len() { + let raw = vec![49406, 10, 20, 30, EOS_ID]; // 5 tokens, last is EOS + let (tokens, eos) = prepare_clip_tokens(raw, MAX_LEN, PAD_ID); + assert_eq!(tokens.len(), MAX_LEN, "must pad up to max_len"); + assert_eq!(eos, 4, "eos_position tracks the raw EOS slot"); + assert_eq!(tokens[4], EOS_ID, "EOS preserved at original position"); + assert_eq!(tokens[5], PAD_ID, "pads follow the real tokens"); + assert_eq!(*tokens.last().unwrap(), PAD_ID); + } + + #[test] + fn leaves_exact_length_untouched() { + let mut raw: Vec = (1..MAX_LEN as u32).collect(); + raw.push(EOS_ID); + assert_eq!(raw.len(), MAX_LEN); + let (tokens, eos) = prepare_clip_tokens(raw.clone(), MAX_LEN, PAD_ID); + assert_eq!(tokens.len(), MAX_LEN); + assert_eq!(eos, MAX_LEN - 1); + assert_eq!(tokens, raw); + } + + #[test] + fn truncates_overlong_prompt_preserving_eos() { + // 132-token sequence — matches the shapes in the original bug report + // ([1, 132, 768] vs [1, 77, 768]). + let mut raw: Vec = (1..=131).collect(); + raw.push(EOS_ID); + assert_eq!(raw.len(), 132); + + let (tokens, eos) = prepare_clip_tokens(raw, MAX_LEN, PAD_ID); + + assert_eq!(tokens.len(), MAX_LEN, "overlong sequence must be truncated"); + assert_eq!(eos, MAX_LEN - 1, "eos_position must land on the last slot"); + assert_eq!( + tokens[MAX_LEN - 1], + EOS_ID, + "EOS must be preserved in the final slot so pooled output reads EOS hidden state", + ); + } + + #[test] + fn handles_empty_input() { + // Degenerate case: tokenizer somehow returns no ids. Shouldn't panic. + let (tokens, eos) = prepare_clip_tokens(Vec::new(), MAX_LEN, PAD_ID); + assert_eq!(tokens.len(), MAX_LEN); + assert_eq!(eos, 0); + assert!(tokens.iter().all(|t| *t == PAD_ID)); + } +} From adf1ff6f1e7a9eb10d89ae928094f7b1e50ad3f0 Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 22:59:40 -0700 Subject: [PATCH 14/31] fix(multi-gpu): stop LTX-2 and upscaler from nuking GPU 0's CUDA context MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LTX-2 and the upscaler hardcoded `Device::new_cuda(0)` and `reclaim_gpu_memory(0)` in their engine bodies, ignoring the `gpu_ordinal` they were dispatched with. On a multi-GPU host that meant dispatching LTX-2 to GPU 1 still destroyed GPU 0's primary CUDA context mid-denoise, which surfaced as a misleading CUDA_ERROR_OUT_OF_MEMORY on the sibling job and then segfaulted inside `cuEventDestroy_v2` when candle's Drop chain unwound. - Thread `gpu_ordinal` through `Ltx2Engine` → `Ltx2RuntimeSession` and `UpscalerEngine` / `create_upscale_engine`; replace all four hardcoded-0 call sites. - Add a thread-local GPU binding (`init_thread_gpu_ordinal`) set by each GPU worker thread; `create_device` and `reclaim_gpu_memory` `debug_assert` the caller's ordinal matches, so any future hardcoded-0 regression panics in debug builds instead of silently corrupting a sibling GPU's context. - Update all 4 `create_upscale_engine` callers (CLI, TUI, two in server routes) to pass ordinal 0 explicitly. Server upscaler cache stays pinned to GPU 0 with a comment noting the per-worker cache migration path if multi-GPU upscale becomes interesting. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/mold-cli/src/commands/upscale.rs | 4 ++ crates/mold-inference/src/device.rs | 54 ++++++++++++++++++++ crates/mold-inference/src/factory.rs | 7 ++- crates/mold-inference/src/ltx2/execution.rs | 2 +- crates/mold-inference/src/ltx2/pipeline.rs | 35 ++++++++++--- crates/mold-inference/src/ltx2/runtime.rs | 28 +++++++--- crates/mold-inference/src/upscaler/engine.rs | 18 +++++-- crates/mold-server/src/gpu_worker.rs | 4 ++ crates/mold-server/src/routes.rs | 6 +++ crates/mold-tui/src/app.rs | 1 + 10 files changed, 140 insertions(+), 19 deletions(-) diff --git a/crates/mold-cli/src/commands/upscale.rs b/crates/mold-cli/src/commands/upscale.rs index 7776a45a..a8224924 100644 --- a/crates/mold-cli/src/commands/upscale.rs +++ b/crates/mold-cli/src/commands/upscale.rs @@ -161,10 +161,14 @@ async fn upscale_local( let req_clone = req.clone(); let resp = tokio::task::spawn_blocking(move || -> Result { + // CLI upscale runs locally on the best available GPU (ordinal 0 + // on single-GPU hosts). The multi-GPU server path routes through + // `gpu_worker`, which passes its own ordinal. let mut engine = mold_inference::create_upscale_engine( model_name_owned, weights_path, mold_inference::LoadStrategy::Sequential, + 0, )?; // Set up progress callback for stderr diff --git a/crates/mold-inference/src/device.rs b/crates/mold-inference/src/device.rs index a16ca50f..acc12a04 100644 --- a/crates/mold-inference/src/device.rs +++ b/crates/mold-inference/src/device.rs @@ -1,6 +1,57 @@ use crate::engine::LoadStrategy; use crate::progress::ProgressReporter; use mold_core::types::GpuSelection; +use std::cell::Cell; + +// ── Thread-local GPU ordinal guard ───────────────────────────────────────── +// +// Each GPU worker thread is pinned to a single ordinal. We stash that ordinal +// in a thread-local so cross-engine hotpaths (`create_device`, `reclaim_gpu_memory`) +// can debug-assert the caller isn't drifting onto a sibling GPU's context — +// the exact footgun that took the process down on killswitch when LTX-2 had +// `reclaim_gpu_memory(0)` hardcoded and nuked GPU 0's context while SD3.5 +// was still denoising there. +// +// Threads without a bound ordinal (tokio blocking pool, tests) see `None` +// and the assert is skipped. + +thread_local! { + static THREAD_GPU_ORDINAL: Cell> = const { Cell::new(None) }; +} + +/// Bind the current thread to a GPU ordinal. Call once from each GPU worker +/// thread's entry point. Any subsequent `create_device` / `reclaim_gpu_memory` +/// call on this thread must match `ordinal` (debug builds only). +pub fn init_thread_gpu_ordinal(ordinal: usize) { + THREAD_GPU_ORDINAL.with(|c| c.set(Some(ordinal))); +} + +/// Clear the thread's GPU binding. Not strictly needed in production (workers +/// run for the process lifetime) but useful for tests that reuse threads. +pub fn clear_thread_gpu_ordinal() { + THREAD_GPU_ORDINAL.with(|c| c.set(None)); +} + +/// Returns the currently-bound ordinal, if any. +pub fn thread_gpu_ordinal() -> Option { + THREAD_GPU_ORDINAL.with(|c| c.get()) +} + +/// Panic in debug builds if `ordinal` doesn't match the thread's bound GPU. +/// A mismatch means a call site is ignoring its engine's `gpu_ordinal` and +/// reaching for another GPU's context — the SD3.5/LTX-2 crash pattern. +#[inline] +fn debug_assert_ordinal_matches_thread(ordinal: usize, context: &'static str) { + if cfg!(debug_assertions) { + if let Some(expected) = thread_gpu_ordinal() { + assert_eq!( + expected, ordinal, + "{context}: ordinal {ordinal} does not match this thread's \ + bound GPU {expected} — hardcoded ordinal regression?" + ); + } + } +} // ── GPU discovery ────────────────────────────────────────────────────────── @@ -107,6 +158,7 @@ pub fn create_device( tracing::info!("CPU forced via MOLD_DEVICE=cpu"); return Ok(Device::Cpu); } + debug_assert_ordinal_matches_thread(ordinal, "create_device"); if candle_core::utils::cuda_is_available() { progress.info(&format!("Using CUDA device {ordinal}")); tracing::info!("Using CUDA device {ordinal}"); @@ -389,6 +441,8 @@ pub fn available_system_memory_bytes() -> Option { pub fn reclaim_gpu_memory(ordinal: usize) { use candle_core::cuda_backend::cudarc::driver::{result, sys}; + debug_assert_ordinal_matches_thread(ordinal, "reclaim_gpu_memory"); + // Synchronize to ensure all async GPU work completes before reset. let _ = result::ctx::synchronize(); diff --git a/crates/mold-inference/src/factory.rs b/crates/mold-inference/src/factory.rs index 4a82bfed..7b78df40 100644 --- a/crates/mold-inference/src/factory.rs +++ b/crates/mold-inference/src/factory.rs @@ -181,7 +181,12 @@ pub fn create_engine_with_pool( shared_pool, ))) } - "ltx2" | "ltx-2" => Ok(Box::new(Ltx2Engine::new(model_name, paths, load_strategy))), + "ltx2" | "ltx-2" => Ok(Box::new(Ltx2Engine::new( + model_name, + paths, + load_strategy, + gpu_ordinal, + ))), "wuerstchen" | "wuerstchen-v2" => Ok(Box::new(WuerstchenEngine::new( model_name, paths, diff --git a/crates/mold-inference/src/ltx2/execution.rs b/crates/mold-inference/src/ltx2/execution.rs index b0624a2a..a76318ac 100644 --- a/crates/mold-inference/src/ltx2/execution.rs +++ b/crates/mold-inference/src/ltx2/execution.rs @@ -268,7 +268,7 @@ mod tests { } fn engine(model_name: &str, paths: ModelPaths) -> Ltx2Engine { - Ltx2Engine::new(model_name.to_string(), paths, LoadStrategy::Sequential) + Ltx2Engine::new(model_name.to_string(), paths, LoadStrategy::Sequential, 0) } #[test] diff --git a/crates/mold-inference/src/ltx2/pipeline.rs b/crates/mold-inference/src/ltx2/pipeline.rs index 9c744d8b..871fa87d 100644 --- a/crates/mold-inference/src/ltx2/pipeline.rs +++ b/crates/mold-inference/src/ltx2/pipeline.rs @@ -34,6 +34,11 @@ pub struct Ltx2Engine { native_runtime: Option, on_progress: Option, pending_placement: Option, + /// GPU ordinal this engine is pinned to. Every `Device::new_cuda` and + /// `reclaim_gpu_memory` call must use this ordinal — hardcoding `0` here + /// is what took down the process on killswitch when LTX-2 ran alongside + /// SD3.5 on a multi-GPU host. + gpu_ordinal: usize, } impl Ltx2Engine { @@ -54,7 +59,12 @@ impl Ltx2Engine { } } - pub fn new(model_name: String, paths: ModelPaths, _load_strategy: LoadStrategy) -> Self { + pub fn new( + model_name: String, + paths: ModelPaths, + _load_strategy: LoadStrategy, + gpu_ordinal: usize, + ) -> Self { Self { model_name, paths, @@ -62,6 +72,7 @@ impl Ltx2Engine { native_runtime: None, on_progress: None, pending_placement: None, + gpu_ordinal, } } @@ -78,6 +89,7 @@ impl Ltx2Engine { native_runtime: Some(runtime), on_progress: None, pending_placement: None, + gpu_ordinal: 0, } } @@ -220,7 +232,7 @@ impl Ltx2Engine { match backend { Ltx2Backend::Cuda => { self.info("CUDA detected, using native LTX-2 GPU path"); - let device = Device::new_cuda(0)?; + let device = Device::new_cuda(self.gpu_ordinal)?; configure_native_ltx2_cuda_device(&device)?; Ok(device) } @@ -261,9 +273,16 @@ impl Ltx2Engine { )?; Self::log_timing("pipeline.create_runtime.load_prompt_encoder", load_start); if prompt_device.is_cuda() { - Ok(Ltx2RuntimeSession::new_deferred_cuda(prompt_encoder)) + Ok(Ltx2RuntimeSession::new_deferred_cuda( + prompt_encoder, + self.gpu_ordinal, + )) } else { - Ok(Ltx2RuntimeSession::new(device, prompt_encoder)) + Ok(Ltx2RuntimeSession::new( + device, + prompt_encoder, + self.gpu_ordinal, + )) } } @@ -294,7 +313,7 @@ impl Ltx2Engine { self.info( "Native LTX-2 prompt path ran out of CUDA memory; retrying with CPU fallback", ); - crate::device::reclaim_gpu_memory(0); + crate::device::reclaim_gpu_memory(self.gpu_ordinal); self.load_runtime_session_on_device(plan, Device::Cpu) } Err(err) => Err(err), @@ -1003,7 +1022,7 @@ mod tests { .unwrap(), PaddingSide::Left, ); - Ltx2RuntimeSession::new(Device::Cpu, prompt_encoder) + Ltx2RuntimeSession::new(Device::Cpu, prompt_encoder, 0) } fn request(output_format: OutputFormat, enable_audio: Option) -> GenerateRequest { @@ -1053,6 +1072,7 @@ mod tests { "ltx-2.3-22b-distilled:fp8".to_string(), dummy_paths(), LoadStrategy::Sequential, + 0, ); let req = GenerateRequest { prompt: "test".to_string(), @@ -1114,6 +1134,7 @@ mod tests { "ltx-2-19b-distilled:fp8".to_string(), dummy_paths(), LoadStrategy::Sequential, + 0, ); assert_eq!(engine.request_quantization(), Some("fp8-cast".to_string())); } @@ -1133,6 +1154,7 @@ mod tests { "ltx-2-19b-distilled:fp8".to_string(), dummy_paths_with_gemma_root(gemma_dir.path()), LoadStrategy::Sequential, + 0, ); let req = GenerateRequest { prompt: "test".to_string(), @@ -1200,6 +1222,7 @@ mod tests { "ltx-2-19b-distilled:fp8".to_string(), paths, LoadStrategy::Sequential, + 0, ); engine.load().unwrap(); diff --git a/crates/mold-inference/src/ltx2/runtime.rs b/crates/mold-inference/src/ltx2/runtime.rs index 5fc0a62b..221087b9 100644 --- a/crates/mold-inference/src/ltx2/runtime.rs +++ b/crates/mold-inference/src/ltx2/runtime.rs @@ -296,22 +296,34 @@ pub struct Ltx2RuntimeSession { /// final latents and forward them to the next chain stage as a /// [`super::chain::ChainTail`]. `None` outside chain flow. pub(crate) tail_capture: Option>>>, + /// GPU ordinal inherited from `Ltx2Engine`. Used for the deferred CUDA + /// device creation in `prepare()` and for post-OOM context reset. + gpu_ordinal: usize, } impl Ltx2RuntimeSession { - pub fn new(device: candle_core::Device, prompt_encoder: NativePromptEncoder) -> Self { + pub fn new( + device: candle_core::Device, + prompt_encoder: NativePromptEncoder, + gpu_ordinal: usize, + ) -> Self { Self { device: Some(device), prompt_encoder: Some(prompt_encoder), tail_capture: None, + gpu_ordinal, } } - pub fn new_deferred_cuda(prompt_encoder: NativePromptEncoder) -> Self { + pub fn new_deferred_cuda( + prompt_encoder: NativePromptEncoder, + gpu_ordinal: usize, + ) -> Self { Self { device: None, prompt_encoder: Some(prompt_encoder), tail_capture: None, + gpu_ordinal, } } @@ -414,8 +426,8 @@ impl Ltx2RuntimeSession { let device_handoff_start = Instant::now(); if prompt_device_is_cuda { if self.device.is_none() { - crate::device::reclaim_gpu_memory(0); - self.device = Some(new_native_cuda_device()?); + crate::device::reclaim_gpu_memory(self.gpu_ordinal); + self.device = Some(new_native_cuda_device(self.gpu_ordinal)?); } else if let Some(device) = self.device.as_ref() { if device.is_cuda() { device.synchronize()?; @@ -864,8 +876,8 @@ fn overlay_alpha(overlay: &ConditioningOverlay, frame_idx: u32, total_frames: u3 } #[cfg(feature = "cuda")] -fn new_native_cuda_device() -> Result { - let device = candle_core::Device::new_cuda(0)?; +fn new_native_cuda_device(ordinal: usize) -> Result { + let device = candle_core::Device::new_cuda(ordinal)?; let cuda = device.as_cuda_device()?; if cuda.is_event_tracking() { unsafe { @@ -876,7 +888,7 @@ fn new_native_cuda_device() -> Result { } #[cfg(not(feature = "cuda"))] -fn new_native_cuda_device() -> Result { +fn new_native_cuda_device(_ordinal: usize) -> Result { anyhow::bail!("CUDA backend is unavailable in this build") } @@ -4990,7 +5002,7 @@ mod tests { .unwrap(), PaddingSide::Left, ); - Ltx2RuntimeSession::new(candle_core::Device::Cpu, prompt_encoder) + Ltx2RuntimeSession::new(candle_core::Device::Cpu, prompt_encoder, 0) } fn build_plan( diff --git a/crates/mold-inference/src/upscaler/engine.rs b/crates/mold-inference/src/upscaler/engine.rs index 89590da3..eb741180 100644 --- a/crates/mold-inference/src/upscaler/engine.rs +++ b/crates/mold-inference/src/upscaler/engine.rs @@ -75,16 +75,26 @@ pub struct UpscalerEngine { loaded: Option, progress: ProgressReporter, load_strategy: LoadStrategy, + /// GPU ordinal this engine is pinned to. Same multi-GPU footgun as + /// `Ltx2Engine::gpu_ordinal` — hardcoding `0` would corrupt a sibling + /// GPU's CUDA context on unload. + gpu_ordinal: usize, } impl UpscalerEngine { - pub fn new(name: String, weights_path: PathBuf, load_strategy: LoadStrategy) -> Self { + pub fn new( + name: String, + weights_path: PathBuf, + load_strategy: LoadStrategy, + gpu_ordinal: usize, + ) -> Self { Self { name, weights_path, loaded: None, progress: ProgressReporter::default(), load_strategy, + gpu_ordinal, } } @@ -240,7 +250,7 @@ impl UpscaleEngine for UpscalerEngine { let load_start = Instant::now(); self.progress.stage_start("Loading upscaler model"); - let device = create_device(0, &self.progress)?; + let device = create_device(self.gpu_ordinal, &self.progress)?; // Determine dtype: prefer F16 on GPU, F32 on CPU let dtype = if matches!(device, Device::Cpu) { @@ -317,7 +327,7 @@ impl UpscaleEngine for UpscalerEngine { fn unload(&mut self) { if self.loaded.is_some() { self.loaded = None; - crate::reclaim_gpu_memory(0); + crate::reclaim_gpu_memory(self.gpu_ordinal); tracing::info!("Upscaler model unloaded: {}", self.name); } } @@ -340,6 +350,7 @@ pub fn create_upscale_engine( model_name: String, weights_path: PathBuf, load_strategy: LoadStrategy, + gpu_ordinal: usize, ) -> Result> { if !weights_path.exists() { bail!("upscaler weights not found: {}", weights_path.display()); @@ -348,5 +359,6 @@ pub fn create_upscale_engine( model_name, weights_path, load_strategy, + gpu_ordinal, ))) } diff --git a/crates/mold-server/src/gpu_worker.rs b/crates/mold-server/src/gpu_worker.rs index a2a5fe72..104f8d38 100644 --- a/crates/mold-server/src/gpu_worker.rs +++ b/crates/mold-server/src/gpu_worker.rs @@ -21,6 +21,10 @@ pub fn spawn_gpu_thread( std::thread::Builder::new() .name(format!("gpu-worker-{}", worker.gpu.ordinal)) .spawn(move || { + // Bind this thread to its GPU ordinal so `create_device` / + // `reclaim_gpu_memory` can debug-assert callers don't drift onto + // a sibling GPU's context. See device::init_thread_gpu_ordinal. + mold_inference::device::init_thread_gpu_ordinal(worker.gpu.ordinal); tracing::info!( gpu = worker.gpu.ordinal, name = %worker.gpu.name, diff --git a/crates/mold-server/src/routes.rs b/crates/mold-server/src/routes.rs index 2d6ffc3c..ab971832 100644 --- a/crates/mold-server/src/routes.rs +++ b/crates/mold-server/src/routes.rs @@ -656,10 +656,15 @@ async fn upscale( .as_ref() .is_none_or(|e| e.model_name() != model_name_owned); if needs_new { + // Server-side upscaler cache is process-global today and + // intentionally pinned to GPU 0 (matches prior behavior). + // If multi-GPU upscale becomes interesting, migrate this to + // a per-worker cache on `GpuWorker` and route via the pool. let new_engine = mold_inference::create_upscale_engine( model_name_owned, weights_path, mold_inference::LoadStrategy::Eager, + 0, )?; *cache = Some(new_engine); } @@ -824,6 +829,7 @@ async fn upscale_stream( model_name_owned, weights_path, mold_inference::LoadStrategy::Eager, + 0, ) { Ok(new_engine) => { *cache = Some(new_engine); diff --git a/crates/mold-tui/src/app.rs b/crates/mold-tui/src/app.rs index 0ec7661f..16659f9e 100644 --- a/crates/mold-tui/src/app.rs +++ b/crates/mold-tui/src/app.rs @@ -1452,6 +1452,7 @@ impl App { model_name_local.clone(), weights_path, mold_inference::LoadStrategy::Eager, + 0, )?; engine.set_on_progress(Box::new(move |event| { From 24437ee1a2034854b5a8dffc379a8bd61dc5105c Mon Sep 17 00:00:00 2001 From: Jeffrey Dilley Date: Mon, 20 Apr 2026 23:34:57 -0700 Subject: [PATCH 15/31] feat(web): hide-mode toggle + multi-select delete in gallery Adds two bulk-UX affordances to the web gallery SPA. Hide-mode toggle blurs every tile behind a dark shroud with a per-tile "Reveal" for single peeks; the global preference persists in localStorage, peeks don't. Select mode enables click-to-toggle, shift-click range, and drag-marquee selection with a floating action bar for Select all / Clear / Delete selected / Delete all. Bulk deletes parallelize via Promise.allSettled and partial failures surface a rollup. Select button is gated on capabilities.gallery.can_delete so servers without MOLD_GALLERY_ALLOW_DELETE=1 don't expose dead UI. Co-Authored-By: Claude Opus 4.7 (1M context) --- web/src/components/GalleryCard.vue | 133 ++++++++++++++- web/src/components/GalleryFeed.vue | 179 +++++++++++++++++++- web/src/components/TopBar.vue | 121 ++++++++++++- web/src/pages/GalleryPage.vue | 263 +++++++++++++++++++++++++++++ 4 files changed, 674 insertions(+), 22 deletions(-) diff --git a/web/src/components/GalleryCard.vue b/web/src/components/GalleryCard.vue index 3ce44c42..355fc4ea 100644 --- a/web/src/components/GalleryCard.vue +++ b/web/src/components/GalleryCard.vue @@ -20,14 +20,36 @@ const props = withDefaults( // the header toggle, subsequent videos entering the viewport pick up // the preference automatically. muted?: boolean; + // Multi-select state. When `selectMode` is true, clicks toggle the + // selection instead of opening the detail drawer. + selectMode?: boolean; + selected?: boolean; + // Hide mode renders a blurred overlay over the media until the user + // clicks the reveal button (per-item) or flips the global toggle. + hideMode?: boolean; + revealed?: boolean; }>(), - { variant: "grid", muted: true }, + { + variant: "grid", + muted: true, + selectMode: false, + selected: false, + hideMode: false, + revealed: false, + }, ); const emit = defineEmits<{ (e: "open", item: GalleryImage): void; + ( + e: "toggle-select", + payload: { item: GalleryImage; shift: boolean; meta: boolean }, + ): void; + (e: "reveal", item: GalleryImage): void; }>(); +const isHidden = computed(() => props.hideMode && !props.revealed); + /* * Lifecycle * --------- @@ -129,26 +151,64 @@ function onVideoError() { stage.value = "broken"; } -function openDetail() { +function onCardClick(evt: MouseEvent) { + if (props.selectMode) { + emit("toggle-select", { + item: props.item, + shift: evt.shiftKey, + meta: evt.metaKey || evt.ctrlKey, + }); + return; + } emit("open", props.item); } + +function onCardKey(evt: KeyboardEvent) { + if (props.selectMode) { + emit("toggle-select", { + item: props.item, + shift: evt.shiftKey, + meta: evt.metaKey || evt.ctrlKey, + }); + return; + } + emit("open", props.item); +} + +function onReveal(evt: Event) { + evt.stopPropagation(); + emit("reveal", props.item); +}