diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..441dc758 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,10 @@ +[alias] +build-dev = "build --profile dev-fast" +run-dev = "run --profile dev-fast" + +[build] +incremental = true + +[target.x86_64-unknown-linux-gnu] +linker = "clang" +rustflags = ["-C", "link-arg=-fuse-ld=lld"] diff --git a/.claude/skills/mold/SKILL.md b/.claude/skills/mold/SKILL.md index 4c8ddeee..97f28b29 100644 --- a/.claude/skills/mold/SKILL.md +++ b/.claude/skills/mold/SKILL.md @@ -178,7 +178,9 @@ mold run ltx-2-19b-distilled:fp8 "lantern-lit cave entrance" --camera-control do **Models:** `ltx-2-19b-dev:fp8`, `ltx-2-19b-distilled:fp8`, `ltx-2.3-22b-dev:fp8`, `ltx-2.3-22b-distilled:fp8` -**Important flags:** `--audio`, `--no-audio`, `--audio-file`, `--video`, repeatable `--keyframe`, repeatable `--lora`, `--pipeline`, `--retake`, `--camera-control`, `--spatial-upscale`, `--temporal-upscale` +**Important flags:** `--audio`, `--no-audio`, `--audio-file`, `--video`, repeatable `--keyframe`, repeatable `--lora`, `--pipeline`, `--retake`, `--camera-control`, `--spatial-upscale`, `--temporal-upscale`, `--clip-frames`, `--motion-tail` + +**Chained (arbitrary-length) video output:** for LTX-2 19B and 22B distilled models, `--frames` above the 97-frame per-clip cap automatically renders multiple clips with a motion-tail of latents carried across each clip boundary, then stitches them into a single MP4. The CLI picks this path transparently — `mold run ltx-2-19b-distilled:fp8 "a cat walking" --frames 400` produces one 400-frame MP4 from 5 chained stages. Advanced callers can override the per-clip length via `--clip-frames N` (must be `8k+1`, clamped to the model cap) and the overlap via `--motion-tail N` (default 4 pixel frames, 0 disables carryover). Chains fail closed on mid-stage failure (no partial output) and run on a single GPU. Other model families reject `--frames > 97` with an actionable error. **Current constraints:** `x2` spatial upscaling is wired across the family, `x1.5` spatial upscaling is wired for `ltx-2.3-*`, and `x2` temporal upscaling is wired in the native runtime. Camera-control preset aliases currently auto-resolve the published LTX-2 19B LoRAs only. The family runs through the native Rust stack in `mold-inference`, with CUDA as the supported backend for real local generation, CPU as a correctness-only fallback, and Metal unsupported. On 24 GB Ada GPUs such as the RTX 4090, the validated path stays on the compatible `fp8-cast` mode rather than Hopper-only `fp8-scaled-mm`. The native CUDA matrix is validated across 19B/22B text+audio-video, image-to-video, audio-to-video, keyframe, retake, public IC-LoRA, spatial upscale (`x1.5` / `x2` where published), and temporal upscale (`x2`). When requests go through `mold serve`, the built-in body limit is `64 MiB`, which is enough for common inline source-video and source-audio workflows. @@ -535,6 +537,20 @@ MOLD_HOST=http://gpu-host:7680 mold run "a cat" MOLD_OUTPUT_DIR=/srv/mold/output mold serve ``` +### HTTP API Endpoints + +Core endpoints exposed by `mold serve` (full list + schemas at `/api/docs`): + +- `POST /api/generate` — image/video generation, raw bytes response +- `POST /api/generate/stream` — SSE progress + base64 complete event +- `POST /api/generate/chain` — chained arbitrary-length video (LTX-2 distilled); body is `mold_core::chain::ChainRequest` (canonical `stages[]` or auto-expand `prompt`+`total_frames`+`clip_frames`) +- `POST /api/generate/chain/stream` — same as above, SSE progress with per-stage `denoise_step` events +- `POST /api/expand` — LLM prompt expansion +- `GET /api/models` · `POST /api/models/load` · `POST /api/models/pull` · `DELETE /api/models/unload` +- `GET /api/gallery` · `GET /api/gallery/image/:name` · `GET /api/gallery/thumbnail/:name` · `DELETE /api/gallery/image/:name` +- `POST /api/upscale` · `POST /api/upscale/stream` +- `GET /api/status` · `GET /health` · `GET /api/capabilities` + ### Prometheus Metrics When built with the `metrics` feature flag (included in Docker images and Nix builds), the server exposes a `GET /metrics` endpoint in Prometheus text exposition format. This endpoint is excluded from auth and rate limiting for monitoring scrapers. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 02a51666..d013dafe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,6 +8,7 @@ on: env: CARGO_TERM_COLOR: always + CARGO_INCREMENTAL: "1" # Note: CUDA builds are not run in CI — they require a GPU host with NixOS + CUDA. # CI only checks non-CUDA compilation, lints, and tests. @@ -56,60 +57,76 @@ jobs: working-directory: web run: bun run build - fmt: + rust: runs-on: ubuntu-latest + env: + RUSTC_WRAPPER: sccache + SCCACHE_GHA_ENABLED: "true" steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: - components: rustfmt - - run: cargo fmt --all -- --check - - check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - - uses: dtolnay/rust-toolchain@stable + components: rustfmt,clippy + - name: Install Rust build deps + run: sudo apt-get update && sudo apt-get install -y clang lld nasm libwebp-dev + - uses: mozilla-actions/sccache-action@v0.0.7 + - name: Probe sccache (disable on cache outage) + shell: bash + run: | + set +e + printf 'fn main() {}\n' > /tmp/sccache_probe.rs + sccache rustc --edition 2021 -- /tmp/sccache_probe.rs -o /tmp/sccache_probe + status=$? + rm -f /tmp/sccache_probe /tmp/sccache_probe.rs + if [ $status -ne 0 ]; then + echo "sccache probe failed — disabling RUSTC_WRAPPER for this job" + echo "RUSTC_WRAPPER=" >> "$GITHUB_ENV" + sccache --stop-server >/dev/null 2>&1 || true + fi - uses: Swatinem/rust-cache@v2 with: shared-key: workspace-default save-if: ${{ github.ref == 'refs/heads/main' }} + - name: Format + run: cargo fmt --all -- --check - name: Check run: cargo check --workspace - - clippy: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - - uses: dtolnay/rust-toolchain@stable - with: - components: clippy - - uses: Swatinem/rust-cache@v2 - with: - shared-key: workspace-default - save-if: ${{ github.ref == 'refs/heads/main' }} - name: Clippy - run: cargo clippy --workspace -- -D warnings - - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 - with: - shared-key: workspace-default - save-if: ${{ github.ref == 'refs/heads/main' }} + run: cargo clippy --workspace --all-targets -- -D warnings - name: Test run: cargo test --workspace + - name: Check with all optional features + run: cargo check -p mold-ai --features preview,discord,expand,tui,webp,mp4 coverage: runs-on: ubuntu-latest + env: + RUSTC_WRAPPER: sccache + SCCACHE_GHA_ENABLED: "true" steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable + - name: Install Rust build deps + run: sudo apt-get update && sudo apt-get install -y clang lld nasm libwebp-dev + + - uses: mozilla-actions/sccache-action@v0.0.7 + + - name: Probe sccache (disable on cache outage) + shell: bash + run: | + set +e + printf 'fn main() {}\n' > /tmp/sccache_probe.rs + sccache rustc --edition 2021 -- /tmp/sccache_probe.rs -o /tmp/sccache_probe + status=$? + rm -f /tmp/sccache_probe /tmp/sccache_probe.rs + if [ $status -ne 0 ]; then + echo "sccache probe failed — disabling RUSTC_WRAPPER for this job" + echo "RUSTC_WRAPPER=" >> "$GITHUB_ENV" + sccache --stop-server >/dev/null 2>&1 || true + fi + - uses: Swatinem/rust-cache@v2 with: shared-key: workspace-llvm-cov @@ -127,17 +144,3 @@ jobs: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - - check-features: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 - with: - shared-key: workspace-all-features - save-if: ${{ github.ref == 'refs/heads/main' }} - - name: Install system deps for optional features - run: sudo apt-get update && sudo apt-get install -y nasm libwebp-dev - - name: Check with all features (preview, discord, expand, tui, webp, mp4) - run: cargo check -p mold-ai --features preview,discord,expand,tui,webp,mp4 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e4831e94..09edd513 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,11 +12,17 @@ concurrency: jobs: build-macos: runs-on: macos-14 + env: + CARGO_INCREMENTAL: "1" + RUSTC_WRAPPER: sccache + SCCACHE_GHA_ENABLED: "true" steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable + - uses: mozilla-actions/sccache-action@v0.0.7 + - uses: Swatinem/rust-cache@v2 with: key: macos-release @@ -39,11 +45,15 @@ jobs: # Ada Lovelace (RTX 40-series, sm_89) build-linux-sm89: runs-on: ubuntu-latest + env: + CARGO_INCREMENTAL: "1" + RUSTC_WRAPPER: sccache + SCCACHE_GHA_ENABLED: "true" steps: - uses: actions/checkout@v6 - - name: Install nasm (required by openh264 source build for mp4 feature) - run: sudo apt-get update && sudo apt-get install -y nasm + - name: Install build deps + run: sudo apt-get update && sudo apt-get install -y clang lld nasm - uses: Jimver/cuda-toolkit@v0.2.35 with: @@ -53,6 +63,8 @@ jobs: - uses: dtolnay/rust-toolchain@stable + - uses: mozilla-actions/sccache-action@v0.0.7 + - uses: Swatinem/rust-cache@v2 with: key: linux-cuda-sm89-release @@ -75,11 +87,15 @@ jobs: # Blackwell (RTX 50-series, sm_120) build-linux-sm120: runs-on: ubuntu-latest + env: + CARGO_INCREMENTAL: "1" + RUSTC_WRAPPER: sccache + SCCACHE_GHA_ENABLED: "true" steps: - uses: actions/checkout@v6 - - name: Install nasm (required by openh264 source build for mp4 feature) - run: sudo apt-get update && sudo apt-get install -y nasm + - name: Install build deps + run: sudo apt-get update && sudo apt-get install -y clang lld nasm - uses: Jimver/cuda-toolkit@v0.2.35 with: @@ -89,6 +105,8 @@ jobs: - uses: dtolnay/rust-toolchain@stable + - uses: mozilla-actions/sccache-action@v0.0.7 + - uses: Swatinem/rust-cache@v2 with: key: linux-cuda-sm120-release diff --git a/.gitignore b/.gitignore index f0f1b13c..ac06ff17 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ # Claude Code .claude/worktrees/ .claude/scheduled_tasks.lock +.worktrees/ .playwright-mcp/ .direnv/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 51e20f6e..caec3a3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- **Fast local build profile and web bundle helper**: the workspace now defines a `dev-fast` Cargo profile (`thin` LTO, `codegen-units = 16`, incremental on, debuginfo retained) plus `scripts/ensure-web-dist.sh`, which only rebuilds `web/dist` when the SPA inputs changed. The Nix devshell's default `build`, `build-server`, `mold`, `serve`, and `generate` commands now use that profile and embed the real web gallery by default instead of falling back to the placeholder stub. + +### Changed + +- **LTX-2 chain continuations hold identity far better across multi-clip videos** by threading the starting image into every stage as a soft long-range anchor, re-encoding the motion tail from decoded pixels rather than raw latents, and bumping the default motion-tail window from 9 → 17 pixel frames. Previously the first clip looked on-model but everything after it drifted into "really strange" territory because (1) `build_auto_expand_stages` + `build_stage_generate_request` dropped the `source_image` on every stage past index 0, so the model had no persistent identity reference — each continuation was anchored only to the drifted last frame of the prior clip and errors compounded stage-over-stage; (2) the carryover was a narrowed slice of the emitting stage's final latent tensor, where the last two latent slots encoded pixel frames near the *end* of stage N but got pinned at the *start* of stage N+1 at the receiving clip's RoPE positions, so the LTX-2 VAE's causal-first-slot / 8-pixel-continuation-slot convention was misaligned with the pinned content (`runtime.rs`'s `causal_first_frame_rgb` splice only patched the one causal slot, leaving a backward-pointing jump at slot 1); and (3) only ≈0.4 s of pixel context was carried across the boundary, which is nowhere near enough for the denoiser to reconstruct scene / lighting / subject without help. Fixed by (a) propagating `source_image` to every `ChainStage` in `crates/mold-core/src/chain.rs` and through `build_stage_generate_request` in `crates/mold-inference/src/ltx2/chain.rs`, with `Ltx2Engine::render_chain_stage` re-routing the staged image into the `VideoTokenAppendCondition` append path at a non-zero frame (`motion_tail_pixel_frames`) with soft `CHAIN_SOFT_ANCHOR_STRENGTH = 0.4` — the frame-0 replacement slot stays owned by the motion-tail pin, so the append tokens act as a durable cross-attention anchor for appearance without freezing any pixels; (b) replacing `ChainTail::{latents, last_rgb_frame}` and `StagedLatent::{latents, causal_first_frame_rgb}` with a single `tail_rgb_frames: Vec`, which the receiving `maybe_load_stage_video_conditioning` VAE-encodes fresh into proper-slot-semantics latents on the receiving clip's own time axis (slot 0 = causal 1 pixel, slots 1+ = 8-pixel continuation, all monotonic forward-in-time with correct RoPE); and (c) bumping `DEFAULT_MOTION_TAIL` in `web/src/lib/chainRouting.ts` and the `--motion-tail` CLI default in `crates/mold-cli/src/main.rs` from 9 to 17 pixel frames (three latent frames: causal + two continuation, ≈0.7 s at 24 fps) for enough hard-pinned context that the first free latent frame has plausible neighbours to reason against. The pre-VAE-decode `tail_capture` mutex plumbing in `Ltx2RuntimeSession` is kept (marked `#[allow(dead_code)]`) for future quality-diagnostic tooling; the production chain path no longer arms it. Existing chain orchestrator + `chainRouting` + `chain_client` tests continue to pass; `normalise_preserves_first_stage_image` is renamed to `normalise_preserves_starting_image_across_all_stages` and flipped to assert every stage carries the image, and `chain_only_stage0_carries_source_image` is renamed to `chain_propagates_source_image_to_every_stage` for the same reason. +- **Nix/crane build caching now matches the release feature set**: `flake.nix` feeds the same `mold-ai` feature list into `craneLib.buildDepsOnly` that the packaged binary uses, so dependency artifacts are reused instead of being recompiled in the final package build. Local devshell defaults also stop compiling the full optional feature set unless explicitly requested, while `build-release` and `build-ltx2` still produce the all-features binary. +- **Rust builds now default to `sccache` and a faster Linux linker path in supported build environments**: the repo ships `.cargo/config.toml` aliases for `dev-fast`, Linux builds link via `clang` + `lld`, the devshell includes `sccache`, and the CI/release workflows install the matching linker/tooling. +- **CI now reuses one warmed Rust target dir instead of rebuilding in three separate jobs**: `fmt`, `check`, `clippy`, `test`, and the all-features `cargo check` now run in a single `rust` job with `sccache`, reducing repeated workspace compilation while keeping the existing `coverage`, `docs`, and `web` jobs separate. + +- **Web UI auto-promotes long LTX-2 distilled video requests to the chain endpoint** so the SPA can render arbitrary-length clips without the user manually hitting `mold run` on the CLI. Previously, requesting `frames > 97` (the LTX-2 19B/22B distilled per-clip cap) from the in-browser generate composer POSTed straight to `/api/generate/stream`, the engine dutifully tried to render the full request in one pass, transformer denoise fit fine (~10 GB residual after the Gemma 3 text encoder dropped), and then VAE decode of the full-length latent stack exceeded the 24 GB 3090 budget and `CUDA_ERROR_OUT_OF_MEMORY` killed the job minutes from the end — repeatable three times in a row on a 241-frame 512×512 img2v request before the symptom was traced back to the missing client-side routing. A new pure `decideChainRouting` helper in `web/src/lib/chainRouting.ts` mirrors the CLI's `decide_chain_routing` (`crates/mold-cli/src/commands/chain.rs`): it checks the selected model's family, promotes to the chain endpoint when `frames > LTX2_DISTILLED_CLIP_CAP` for an `ltx2`-family distilled model, rejects cleanly for non-chainable video families that exceed the per-clip budget, and stays `single` otherwise. `useGenerateStream.submit` now accepts the decision, dispatches to either `/api/generate/stream` or `/api/generate/chain/stream`, and folds chain-shaped `ChainProgressEvent`s into the existing `JobProgress` so `RunningJobCard` renders a familiar "Denoising clip K/N · step X/Y" readout without any per-event UI changes. Completion uses a shape-shifter (`chainCompleteToSingle`) to transform `SseChainCompleteEvent` into the canonical `SseCompleteEvent` the gallery detail drawer expects — the `seed_used` fallback to `req.seed ?? 0` loses the auto-open-on-complete affordance for chain runs (chains derive per-stage seeds from the base seed, so there isn't a single seed to match against gallery metadata) but the gallery still refreshes and the new video surfaces at the top. The Composer shows a brand-tinted pill ("Will render as N chained clips of 97 frames (motion-tail 4) — expect this to take substantially longer than a single clip") whenever the routing decision resolves to `chain`, and a red error pill when it resolves to `reject`; `onSubmit` also hard-blocks rejected requests with an `alert()` before building the wire body. All eight routing helper branches and the existing 75 generate-form tests pass (`bun run test`). + ### Fixed +- **Native LTX-2 chain continuations with a staged source-image anchor no longer hit `shape mismatch in broadcast_mul` mid-denoise.** Continuation stages re-route the carried `source_image` into a soft appended anchor (`CHAIN_SOFT_ANCHOR_STRENGTH = 0.4`) so identity stays present without freezing frame-0 tokens, but `reapply_stage_video_conditioning()` in `crates/mold-inference/src/ltx2/runtime.rs` was incorrectly dropping appended conditioning tokens unless `strength >= 1.0`. That shrank the live video-token sequence after the first sampler step while the cached clean-latent tensor and denoise mask still reflected the longer conditioned shape, producing errors like `lhs: [1, 2717, 4096], rhs: [1, 2926, 4096]` on the next `blend_conditioned_denoised()` call. The reapply path now keeps appended tokens present for the full denoise loop regardless of strength; strength continues to be enforced by the denoise mask, and a regression test covers the soft-anchor case. +- **Chained LTX-2 video generation no longer errors on stage 2 with "native LTX-2 prompt encoder is unavailable".** `Ltx2RuntimeSession::prepare` in `crates/mold-inference/src/ltx2/runtime.rs` intentionally consumes the prompt encoder on first call via `self.prompt_encoder.take()` so the Gemma 3 encoder's ~10 GB of VRAM can be freed for the transformer (same drop-and-reload pattern that FLUX uses for T5 and Z-Image uses for Qwen3). That invariant is fine for single-clip generation because the pipeline drops the whole runtime session after each call (`generate_inner` takes but never restores `native_runtime`), so the next request builds a fresh session with a fresh encoder. But `render_chain_stage` *does* restore the session so the transformer and VAE stay warm across stages — and stage 2+ then hit `.context("native LTX-2 prompt encoder is unavailable")?` because the encoder slot is already empty. Three back-to-back failures on a 241-frame img2v chain before it was traced. Fixed by caching the last `NativePromptEncoding` output on the session (keyed by the plan's full `EncodedPromptPair` plus the unconditional flag), so same-prompt follow-ups skip the encoder entirely; chain v1 replicates a single prompt across all stages so the cache is always warm from stage 1 onward. A new `Ltx2RuntimeSession::can_reuse_for(&plan)` helper lets `Ltx2Engine` detect when a persisted session carries a consumed encoder *and* a different prompt — in that case the engine drops it and builds a fresh session, which is the only way to re-encode from scratch. The existing `runtime_session_prepare_consumes_prompt_encoder` test is updated to reflect the new semantic (same-prompt second call succeeds via cache), and a new `runtime_session_prepare_rejects_encoder_reuse_with_different_prompt` test locks in the fresh-session-required branch so future refactors can't silently regress. +- **Chained LTX-2 generation now serializes concurrent chain requests instead of racing on the model cache.** `routes_chain::run_chain` in `crates/mold-server/src/routes_chain.rs` deliberately bypasses the normal generation queue and holds the engine out of `model_cache` for the full multi-minute chain run — a design tradeoff documented in CLAUDE.md to keep the transformer warm across stages without blocking single-clip requests. But that design had no serialization across *chain* requests: a second chain arriving while the first was running called `ensure_model_ready` → saw the cache empty (request A had taken the engine) → tried to `create_and_load_engine` a second copy → then request B's `cache.take()` surfaced the cryptic "engine '…' vanished from cache after ensure_model_ready" error. Manifested the moment the web UI auto-promoted multiple long-frame LTX-2 requests to chain mode. Fixed by adding `chain_lock: Arc>` to `AppState` and acquiring it at the top of `run_chain` before `ensure_model_ready`. The lock is held for the entire chain, so concurrent chain requests queue naturally; single-clip requests continue to flow through the normal generation queue unchanged. `ensure_model_ready`'s global `model_load_lock` still prevents concurrent loads from racing; this new lock adds a second level of serialization specific to chain's "hold engine out of cache" pattern. + +- **Multi-GPU worker affinity now holds end-to-end for queued generation, prompt expansion, and upscaling.** The generation dispatcher no longer rejects work just because the tiny per-worker channels are full; jobs stay pending in the configured global queue until a worker can accept them. Explicit placement GPU ordinals are now validated against the active worker pool and may target only one worker GPU per request/config entry, so per-component overrides can no longer silently allocate on a sibling card while auto-placement heuristics continue reading VRAM from the bound worker. Busy workers keep advertising their active model during the cache `take()` window, so follow-up requests queue behind the warm copy instead of spuriously reloading elsewhere. Server-side local prompt expansion now honors the selected GPU set (and prefers an explicitly requested worker GPU when present), Qwen-Image offload budgeting reads the worker's real ordinal instead of GPU 0, and multi-GPU upscaling now routes through the pool instead of a process-global GPU-0 singleton. Disconnected queued jobs are skipped before expensive work begins, multi-GPU `/api/status` now reports real prompt hashes/timestamps for active generations, the TUI info panel reads per-GPU status, and `MoldClient` can target model unloads by GPU/model instead of only exposing the legacy global unload. +- **LTX-2 image-to-video no longer locks the first latent frame to a noisy ghost of the source image at `strength < 1.0`.** In `run_real_distilled_stage` (`crates/mold-inference/src/ltx2/runtime.rs`) the "clean reference" that the per-step denoise-mask blend pulls the conditioned tokens toward was sourced by cloning `video_latents` *after* `apply_stage_video_conditioning` had already soft-blended the first-latent-frame positions with the initial noise (`noise*(1-s) + source*s`). Used as the clean target, that pre-blended tensor pinned the first latent to a noisy copy of the image instead of the pure image at every step — so i2v runs with `--strength 0.75` (the CLI default) produced a first frame that was 25 % noise + 75 % image rather than the source image. A new helper `clean_latents_for_conditioning` re-applies the replacements with strength 1.0 on top of the post-apply tensor so replacement positions hold pure source image tokens while appended keyframe tokens and pure-noise regions pass through unchanged. `strength = 1.0` and pure-T2V paths are bit-for-bit identical to before. Covered by two new regression tests (`clean_latents_replace_soft_blended_positions_with_pure_source`, `clean_latents_passthrough_when_no_replacements`). +- **city96-format FLUX fine-tune GGUFs now accept `flux-krea:q{8,6,4}` as a valid reference.** `find_flux_reference_gguf` in `crates/mold-inference/src/flux/pipeline.rs` previously hardcoded the candidate list to `flux-dev:q{8,6,4}` (plus schnell for schnell targets), so a box that already had `flux-krea:q8` on disk — a dev-family QuantStack GGUF with the full embedding set — would still error on first `ultrareal-v4:q8` generation and force the user to download a redundant ~12 GB `flux-dev:q8` reference. Krea is now probed after base flux-dev; the existing `gguf_has_guidance` check still gates acceptance so nothing is assumed about completeness. Regression test `find_flux_reference_accepts_krea_when_no_base_dev` covers the new path. - **city96-format FLUX fine-tune GGUFs now fail with an honest, actionable error when no dev reference is downloaded, and surface the dependency at pull time instead of inside `ensure_gguf_embeddings`.** Community fine-tune GGUFs (e.g. the `silveroxides/ultrareal-fine-tune-GGUF` tree that powers `ultrareal-v4:q{8,5,4}`) ship only the diffusion blocks and expect the base FLUX input embedding layers (`img_in`, `time_in`, `vector_in`, `guidance_in`) to be patched in from a separately-downloaded flux-dev reference. Two bugs made this fail confusingly: (1) `find_flux_reference_gguf` in `crates/mold-inference/src/flux/pipeline.rs` returned the first candidate with `img_in.weight`, which let `flux-schnell:q8` pass the probe even though schnell is distilled without `guidance_in` — the subsequent patch loop bailed with `reference GGUF (.../flux-schnell-q8/flux1-schnell-Q8_0.gguf) is also missing required tensor 'guidance_in.in_layer.weight'`, making it look like schnell itself was broken. (2) The manifest didn't express the dependency at all, so the first indication a user had that `mold pull ultrareal-v4:q8` wasn't self-sufficient was an HTTP 500 on their first generation. Fixed by (a) adding a `needs_guidance: bool` parameter to `find_flux_reference_gguf` that skips schnell candidates for dev-family targets and verifies candidates contain `guidance_in.in_layer.weight` before accepting them, (b) rewriting both error messages so the source model is named and the reference path is shown as a filename rather than a full path, and (c) adding a pull-time probe in `crates/mold-core/src/download.rs` (`warn_if_flux_gguf_needs_reference`) that scans the first 4 MiB of any downloaded `.gguf` transformer for `img_in.weight`, and prints a one-line warning via the download callback when the GGUF is incomplete and no suitable dev reference is already on disk. Works for both the CLI (`pull_model`) and server (`pull_model_with_callback`) paths. New regression test `find_flux_reference_skips_schnell_when_dev_needed` covers the reference-picker behaviour. - **Prompt expansion can no longer OOM on a multi-GPU box with a tight main GPU.** `LocalExpander` previously hardcoded `gpu_ordinal: 0` and gated placement with a static 2 GB VRAM threshold — on a dual-GPU system with a busy main card it fell back to CPU unnecessarily, and on a q8/bf16 expand model (4+ GB weights) the 2 GB threshold under-budgeted activations so the GPU placement check could pass and the load then OOM. The expander now sizes its budget dynamically (`model_size + 2 GB activations`, matching the T5/Qwen3 pattern) and cascades through devices: main GPU → remaining GPUs in ordinal order → CPU, with `preflight_memory_check()` as the final hard-fail guard when system RAM can't hold the model either. Unified-memory Metal placements also run the RAM preflight (Metal allocations draw from the same pool). Device selection logic is factored into a pure `select_expand_device(gpus, threshold, is_metal) -> ExpandPlacement` helper with unit tests for every branch. @@ -16,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **Render chain for arbitrary-length LTX-2 distilled video.** `mold run ltx-2-19b-distilled:fp8 "a cat walking" --frames 400` now produces a single stitched MP4 by splitting the request into multiple per-clip renders and carrying a motion-tail of latents across each clip boundary so the continuation stays coherent without a VAE encode/decode round-trip. New server endpoints `POST /api/generate/chain` and `POST /api/generate/chain/stream` (SSE) accept either a canonical `stages[]` body or an auto-expand form (`prompt` + `total_frames` + `clip_frames`) — the wire format is stages-based from day one so the v2 movie-maker UI can author per-stage prompts/keyframes without a breaking change. Request/response/event types live in `crates/mold-core/src/chain.rs` (`ChainRequest`, `ChainResponse`, `ChainProgressEvent`, `SseChainCompleteEvent`); the LTX-2 orchestrator is in `crates/mold-inference/src/ltx2/chain.rs` (`Ltx2ChainOrchestrator`, `ChainTail`); the server routes in `crates/mold-server/src/routes_chain.rs`; and the CLI side in `crates/mold-cli/src/commands/chain.rs`. `mold run` auto-routes to the chain endpoint when `--frames` exceeds the model's per-clip cap (97 for LTX-2 19B/22B distilled); non-distilled families fail fast with an actionable error instead of silently over-producing. New flags `--clip-frames N` (default = model cap) and `--motion-tail N` (default 4, 0 disables carryover) let advanced callers tune the split. The orchestrator derives per-stage seeds as `base_seed ^ ((stage_idx as u64) << 32)` so the whole chain reproduces from a single seed without identical-noise artefacts when every stage shares a prompt. Over-production at the final clip is trimmed from the tail (the head carries the user-anchored starting image and is perceptually load-bearing); mid-chain failures fail closed with HTTP 502 and no partial stitch is ever written to the gallery. Chains run on a single GPU — the chain handler bypasses the single-job queue and holds the `ModelCache` lock for the full chain duration (a multi-minute compound operation would otherwise stall the FIFO queue). Both the remote SSE path and the `--local` in-process path funnel through the same orchestrator via `Ltx2Engine::as_chain_renderer`, and `mold run` renders stacked `indicatif` progress bars (parent "Chain" frame counter + per-stage denoise-step bar). v1 is LTX-2 distilled only, single-GPU, and single-prompt; per-stage prompts, keyframes, selective regen, and multi-GPU stage fan-out are v2 movie-maker work. - **In-browser model downloads with queued progress, ETA, cancel, and retry** ([#255](https://github.com/utensils/mold/pull/255)). `ModelPicker.vue` now shows `(X GB)` next to every model — click an undownloaded one to enqueue a pull without leaving the generate flow. A new `DownloadsDrawer` (opened from a TopBar button with an active/queued count badge) shows per-file progress, client-computed ETA (10 s sliding window), and cancel/retry controls. Undownloaded models in the picker switch to inline progress or a "Queued (#N)" chip while their job is alive, and the picker auto-refreshes on `JobDone` so the model becomes selectable without a page reload. Server-side: a new single-writer `DownloadQueue` in `AppState` drives the existing `mold_core::download::pull_model_with_callback` one model at a time (files sequential inside a set — HF's CDN is bandwidth-bound, so file-level parallelism would only trip rate limits), with one auto-retry on transient failure. Cancellation aborts the in-flight pull, cleans up partials under `MOLD_MODELS_DIR//` while preserving any `.sha256-verified` markers, and leaves the HF blob cache intact so resume is cheap. The same cleanup runs on terminal failures, not just cancel. New routes: `POST /api/downloads` (idempotent — returns the existing job id on a second enqueue), `DELETE /api/downloads/:id`, `GET /api/downloads` (active + queued + last 20 history), `GET /api/downloads/stream` (SSE multiplex of `DownloadEvent` frames — `Enqueued`, `Started`, `Progress`, `FileDone`, `JobDone`, `JobFailed`, `JobCancelled`). Existing `POST /api/models/pull` becomes a thin compat shim that enqueues via the queue and re-emits the legacy SSE event shape, so the TUI keeps working unchanged. - **Always-visible VRAM + system RAM telemetry on `/generate`** ([#254](https://github.com/utensils/mold/pull/254)). A new `ResourceStrip.vue` docks at the bottom of the Composer sidebar on desktop (and collapses to a `🧠 used · total` chip in the TopBar on narrow viewports), showing one stacked-bar row per discovered GPU plus one for system RAM. Each row breaks usage into `mold` / `other` / `free` on CUDA hosts with per-process attribution (NVML feature-gated as `mold-ai-server` `--features nvml`, `nvidia-smi` subprocess fallback on by default) — on Metal the per-process fields are intentionally `None` and the SPA hides those breakdowns, since macOS doesn't expose per-process GPU attribution without private entitlements. Aggregated once per second on the server into a `ResourceSnapshot { hostname, gpus, system_ram }`, exposed as `GET /api/resources` (one-shot; `503` before the first aggregator tick) and `GET /api/resources/stream` (SSE broadcast with 15 s keepalive and the cached snapshot prepended as the first frame so new subscribers don't wait a full second). The aggregator handle is bound to `axum::serve`'s shutdown path so it's aborted on graceful exit. The strip's `useResources` composable is a provide/inject singleton mounted in `App.vue`, and it exposes a `gpuList: ComputedRef` that the new device-placement UI consumes directly. - **Per-component device placement for FLUX, Flux.2, Z-Image, and Qwen-Image** ([#256](https://github.com/utensils/mold/pull/256)). A new `PlacementPanel` disclosure inside the Composer lets users override which device each part of the pipeline runs on. Tier 1 is a single "Text encoders: Auto / CPU / GPU N" dropdown that applies to every model family (SD1.5, SDXL, SD3.5, Wuerstchen, LTX-Video, LTX-2 in addition to the Tier 2 four) — picking CPU reliably moves the text encoder off-GPU so a large transformer can stay on-device without triggering block-level offload. Tier 2 adds per-component selects (transformer, VAE, and family-appropriate encoder slots) for FLUX, Flux.2, Z-Image, and Qwen-Image, where the plumbing is cheapest and the value is clearest. SD3.5 was marked stretch in the design and cut cleanly — the UI correctly hides Advanced for SD3.5 with a tooltip so no user sees an override that silently no-ops. A new `DevicePlacement` serde type (`DeviceRef = Auto | Cpu | Gpu(ordinal)` plus an optional `AdvancedPlacement` sub-struct for per-component overrides) rides as an optional field on `GenerateRequest`; `None` preserves the existing VRAM-aware auto-placement end-to-end. A shared `resolve_device()` helper in `mold_inference::device` (and a companion `effective_device_ref()` shared by the four Tier-2 engines) maps each `DeviceRef` variant to a `candle_core::Device`, returning a clean `anyhow::Error` for bad ordinals instead of panicking. Defaults are saved per-model in `[models."name:tag".placement]` (with `MOLD_PLACE_TEXT_ENCODERS`, `MOLD_PLACE_TRANSFORMER`, `MOLD_PLACE_VAE`, `MOLD_PLACE_CLIP_L`, `MOLD_PLACE_CLIP_G`, `MOLD_PLACE_T5`, `MOLD_PLACE_QWEN` env overrides) via a new `PUT /api/config/model/:name/placement` route (with `DELETE` to clear); the route now returns a real `500` when `Config::save()` fails instead of silently lying to the client. The placement UI reads its GPU list from `useResources().gpuList`, so spinning up a mold server on a dual-3090 box auto-populates "GPU 0 · RTX 3090" / "GPU 1 · RTX 3090" in every dropdown without any extra discovery wiring. `mold run` gains matching CLI flags — `--device-text-encoders`, `--device-transformer`, `--device-vae`, `--device-t5`, `--device-clip-l`, `--device-clip-g`, `--device-qwen` — which override env vars and config; flag parse errors surface with the specific flag name so `--device-vae banana` reports `--device-vae: invalid device 'banana' (expected auto|cpu|gpu[:N])` instead of a generic failure. Documented in `website/guide/configuration.md` (new "Per-component device placement" section) and `website/guide/performance.md` (the "CPU text encoders" subsection now points at the CLI flags for deliberate VRAM tuning). diff --git a/CLAUDE.md b/CLAUDE.md index 0c24df48..12b2e7f5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -24,10 +24,11 @@ nix flake check # Validate formatting + fla | Category | Command | Description | |----------|---------|-------------| -| build | `build` | `cargo build` (debug, all crates) | -| build | `build-release` | `cargo build --release` | -| build | `build-server` | `cargo build -p mold-ai --features {cuda\|metal}` (single binary with GPU) | -| build | `build-discord` | `cargo build -p mold-ai --features discord` | +| build | `build` | Fast local `mold` build (`cargo build --profile dev-fast -p mold-ai`) with the web bundle embedded | +| build | `build-workspace` | `cargo build` (debug, all crates) | +| build | `build-release` | Shipping `cargo build --release -p mold-ai --features {gpu},preview,discord,expand,tui,webp,mp4,metrics` | +| build | `build-server` | Fast local single-binary server build with GPU + preview + expand and embedded web UI | +| build | `build-discord` | Fast local `cargo build --profile dev-fast -p mold-ai --features discord` | | check | `check` | `cargo check` | | check | `clippy` | `cargo clippy` | | check | `run-tests` | `cargo test` | @@ -43,6 +44,7 @@ nix flake check # Validate formatting + fla ```bash cargo build # Debug build (all crates) +cargo build --profile dev-fast # Fast local optimized build cargo build --release # Release build cargo build -p mold-ai # Just the CLI cargo build -p mold-ai --features cuda # CLI with CUDA (includes serve) @@ -54,8 +56,8 @@ cargo test -p mold-ai-core # Single crate ./scripts/coverage.sh # Test coverage summary ./scripts/coverage.sh --html # HTML coverage report ./scripts/fetch-tokenizers.sh # Pre-download tokenizer files -cargo run -p mold-ai -- run "a cat" # Generate image -cargo run -p mold-ai -- serve # Start server +./scripts/ensure-web-dist.sh && cargo run --profile dev-fast -p mold-ai --features metal,preview,expand -- run "a cat" +./scripts/ensure-web-dist.sh && cargo run --profile dev-fast -p mold-ai --features metal,preview,expand -- serve cargo run -p mold-ai-inference --features dev-bins --bin ltx2_review -- clip.mp4 ``` @@ -65,12 +67,8 @@ CI runs on every push and PR (`.github/workflows/ci.yml`). All jobs must pass: | Job | What it checks | |-----|----------------| -| `fmt` | `cargo fmt --all -- --check` | -| `check` | `cargo check --workspace` | -| `clippy` | `cargo clippy --workspace -- -D warnings` | -| `test` | `cargo test --workspace` | +| `rust` | `cargo fmt --all -- --check && cargo check --workspace && cargo clippy --workspace --all-targets -- -D warnings && cargo test --workspace && cargo check -p mold-ai --features preview,discord,expand,tui,webp,mp4` | | `coverage` | `cargo llvm-cov` → Codecov upload | -| `check-features` | `cargo check -p mold-ai --features preview,discord,expand,tui,webp,mp4` (all optional features) | | `docs` | `bun run fmt:check && bun run verify && bun run build` in `website/` | > **Note:** `mold-inference` and `mold-server` have `[lib] test = false` in their `Cargo.toml` files. The test harness for these crates links against candle/CUDA which triggers heavy model weight initialization (~32GB RAM, 40+ min hang). The `mold-server` binary target also has `test = false`. Unit tests in `mold-core` and `mold-cli` run normally. If you add tests to `mold-inference` or `mold-server`, run them with `cargo test -p --lib` after temporarily removing the `test = false` flag. @@ -266,7 +264,7 @@ Location: `~/.config/mold/config.toml` (XDG) or `~/.mold/config.toml` (legacy **Documentation site**: VitePress 1.6 + Tailwind CSS v4 + bun in `website/`. Dev server: `cd website && bun install && bun run dev -- --host 0.0.0.0`. Build: `bun run build`. Deployed to GitHub Pages via `.github/workflows/pages.yml` on push to `main` (website/** paths). Base path is `/mold/` (served at `utensils.github.io/mold/`). -**Web gallery UI** (separate from the docs site): Vue 3 + Vite 7 + Tailwind CSS v4.2 SPA in `web/`. The SPA is **embedded directly into the `mold` binary at compile time** via [`rust-embed`](https://crates.io/crates/rust-embed), so `nix build` (or `cargo build` after running `cd web && bun run build`) produces a single-file server that serves the gallery with zero runtime filesystem dependency. `crates/mold-server/build.rs` resolves the bundle from one of three sources, in order: `$MOLD_WEB_DIST` (set by the Nix flake's `mold-web` derivation — built reproducibly via [bun2nix](https://github.com/nix-community/bun2nix) from `web/bun.lock` → `web/bun.nix`), `/web/dist`, or a generated placeholder stub (`$OUT_DIR/web-stub/__mold_placeholder`). The stub is detected at runtime and swapped for the inline "mold is running" page so a bare `cargo build` still produces a working binary. For SPA hot-iteration without Rust recompiles, `MOLD_WEB_DIR` (and the legacy `$XDG_DATA_HOME/mold/web`, `~/.mold/web`, `/web`, `./web/dist` candidates) still take precedence over the embedded bundle — so `bun run dev` or a local `web/dist` can be swapped in without rebuilding Rust. Dev server: `bun run dev` (proxies `/api` + `/health` to `http://localhost:7680`; override with `MOLD_API_ORIGIN`). See `crates/mold-server/src/web_ui.rs` for the resolver + embed wiring and `web/README.md` for the frontend stack. +**Web gallery UI** (separate from the docs site): Vue 3 + Vite 7 + Tailwind CSS v4.2 SPA in `web/`. The SPA is **embedded directly into the `mold` binary at compile time** via [`rust-embed`](https://crates.io/crates/rust-embed), so `nix build` produces a single-file server that serves the gallery with zero runtime filesystem dependency. Local devshell `build`, `build-server`, `mold`, `serve`, and `generate` commands now call `./scripts/ensure-web-dist.sh` first, so `target/dev-fast/mold` normally includes the real SPA by default instead of the placeholder stub. `crates/mold-server/build.rs` resolves the bundle from one of three sources, in order: `$MOLD_WEB_DIST` (set by the Nix flake's `mold-web` derivation — built reproducibly via [bun2nix](https://github.com/nix-community/bun2nix) from `web/bun.lock` → `web/bun.nix`), `/web/dist`, or a generated placeholder stub (`$OUT_DIR/web-stub/__mold_placeholder`). The stub is detected at runtime and swapped for the inline "mold is running" page so a bare `cargo build` still produces a working binary. For SPA hot-iteration without Rust recompiles, `MOLD_WEB_DIR` (and the legacy `$XDG_DATA_HOME/mold/web`, `~/.mold/web`, `/web`, `./web/dist` candidates) still take precedence over the embedded bundle — so `bun run dev` or a local `web/dist` can be swapped in without rebuilding Rust. Dev server: `bun run dev` (proxies `/api` + `/health` to `http://localhost:7680`; override with `MOLD_API_ORIGIN`). See `crates/mold-server/src/web_ui.rs` for the resolver + embed wiring and `web/README.md` for the frontend stack. **Gallery metadata DB at `MOLD_HOME/mold.db`** (override with `MOLD_DB_PATH`, disable with `MOLD_DB_DISABLE=1`). The `mold-db` crate (rusqlite + bundled SQLite, WAL mode) holds one row per saved file with the full `OutputMetadata` plus `file_mtime_ms`, `file_size_bytes`, `generation_time_ms`, `backend` (`cuda`/`metal`/`cpu`), `hostname`, `format`, and a `source` column (`server` / `cli` / `backfill`). Both surfaces write rows after a successful save: diff --git a/Cargo.lock b/Cargo.lock index c950b94e..4a76c58b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3104,6 +3104,7 @@ dependencies = [ "async-trait", "axum", "base64 0.22.1", + "candle-core-mold", "clap", "dirs 5.0.1", "futures", diff --git a/Cargo.toml b/Cargo.toml index c667ed3f..62a98080 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,3 +30,11 @@ name = "mold" lto = "fat" codegen-units = 1 strip = true + +[profile.dev-fast] +inherits = "release" +lto = "thin" +codegen-units = 16 +incremental = true +debug = 1 +strip = false diff --git a/README.md b/README.md index 9a7c2714..5d87a70d 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,10 @@ nix run github:utensils/mold#mold-sm120 -- run "a cat" # Blackwell / RTX ### From source ```bash -cargo build --release -p mold-ai --features cuda # Linux (NVIDIA) -cargo build --release -p mold-ai --features metal # macOS (Apple Silicon) +./scripts/ensure-web-dist.sh && cargo build --profile dev-fast -p mold-ai --features cuda # Linux (NVIDIA), fast local build +./scripts/ensure-web-dist.sh && cargo build --profile dev-fast -p mold-ai --features metal # macOS (Apple Silicon), fast local build +cargo build --release -p mold-ai --features cuda # Linux (NVIDIA), shipping build +cargo build --release -p mold-ai --features metal # macOS (Apple Silicon), shipping build ``` Add `preview`, `expand`, `discord`, or `tui` to the features list as needed. diff --git a/crates/mold-cli/Cargo.toml b/crates/mold-cli/Cargo.toml index a7734a87..07a71f88 100644 --- a/crates/mold-cli/Cargo.toml +++ b/crates/mold-cli/Cargo.toml @@ -22,7 +22,7 @@ discord = ["mold-discord"] expand = ["mold-inference/expand", "mold-server/expand", "mold-tui?/expand"] tui = ["dep:mold-tui"] webp = ["mold-inference/webp"] -mp4 = ["mold-inference/mp4"] +mp4 = ["mold-inference/mp4", "mold-server/mp4"] metrics = ["mold-server/metrics"] [dependencies] diff --git a/crates/mold-cli/src/commands/chain.rs b/crates/mold-cli/src/commands/chain.rs new file mode 100644 index 00000000..4a8796bf --- /dev/null +++ b/crates/mold-cli/src/commands/chain.rs @@ -0,0 +1,843 @@ +//! CLI-side render-chain orchestration for LTX-2 distilled models. +//! +//! When `mold run --frames N` exceeds the per-clip cap of the selected model, +//! this module takes over from [`super::generate::run`]: it assembles a +//! [`ChainRequest`] from the user's CLI args and either submits it to a +//! running server via [`MoldClient::generate_chain_stream`] or, in `--local` +//! mode, drives an in-process [`Ltx2ChainOrchestrator`]. +//! +//! Both paths funnel through [`encode_and_save`] so stdout piping, gallery +//! save, metadata DB writes, and preview behaviour match the single-clip +//! path byte-for-byte. + +use std::io::Write; +use std::time::Duration; + +use anyhow::Result; +use colored::Colorize; +use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use mold_core::chain::{ChainProgressEvent, ChainRequest}; +use mold_core::{Config, MoldClient, OutputFormat, VideoData}; + +use crate::control::CliContext; +use crate::output::{is_piped, status}; +use crate::theme; + +/// Per-clip frame cap for LTX-2 19B/22B distilled. The distilled VAE +/// pipeline maxes at 97 pixel frames (13 latent frames) per clip. +pub const LTX2_DISTILLED_CLIP_CAP: u32 = 97; + +/// Outcome of [`decide_chain_routing`]: either the caller should continue +/// down the single-clip path, build a chain with the given settings, or +/// reject the request because the model family can't be chained. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ChainRoutingDecision { + /// Go through the normal single-clip path; no chaining required. + SingleClip, + /// Submit a chain. `clip_frames` is the clamped per-clip cap. + Chain { clip_frames: u32, motion_tail: u32 }, + /// Model family doesn't support chaining and `frames` exceeds its cap. + Rejected { reason: String }, +} + +/// Pure decision function — given a model family, the user's requested +/// `frames`, and the optional `--clip-frames` override, decide whether to +/// chain, stay single-clip, or reject. +/// +/// The clamp-to-cap behaviour surfaces through the returned `clip_frames` +/// field; callers warn the user via stderr when they had to clamp. +pub fn decide_chain_routing( + frames: Option, + family: Option<&str>, + model: &str, + clip_frames_flag: Option, + motion_tail: u32, +) -> ChainRoutingDecision { + let Some(total_frames) = frames else { + return ChainRoutingDecision::SingleClip; + }; + + let is_ltx2_distilled = family == Some("ltx2") && model.contains("distilled"); + + if !is_ltx2_distilled { + // Non-chainable families: if the requested frame count is within a + // conservative single-clip budget, stay on the single-clip path and + // let the engine decide if it's acceptable. Otherwise, reject with + // a clear message rather than silently over-producing. + if total_frames <= LTX2_DISTILLED_CLIP_CAP { + return ChainRoutingDecision::SingleClip; + } + return ChainRoutingDecision::Rejected { + reason: format!( + "model '{model}' does not support chained video generation \ + (only LTX-2 distilled families do); specify --frames <= {} \ + per clip for this model", + LTX2_DISTILLED_CLIP_CAP, + ), + }; + } + + let cap = LTX2_DISTILLED_CLIP_CAP; + let effective_clip_frames = clip_frames_flag.unwrap_or(cap).min(cap); + + if total_frames <= effective_clip_frames { + return ChainRoutingDecision::SingleClip; + } + + if motion_tail >= effective_clip_frames { + return ChainRoutingDecision::Rejected { + reason: format!( + "--motion-tail ({motion_tail}) must be strictly less than \ + --clip-frames ({effective_clip_frames}) so every continuation \ + emits at least one new frame", + ), + }; + } + + ChainRoutingDecision::Chain { + clip_frames: effective_clip_frames, + motion_tail, + } +} + +/// Emit a stderr warning if `--clip-frames` was above the model's cap and +/// got clamped. Returns the effective value (caller should already have it). +pub fn warn_if_clamped(flag: Option, cap: u32) { + if let Some(requested) = flag { + if requested > cap { + crate::output::status!( + "{} --clip-frames {} exceeds model cap {}, clamping to {}", + theme::prefix_warning(), + requested, + cap, + cap, + ); + } + } +} + +/// Caller-supplied inputs for a chain run, bundled so the remote + local +/// paths can share a single helper without a 20-arg function signature. +#[allow(clippy::too_many_arguments)] +pub struct ChainInputs { + pub prompt: String, + pub model: String, + pub width: u32, + pub height: u32, + pub steps: u32, + pub guidance: f64, + pub strength: f64, + pub seed: Option, + pub fps: u32, + pub output_format: OutputFormat, + pub total_frames: u32, + pub clip_frames: u32, + pub motion_tail: u32, + pub source_image: Option>, + pub placement: Option, +} + +impl ChainInputs { + fn to_chain_request(&self) -> ChainRequest { + ChainRequest { + model: self.model.clone(), + stages: Vec::new(), + motion_tail_frames: self.motion_tail, + width: self.width, + height: self.height, + fps: self.fps, + seed: self.seed, + steps: self.steps, + guidance: self.guidance, + strength: self.strength, + output_format: self.output_format, + placement: self.placement.clone(), + prompt: Some(self.prompt.clone()), + total_frames: Some(self.total_frames), + clip_frames: Some(self.clip_frames), + source_image: self.source_image.clone(), + } + } +} + +/// Run a chain end-to-end, dispatching to the server (streaming) or the +/// local orchestrator based on the `local` flag. Handles encoding, save, +/// preview, and final status messages. +#[allow(clippy::too_many_arguments)] +pub async fn run_chain( + inputs: ChainInputs, + host: Option, + output: Option, + no_metadata: bool, + preview: bool, + local: bool, + gpus: Option, + t5_variant: Option, + qwen3_variant: Option, + qwen2_variant: Option, + qwen2_text_encoder_mode: Option, + eager: bool, + offload: bool, +) -> Result<()> { + // Validate the auto-expand form before touching the network / GPU so + // obvious mistakes (bad clip_frames math, too many stages) fail fast. + let chain_req = inputs.to_chain_request(); + let normalised = chain_req.clone().normalise()?; + let stage_count = normalised.stages.len() as u32; + + status!( + "{} Chain mode: {} frames → {} stages × {} frames (tail {})", + theme::icon_mode(), + inputs.total_frames, + stage_count, + inputs.clip_frames, + inputs.motion_tail, + ); + + let ctx = CliContext::new(host.as_deref()); + let config = ctx.config().clone(); + let embed_metadata = config.effective_embed_metadata(no_metadata.then_some(false)); + let _ = embed_metadata; // reserved for future metadata-embed work on chain output + + let t0 = std::time::Instant::now(); + let video = if local { + #[cfg(any(feature = "cuda", feature = "metal"))] + { + crate::ui::print_using_local_inference(); + run_chain_local( + &chain_req, + &config, + gpus, + t5_variant, + qwen3_variant, + qwen2_variant, + qwen2_text_encoder_mode, + eager, + offload, + ) + .await? + } + #[cfg(not(any(feature = "cuda", feature = "metal")))] + { + let _ = ( + gpus, + t5_variant, + qwen3_variant, + qwen2_variant, + qwen2_text_encoder_mode, + eager, + offload, + ); + anyhow::bail!( + "No mold server running and this binary was built without GPU support.\n\ + Either start a server with `mold serve` or rebuild with --features cuda" + ) + } + } else { + run_chain_remote(ctx.client(), &chain_req).await? + }; + + let elapsed_ms = t0.elapsed().as_millis() as u64; + let base_seed = inputs.seed.unwrap_or(0); + + encode_and_save( + &inputs, + &video, + output.as_deref(), + preview, + elapsed_ms, + base_seed, + )?; + + Config::write_last_model(&inputs.model); + Ok(()) +} + +/// Remote chain: streaming SSE with stacked progress bars. +async fn run_chain_remote(client: &MoldClient, req: &ChainRequest) -> Result { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + let render = tokio::spawn(render_chain_progress(rx)); + + let stream_result = client.generate_chain_stream(req, tx).await; + let _ = render.await; + + match stream_result { + Ok(Some(resp)) => Ok(resp.video), + Ok(None) => { + // Server predates chain endpoint; fall back to non-streaming. + status!( + "{} Server SSE chain endpoint unavailable, falling back to blocking endpoint", + theme::prefix_warning(), + ); + let resp = client.generate_chain(req).await?; + Ok(resp.video) + } + Err(e) => Err(e), + } +} + +#[cfg(any(feature = "cuda", feature = "metal"))] +#[allow(clippy::too_many_arguments)] +async fn run_chain_local( + chain_req: &ChainRequest, + config: &Config, + gpus: Option, + t5_variant_override: Option, + qwen3_variant_override: Option, + qwen2_variant_override: Option, + qwen2_text_encoder_mode_override: Option, + eager: bool, + offload: bool, +) -> Result { + use mold_core::manifest::find_manifest; + use mold_core::ModelPaths; + use mold_inference::LoadStrategy; + + // Normalise so we have expanded stages locally too. + let req = chain_req.clone().normalise()?; + + // Apply encoder-variant overrides before constructing the engine so the + // factory's auto-select picks them up. + apply_local_engine_env_overrides( + t5_variant_override.as_deref(), + qwen3_variant_override.as_deref(), + qwen2_variant_override.as_deref(), + qwen2_text_encoder_mode_override.as_deref(), + ); + + let model_name = req.model.clone(); + + // Ensure the model is pulled + config rows are in place. + let (paths, effective_config) = if let Some(p) = ModelPaths::resolve(&model_name, config) { + (p, config.clone()) + } else if find_manifest(&model_name).is_some() { + crate::output::status!( + "{} Model '{}' not found locally, pulling...", + theme::icon_info(), + model_name.bold(), + ); + let updated = super::pull::pull_and_configure( + &model_name, + &mold_core::download::PullOptions::default(), + ) + .await?; + let p = ModelPaths::resolve(&model_name, &updated).ok_or_else(|| { + anyhow::anyhow!("model '{model_name}' was pulled but paths could not be resolved") + })?; + (p, updated) + } else { + anyhow::bail!( + "no model paths configured for '{model_name}'. Add [models.{model_name}] \ + to ~/.mold/config.toml or pull via `mold pull {model_name}`." + ); + }; + + let is_eager = eager || std::env::var("MOLD_EAGER").is_ok_and(|v| v == "1"); + let load_strategy = if is_eager { + LoadStrategy::Eager + } else { + LoadStrategy::Sequential + }; + if is_eager { + std::env::set_var("MOLD_EAGER", "1"); + } + let is_offload = offload || std::env::var("MOLD_OFFLOAD").is_ok_and(|v| v == "1"); + + let gpu_selection = match &gpus { + Some(s) => mold_core::types::GpuSelection::parse(s)?, + None => effective_config.gpu_selection(), + }; + let discovered = mold_inference::device::discover_gpus(); + let available = mold_inference::device::filter_gpus(&discovered, &gpu_selection); + let gpu_ordinal = mold_inference::device::select_best_gpu(&available) + .map(|g| g.ordinal) + .unwrap_or(0); + + let mut engine = mold_inference::create_engine( + model_name, + paths, + &effective_config, + load_strategy, + gpu_ordinal, + is_offload, + )?; + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + let render = tokio::spawn(render_chain_progress(rx)); + + let fps = req.fps; + let output_format = req.output_format; + let total_frames_opt = Some(req.total_frames.unwrap_or(u32::MAX)); + let req_clone = req.clone(); + + let handle = tokio::task::spawn_blocking(move || -> Result { + engine.load()?; + let renderer = engine.as_chain_renderer().ok_or_else(|| { + anyhow::anyhow!( + "model '{}' does not support chained video generation \ + (only LTX-2 distilled engines expose a ChainStageRenderer view)", + req_clone.model, + ) + })?; + let mut orch = mold_inference::ltx2::Ltx2ChainOrchestrator::new(renderer); + + let tx = tx; + let mut chain_cb = move |event: ChainProgressEvent| { + let _ = tx.send(event); + }; + let chain_output = orch.run(&req_clone, Some(&mut chain_cb))?; + + let mut frames = chain_output.frames; + if let Some(target) = total_frames_opt { + let target = target as usize; + if frames.len() > target { + frames.truncate(target); + } + } + if frames.is_empty() { + anyhow::bail!("chain run emitted zero frames after trim"); + } + + encode_local_frames(&frames, fps, output_format) + }); + + let result = handle.await??; + let _ = render.await; + Ok(result) +} + +#[cfg(any(feature = "cuda", feature = "metal"))] +fn apply_local_engine_env_overrides( + t5_variant: Option<&str>, + qwen3_variant: Option<&str>, + qwen2_variant: Option<&str>, + qwen2_text_encoder_mode: Option<&str>, +) { + if let Some(v) = t5_variant { + std::env::set_var("MOLD_T5_VARIANT", v); + } + if let Some(v) = qwen3_variant { + std::env::set_var("MOLD_QWEN3_VARIANT", v); + } + if let Some(v) = qwen2_variant { + std::env::set_var("MOLD_QWEN2_VARIANT", v); + } + if let Some(v) = qwen2_text_encoder_mode { + std::env::set_var("MOLD_QWEN2_TEXT_ENCODER_MODE", v); + } +} + +/// Encode stitched frames to the requested container. MP4 is feature-gated; +/// fall back to APNG when the CLI was built without `mp4`. +#[cfg(any(feature = "cuda", feature = "metal"))] +fn encode_local_frames( + frames: &[image::RgbImage], + fps: u32, + output_format: OutputFormat, +) -> Result { + use mold_inference::ltx_video::video_enc; + + let gif_preview = video_enc::encode_gif(frames, fps).unwrap_or_default(); + let thumbnail = video_enc::first_frame_png(frames).unwrap_or_default(); + + let (bytes, actual_format) = match output_format { + OutputFormat::Mp4 => { + #[cfg(feature = "mp4")] + { + (video_enc::encode_mp4(frames, fps)?, OutputFormat::Mp4) + } + #[cfg(not(feature = "mp4"))] + { + crate::output::status!( + "{} MP4 requested but this binary was built without --features mp4; \ + falling back to APNG", + theme::prefix_warning(), + ); + ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ) + } + } + OutputFormat::Apng => ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ), + OutputFormat::Gif => (video_enc::encode_gif(frames, fps)?, OutputFormat::Gif), + OutputFormat::Webp => { + crate::output::status!( + "{} WebP chain output not supported locally yet; falling back to APNG", + theme::prefix_warning(), + ); + ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ) + } + other => anyhow::bail!("{other:?} is not a video output format for chain generation"), + }; + + let width = frames[0].width(); + let height = frames[0].height(); + let frame_count = frames.len() as u32; + let duration_ms = if fps == 0 { + None + } else { + Some((frame_count as u64 * 1000) / fps as u64) + }; + + Ok(VideoData { + data: bytes, + format: actual_format, + width, + height, + frames: frame_count, + fps, + thumbnail, + gif_preview, + has_audio: false, + duration_ms, + audio_sample_rate: None, + audio_channels: None, + }) +} + +/// Shared epilogue: write the stitched video to stdout/file/gallery and +/// emit a terminal preview if requested. +fn encode_and_save( + inputs: &ChainInputs, + video: &VideoData, + output: Option<&str>, + preview: bool, + elapsed_ms: u64, + base_seed: u64, +) -> Result<()> { + let piped = is_piped(); + + if piped && output.is_none() { + let mut stdout = std::io::stdout().lock(); + stdout.write_all(&video.data)?; + stdout.flush()?; + } else { + let filename = match output { + Some("-") => { + let mut stdout = std::io::stdout().lock(); + stdout.write_all(&video.data)?; + stdout.flush()?; + None + } + Some(path) => Some(path.to_string()), + None => { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + Some(mold_core::default_output_filename( + &inputs.model, + timestamp, + video.format.extension(), + 1, + 0, + )) + } + }; + if let Some(ref filename) = filename { + if std::path::Path::new(filename).exists() { + status!("{} Overwriting: {}", theme::icon_alert(), filename); + } + std::fs::write(filename, &video.data)?; + status!( + "{} Saved: {} ({} frames, {}x{}, {} fps)", + theme::icon_done(), + filename.bold(), + video.frames, + video.width, + video.height, + video.fps, + ); + + // Persist to the gallery metadata DB. Build a synthetic + // GenerateRequest so the existing record_local_save helper can + // infer dimensions/seed/steps/etc. without a dedicated chain + // row schema. + let req = synth_generate_request(inputs, video); + crate::metadata_db::record_local_save( + std::path::Path::new(filename), + &req, + inputs.seed.unwrap_or(base_seed), + elapsed_ms, + video.format, + ); + } + } + + if preview && !piped { + // Best-effort: show the gif preview or fall back to the thumbnail + // or the video bytes themselves (GIF/APNG decode as images). + let bytes_for_preview: &[u8] = if !video.gif_preview.is_empty() { + &video.gif_preview + } else if !video.thumbnail.is_empty() { + &video.thumbnail + } else { + &video.data + }; + super::generate::preview_image(bytes_for_preview); + } + + status!( + "{} Done — {} in {:.1}s ({} frames, seed: {})", + theme::icon_done(), + inputs.model.bold(), + elapsed_ms as f64 / 1000.0, + video.frames, + inputs.seed.unwrap_or(base_seed), + ); + + Ok(()) +} + +fn synth_generate_request(inputs: &ChainInputs, video: &VideoData) -> mold_core::GenerateRequest { + mold_core::GenerateRequest { + prompt: inputs.prompt.clone(), + negative_prompt: None, + model: inputs.model.clone(), + width: inputs.width, + height: inputs.height, + steps: inputs.steps, + guidance: inputs.guidance, + seed: inputs.seed, + batch_size: 1, + output_format: video.format, + embed_metadata: Some(false), + scheduler: None, + edit_images: None, + source_image: inputs.source_image.clone(), + strength: inputs.strength, + mask_image: None, + control_image: None, + control_model: None, + control_scale: 1.0, + expand: None, + original_prompt: None, + lora: None, + frames: Some(video.frames), + fps: Some(video.fps), + upscale_model: None, + gif_preview: false, + enable_audio: None, + audio_file: None, + source_video: None, + keyframes: None, + pipeline: None, + loras: None, + retake_range: None, + spatial_upscale: None, + temporal_upscale: None, + placement: inputs.placement.clone(), + } +} + +/// Stacked progress bars for chain render: a parent "Chain" bar covering +/// all pixel frames and a transient per-stage bar covering denoise steps. +async fn render_chain_progress(mut rx: tokio::sync::mpsc::UnboundedReceiver) { + // Always draw to stderr so image bytes piped to stdout stay clean. + let mp = MultiProgress::with_draw_target(ProgressDrawTarget::stderr()); + + let parent = mp.add(ProgressBar::new(0)); + parent.set_style( + ProgressStyle::default_bar() + .template(&format!( + "{{prefix:.{c}}} [{{bar:30.{c}/dim}}] {{pos}}/{{len}} frames {{msg}}", + c = theme::SPINNER_STYLE, + )) + .unwrap() + .progress_chars("━╸─"), + ); + parent.set_prefix("Chain"); + parent.enable_steady_tick(Duration::from_millis(100)); + + let mut stage_bar: Option = None; + let mut stage_count: u32 = 0; + + while let Some(event) = rx.recv().await { + match event { + ChainProgressEvent::ChainStart { + stage_count: sc, + estimated_total_frames, + } => { + stage_count = sc; + parent.set_length(estimated_total_frames as u64); + parent.set_message(format!("(stages {sc})")); + } + ChainProgressEvent::StageStart { stage_idx } => { + if let Some(old) = stage_bar.take() { + old.finish_and_clear(); + } + parent.set_message(format!("stage {}/{}", stage_idx + 1, stage_count)); + let sb = mp.add(ProgressBar::new(0)); + sb.set_style( + ProgressStyle::default_bar() + .template(&format!( + " Stage {{prefix}} [{{bar:30.{c}/dim}}] {{pos}}/{{len}} steps", + c = theme::SPINNER_STYLE, + )) + .unwrap() + .progress_chars("━╸─"), + ); + sb.set_prefix(format!("{}", stage_idx + 1)); + sb.enable_steady_tick(Duration::from_millis(100)); + stage_bar = Some(sb); + } + ChainProgressEvent::DenoiseStep { + stage_idx: _, + step, + total, + } => { + if let Some(ref sb) = stage_bar { + if sb.length().unwrap_or(0) == 0 { + sb.set_length(total as u64); + } + sb.set_position(step as u64); + } + } + ChainProgressEvent::StageDone { + stage_idx: _, + frames_emitted, + } => { + if let Some(sb) = stage_bar.take() { + sb.finish_and_clear(); + } + parent.inc(frames_emitted as u64); + } + ChainProgressEvent::Stitching { total_frames } => { + if let Some(sb) = stage_bar.take() { + sb.finish_and_clear(); + } + parent.set_message(format!("stitching {total_frames} frames…")); + } + } + } + + if let Some(sb) = stage_bar.take() { + sb.finish_and_clear(); + } + parent.finish_and_clear(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn routing_single_clip_under_cap() { + let d = decide_chain_routing(Some(97), Some("ltx2"), "ltx-2-19b-distilled:fp8", None, 4); + assert_eq!(d, ChainRoutingDecision::SingleClip); + } + + #[test] + fn routing_single_clip_when_frames_absent() { + let d = decide_chain_routing(None, Some("ltx2"), "ltx-2-19b-distilled:fp8", None, 4); + assert_eq!(d, ChainRoutingDecision::SingleClip); + } + + #[test] + fn routing_chain_over_cap_ltx2_distilled() { + let d = decide_chain_routing(Some(200), Some("ltx2"), "ltx-2-19b-distilled:fp8", None, 4); + assert_eq!( + d, + ChainRoutingDecision::Chain { + clip_frames: 97, + motion_tail: 4, + }, + ); + } + + #[test] + fn routing_rejects_non_distilled_over_cap() { + let d = decide_chain_routing(Some(200), Some("flux"), "flux-dev:q4", None, 4); + match d { + ChainRoutingDecision::Rejected { reason } => { + assert!( + reason.contains("does not support chained video"), + "unexpected reason: {reason}" + ); + } + other => panic!("expected Rejected, got {other:?}"), + } + } + + #[test] + fn routing_rejects_non_ltx2_family_over_cap() { + // ltx-video (not ltx2) is not chainable in v1. + let d = decide_chain_routing(Some(200), Some("ltx-video"), "ltx-video:0.9.6", None, 4); + assert!(matches!(d, ChainRoutingDecision::Rejected { .. })); + } + + #[test] + fn routing_clip_frames_above_cap_clamps_to_cap() { + let d = decide_chain_routing( + Some(300), + Some("ltx2"), + "ltx-2-19b-distilled:fp8", + Some(200), + 4, + ); + assert_eq!( + d, + ChainRoutingDecision::Chain { + clip_frames: 97, + motion_tail: 4, + }, + ); + } + + #[test] + fn routing_clip_frames_under_cap_respected() { + let d = decide_chain_routing( + Some(300), + Some("ltx2"), + "ltx-2-19b-distilled:fp8", + Some(65), + 4, + ); + assert_eq!( + d, + ChainRoutingDecision::Chain { + clip_frames: 65, + motion_tail: 4, + }, + ); + } + + #[test] + fn routing_motion_tail_ge_clip_frames_rejects() { + let d = decide_chain_routing( + Some(300), + Some("ltx2"), + "ltx-2-19b-distilled:fp8", + Some(49), + 49, + ); + match d { + ChainRoutingDecision::Rejected { reason } => { + assert!( + reason.contains("--motion-tail"), + "unexpected reason: {reason}" + ); + } + other => panic!("expected Rejected, got {other:?}"), + } + } + + #[test] + fn routing_motion_tail_at_clip_frames_rejects() { + let d = decide_chain_routing(Some(200), Some("ltx2"), "ltx-2-19b-distilled:fp8", None, 97); + assert!(matches!(d, ChainRoutingDecision::Rejected { .. })); + } + + #[test] + fn ltx2_distilled_cap_matches_engine_constraint() { + // 97 = 8 * 12 + 1, satisfying the VAE 8k+1 constraint. + assert_eq!(LTX2_DISTILLED_CLIP_CAP % 8, 1); + } +} diff --git a/crates/mold-cli/src/commands/generate.rs b/crates/mold-cli/src/commands/generate.rs index 028fa089..cded5bb1 100644 --- a/crates/mold-cli/src/commands/generate.rs +++ b/crates/mold-cli/src/commands/generate.rs @@ -157,6 +157,11 @@ fn apply_local_engine_env_overrides( pub struct Ltx2Options { pub frames: Option, pub fps: Option, + /// Per-clip cap for chained rendering. `None` = use the model-family default + /// (currently 97 for LTX-2 distilled). Only read when `frames > cap`. + pub clip_frames: Option, + /// Motion-tail overlap between chained clips (pixel frames). + pub motion_tail: u32, pub enable_audio: Option, pub audio_file: Option>, pub source_video: Option>, @@ -210,6 +215,8 @@ pub async fn run( let Ltx2Options { frames, fps, + clip_frames, + motion_tail, enable_audio, audio_file, source_video, @@ -243,6 +250,117 @@ pub async fn run( } else { format }; + + // ── Chain routing ───────────────────────────────────────────────────── + // When --frames exceeds the per-clip cap, auto-build a ChainRequest and + // delegate to the chain helper. Only LTX-2 distilled is chainable in v1; + // other video families error fast rather than silently over-producing. + { + use super::chain::{decide_chain_routing, warn_if_clamped, ChainRoutingDecision}; + let decision = decide_chain_routing( + effective_frames, + family.as_deref(), + model, + clip_frames, + motion_tail, + ); + match decision { + ChainRoutingDecision::SingleClip => { + // Fall through to the existing single-clip path below. + } + ChainRoutingDecision::Rejected { reason } => { + anyhow::bail!(reason); + } + ChainRoutingDecision::Chain { + clip_frames: cf, + motion_tail: mt, + } => { + warn_if_clamped(clip_frames, super::chain::LTX2_DISTILLED_CLIP_CAP); + let (eff_w, eff_h) = effective_dimensions( + &config, + &model_cfg, + family.as_deref(), + width, + height, + source_image.as_deref(), + edit_images.as_deref(), + )?; + let eff_steps = steps.unwrap_or_else(|| model_cfg.effective_steps(&config)); + let eff_guidance = guidance.unwrap_or_else(|| model_cfg.effective_guidance()); + let eff_fps = effective_fps.unwrap_or(24); + let total_frames = effective_frames + .expect("decide_chain_routing only returns Chain when frames is Some"); + + // Chain path doesn't use batch/edit_images/mask/control/loras — + // those are single-clip concepts. If the user set them, warn and + // continue (we don't hard-error to keep the UX lenient). + if batch > 1 { + status!( + "{} --batch has no effect in chain mode; rendering a single stitched video", + theme::icon_warn(), + ); + } + + let inputs = super::chain::ChainInputs { + prompt: prompt.to_string(), + model: model.to_string(), + width: eff_w, + height: eff_h, + steps: eff_steps, + guidance: eff_guidance, + strength, + seed, + fps: eff_fps, + output_format, + total_frames, + clip_frames: cf, + motion_tail: mt, + source_image: source_image.clone(), + placement: placement.clone(), + }; + // Consume otherwise-unused LTX-2 knobs that chain v1 ignores so + // clippy doesn't fire `unused_variables` on the early return. + let _ = ( + &audio_file, + &source_video, + &keyframes, + &pipeline, + &loras, + &retake_range, + &spatial_upscale, + &temporal_upscale, + &enable_audio, + &mask_image, + &control_image, + &control_model, + control_scale, + &negative_prompt, + &original_prompt, + &batch_prompts, + &lora, + &scheduler, + expand, + ); + return super::chain::run_chain( + inputs, + host, + output, + no_metadata, + preview, + local, + gpus, + t5_variant, + qwen3_variant, + qwen2_variant, + qwen2_text_encoder_mode, + eager, + offload, + ) + .await; + } + } + } + let piped = is_piped(); // Reject batch > 1 when output goes to stdout (piped with no --output, or --output -) diff --git a/crates/mold-cli/src/commands/mod.rs b/crates/mold-cli/src/commands/mod.rs index ec4bbb27..82a19d16 100644 --- a/crates/mold-cli/src/commands/mod.rs +++ b/crates/mold-cli/src/commands/mod.rs @@ -1,3 +1,4 @@ +pub mod chain; pub mod clean; pub(crate) mod cleanup; pub mod config; diff --git a/crates/mold-cli/src/commands/run.rs b/crates/mold-cli/src/commands/run.rs index ba18d479..85269678 100644 --- a/crates/mold-cli/src/commands/run.rs +++ b/crates/mold-cli/src/commands/run.rs @@ -436,6 +436,8 @@ pub async fn run( batch: u32, frames: Option, fps: Option, + clip_frames: Option, + motion_tail: u32, audio: bool, no_audio: bool, audio_file: Option, @@ -825,6 +827,8 @@ pub async fn run( generate::Ltx2Options { frames, fps, + clip_frames, + motion_tail, enable_audio: if audio { Some(true) } else if no_audio { diff --git a/crates/mold-cli/src/commands/upscale.rs b/crates/mold-cli/src/commands/upscale.rs index 7776a45a..043ec624 100644 --- a/crates/mold-cli/src/commands/upscale.rs +++ b/crates/mold-cli/src/commands/upscale.rs @@ -159,12 +159,19 @@ async fn upscale_local( // Create engine and run upscaling in a blocking thread let model_name_owned = model_name.clone(); let req_clone = req.clone(); + let best_gpu_ordinal = + mold_inference::device::select_best_gpu(&mold_inference::device::discover_gpus()) + .map(|g| g.ordinal) + .unwrap_or(0); let resp = tokio::task::spawn_blocking(move || -> Result { + // Local upscale should target the GPU with the most free VRAM instead + // of hardcoding ordinal 0 on multi-GPU hosts. let mut engine = mold_inference::create_upscale_engine( model_name_owned, weights_path, mold_inference::LoadStrategy::Sequential, + best_gpu_ordinal, )?; // Set up progress callback for stderr diff --git a/crates/mold-cli/src/main.rs b/crates/mold-cli/src/main.rs index 03293886..811a6d20 100644 --- a/crates/mold-cli/src/main.rs +++ b/crates/mold-cli/src/main.rs @@ -377,6 +377,9 @@ Examples: /// Number of video frames to generate (video models only, e.g. ltx-video). /// Implies video output mode; output defaults to .gif format. + /// + /// For LTX-2 distilled, values above 97 automatically chain multiple + /// clips at render time (see `--clip-frames` / `--motion-tail`). #[arg(long, help_heading = "Video")] frames: Option, @@ -385,6 +388,22 @@ Examples: #[arg(long, help_heading = "Video")] fps: Option, + /// Per-clip frame cap for chained video. When --frames exceeds this, + /// the CLI splits into multiple chained clips stitched at render time. + /// Defaults to the model's native cap (97 for LTX-2 distilled). + #[arg(long, value_name = "N", help_heading = "Video")] + clip_frames: Option, + + /// Motion-tail overlap between chained clips in pixel frames. Each clip + /// after the first reuses this many trailing latents from the prior + /// clip, trimming the duplicated pixel frames at stitch time. 0 disables + /// latent carryover (simple concat). Default 17 — three LTX-2 latent + /// frames of carryover at the 8× causal temporal compression (causal- + /// first slot + two continuation slots, ≈0.7 s at 24 fps), enough hard- + /// pinned pixel context to keep scene identity coherent across clips. + #[arg(long, value_name = "N", default_value_t = 17, help_heading = "Video")] + motion_tail: u32, + /// Enable synchronized audio for LTX-2 / LTX-2.3 generation. #[arg(long, help_heading = "Video", conflicts_with = "no_audio")] audio: bool, @@ -1147,6 +1166,8 @@ async fn run() -> anyhow::Result<()> { batch, frames, fps, + clip_frames, + motion_tail, audio, no_audio, audio_file, @@ -1205,6 +1226,8 @@ async fn run() -> anyhow::Result<()> { batch, frames, fps, + clip_frames, + motion_tail, audio, no_audio, audio_file, @@ -2131,6 +2154,53 @@ mod tests { } } + #[test] + fn run_chain_flags_parse() { + let cli = parse(&[ + "run", + "ltx-2-19b-distilled:fp8", + "a cat", + "--frames", + "200", + "--clip-frames", + "97", + "--motion-tail", + "4", + ]); + match cli.command { + Commands::Run { + frames, + clip_frames, + motion_tail, + .. + } => { + assert_eq!(frames, Some(200)); + assert_eq!(clip_frames, Some(97)); + assert_eq!(motion_tail, 4); + } + _ => panic!("expected Run"), + } + } + + #[test] + fn run_motion_tail_defaults_to_seventeen() { + let cli = parse(&["run", "ltx-2-19b-distilled:fp8", "a cat", "--frames", "200"]); + match cli.command { + Commands::Run { + motion_tail, + clip_frames, + .. + } => { + assert_eq!( + motion_tail, 17, + "default motion tail must be 17 frames (three LTX-2 latent frames: causal + two continuation)" + ); + assert_eq!(clip_frames, None); + } + _ => panic!("expected Run"), + } + } + // --- Regression test for issue #190: --version includes git SHA --- #[test] diff --git a/crates/mold-core/src/chain.rs b/crates/mold-core/src/chain.rs new file mode 100644 index 00000000..994d46bf --- /dev/null +++ b/crates/mold-core/src/chain.rs @@ -0,0 +1,728 @@ +//! Wire types for server-side chained video generation. +//! +//! A *chain* is a sequence of per-clip render stages stitched into a single +//! output video. The v1 CLI UX is single-prompt + arbitrary length, but the +//! wire format is stages-based from day one so the eventual movie-maker +//! (multi-prompt, keyframes, selective regen) can author stages by hand +//! without a breaking change. +//! +//! The server only ever sees the canonical [`ChainRequest`] shape — a +//! `Vec`. Callers can either build that directly or use the +//! auto-expand form (`prompt` + `total_frames` + `clip_frames`), which +//! [`ChainRequest::normalise`] collapses into stages. +//! +//! See `tasks/render-chain-v1-plan.md` for the full design rationale. + +use serde::{Deserialize, Serialize}; + +use crate::error::{MoldError, Result}; +use crate::types::{DevicePlacement, OutputFormat, VideoData}; + +/// A single rendered clip in a chain. Concatenated in order with motion-tail +/// trimming on continuations (stages with `idx >= 1` drop the leading +/// `motion_tail_frames` pixel frames of their output because those duplicate +/// the tail of the previous stage that the engine carried across as +/// latent-space conditioning). +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ChainStage { + /// Prompt used for this stage. In v1 all stages receive the same prompt + /// (auto-expand form replicates it); the movie-maker UI in v2 will let + /// users author per-stage prompts. + #[schema(example = "a cat walking through autumn leaves")] + pub prompt: String, + + /// Frame count for this stage. Must be `8k+1` (LTX-2 pipeline constraint: + /// 9, 17, 25, …, 97). + #[schema(example = 97)] + pub frames: u32, + + /// Optional starting image (raw PNG/JPEG bytes, base64 in JSON). In v1 + /// this is only meaningful on `stages[0]`; later stages draw their + /// conditioning from the prior stage's motion-tail latents instead. + #[serde( + default, + skip_serializing_if = "Option::is_none", + with = "crate::types::base64_opt" + )] + pub source_image: Option>, + + /// Optional negative prompt for CFG-based stages. v1 LTX-2 ignores this + /// (the distilled family doesn't use CFG); the field is reserved so the + /// movie-maker can round-trip it without re-migrating the wire format. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub negative_prompt: Option, + + /// Optional per-stage seed offset. `None` in v1 — the orchestrator + /// derives each stage's seed from the chain's base seed. Reserved as the + /// v2 movie-maker override hook for "regenerate just this stage with a + /// different seed". + #[serde(default, skip_serializing_if = "Option::is_none")] + pub seed_offset: Option, +} + +/// Chained generation request. Server accepts either the canonical form +/// (`stages` non-empty) or the auto-expand form (`prompt` + `total_frames` + +/// `clip_frames`); [`ChainRequest::normalise`] collapses the latter into the +/// former so downstream code only deals with `stages`. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ChainRequest { + #[schema(example = "ltx-2-19b-distilled:fp8")] + pub model: String, + + /// Canonical stages list. Empty triggers auto-expand from + /// `prompt`/`total_frames`/`clip_frames`. + #[serde(default)] + pub stages: Vec, + + /// Pixel frames of motion-tail overlap between consecutive stages. + /// `0` = no overlap (simple concat). `>0` = the final K pixel frames of + /// stage N's latents are threaded into stage N+1's conditioning, and + /// stage N+1's leading K output frames are dropped at stitch time. + /// + /// Defaults to `4` for v1 (matches the CLI default). Must be strictly + /// less than each stage's `frames`. + #[serde(default = "default_motion_tail_frames")] + #[schema(example = 4)] + pub motion_tail_frames: u32, + + #[schema(example = 1216)] + pub width: u32, + #[schema(example = 704)] + pub height: u32, + #[serde(default = "default_fps")] + #[schema(example = 24)] + pub fps: u32, + + /// Chain base seed. Per-stage seeds are derived as + /// `base_seed ^ ((stage_idx as u64) << 32)` by the orchestrator so the + /// whole chain is reproducible from a single seed value. + #[serde(default, skip_serializing_if = "Option::is_none")] + #[schema(example = 42)] + pub seed: Option, + + #[schema(example = 8)] + pub steps: u32, + + #[schema(example = 3.0)] + pub guidance: f64, + + /// Denoising strength for `stages[0].source_image`. Ignored when the + /// first stage has no source image. Continuation stages are always + /// full-strength conditioned via motion-tail latents. + #[serde(default = "default_strength")] + #[schema(example = 1.0)] + pub strength: f64, + + #[serde(default = "default_output_format")] + pub output_format: OutputFormat, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub placement: Option, + + // ── Auto-expand form ──────────────────────────────────────────────── + // These are only read when `stages` is empty; `normalise` clears them + // after expansion so the canonical form only ever carries `stages`. + /// Auto-expand: single prompt replicated across all stages. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub prompt: Option, + + /// Auto-expand: total pixel frames the stitched output should cover. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub total_frames: Option, + + /// Auto-expand: per-clip frame count. Defaults to `97` (LTX-2 19B/22B + /// distilled cap). Must be `8k+1`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub clip_frames: Option, + + /// Auto-expand: starting image for `stages[0]`. + #[serde( + default, + skip_serializing_if = "Option::is_none", + with = "crate::types::base64_opt" + )] + pub source_image: Option>, +} + +/// Response from a chained generation request. The `video` is the stitched +/// output; individual per-stage clips are not returned. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ChainResponse { + pub video: VideoData, + /// Number of stages that actually ran (matches `request.stages.len()` + /// after normalisation). + #[schema(example = 5)] + pub stage_count: u32, + /// GPU ordinal that handled the chain (multi-GPU servers only). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub gpu: Option, +} + +/// SSE completion event for a successful chain run. Streamed as the final +/// `data:` frame under the `event: complete` SSE type. The payload is +/// base64-encoded to stay JSON-safe; clients decode it into `VideoData`. +/// +/// This is a sibling to [`crate::types::SseCompleteEvent`] rather than an +/// extension so image/video vs. chain completion shapes stay independent +/// and can evolve separately. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct SseChainCompleteEvent { + /// Base64-encoded stitched video bytes (format per `format` field). + pub video: String, + pub format: OutputFormat, + #[schema(example = 1216)] + pub width: u32, + #[schema(example = 704)] + pub height: u32, + #[schema(example = 400)] + pub frames: u32, + #[schema(example = 24)] + pub fps: u32, + /// Base64-encoded first-frame PNG thumbnail. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub thumbnail: Option, + /// Base64-encoded animated GIF preview (always emitted for gallery UI). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub gif_preview: Option, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub has_audio: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub duration_ms: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub audio_sample_rate: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub audio_channels: Option, + /// Number of stages that ran end-to-end. + #[schema(example = 5)] + pub stage_count: u32, + /// GPU ordinal that handled the chain (multi-GPU only). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub gpu: Option, + /// Wall-clock elapsed time across all stages + stitching. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub generation_time_ms: Option, +} + +/// Chain-specific SSE progress event. Streamed as `data:` JSON frames from +/// `POST /api/generate/chain/stream` under the `event: progress` SSE type. +/// +/// Per-stage denoise steps are wrapped with `stage_idx` so consumers can +/// render stacked progress bars (overall chain + per-stage) without a +/// separate subscription. Non-denoise engine events (weight load, cache +/// hits, etc.) are intentionally not forwarded through this enum in v1 — +/// they're scoped to individual stages and the UX goal for v1 is per-stage +/// progress, not per-component telemetry. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChainProgressEvent { + /// Emitted once at the start of the chain, after normalisation. Gives + /// consumers the final stage count and the target pre-trim frame total + /// so they can size progress bars up front. + ChainStart { + stage_count: u32, + estimated_total_frames: u32, + }, + /// Stage `stage_idx` (0-indexed) has started its denoise loop. + StageStart { stage_idx: u32 }, + /// Per-step denoise progress for the active stage. + DenoiseStep { + stage_idx: u32, + step: u32, + total: u32, + }, + /// Stage finished generating; `frames_emitted` is the raw clip frame + /// count before motion-tail trim at stitch time. + StageDone { stage_idx: u32, frames_emitted: u32 }, + /// All stages complete; stitching/encoding the final MP4. + Stitching { total_frames: u32 }, +} + +fn default_motion_tail_frames() -> u32 { + 4 +} + +fn default_fps() -> u32 { + 24 +} + +fn default_strength() -> f64 { + 1.0 +} + +fn default_output_format() -> OutputFormat { + OutputFormat::Mp4 +} + +/// Maximum number of stages the v1 orchestrator will accept in a single +/// chain. 16 × 97-frame clips ≈ 1552 frames ≈ 64 s at 24 fps — comfortably +/// past the 400-frame target without risking runaway jobs. +pub const MAX_CHAIN_STAGES: usize = 16; + +impl ChainRequest { + /// Collapse the auto-expand form into a canonical `Vec` and + /// validate the result. Called once on the server side immediately after + /// JSON parsing, before any engine work kicks off. + /// + /// Post-conditions on a successful return: + /// - `self.stages` is non-empty. + /// - Each stage's `frames` is `8k+1` and `> 0`. + /// - `self.stages.len() <= MAX_CHAIN_STAGES`. + /// - All auto-expand fields are `None` (caller must use `self.stages`). + pub fn normalise(mut self) -> Result { + if self.stages.is_empty() { + let prompt = self.prompt.take().ok_or_else(|| { + MoldError::Validation( + "chain request needs either stages[] or prompt + total_frames".into(), + ) + })?; + let total_frames = self.total_frames.ok_or_else(|| { + MoldError::Validation("chain auto-expand requires total_frames".into()) + })?; + if total_frames == 0 { + return Err(MoldError::Validation( + "chain total_frames must be > 0".into(), + )); + } + let clip_frames = self.clip_frames.unwrap_or(97); + if clip_frames == 0 { + return Err(MoldError::Validation( + "chain clip_frames must be > 0".into(), + )); + } + if !is_ltx2_frame_count(clip_frames) { + return Err(MoldError::Validation(format!( + "chain clip_frames ({clip_frames}) must be 8k+1 (9, 17, 25, …, 97)", + ))); + } + let motion_tail = self.motion_tail_frames; + if motion_tail >= clip_frames { + return Err(MoldError::Validation(format!( + "motion_tail_frames ({motion_tail}) must be strictly less than clip_frames ({clip_frames})", + ))); + } + + let source_image = self.source_image.take(); + self.stages = build_auto_expand_stages( + &prompt, + total_frames, + clip_frames, + motion_tail, + source_image, + )?; + } + + if self.stages.is_empty() { + return Err(MoldError::Validation("chain request has no stages".into())); + } + if self.stages.len() > MAX_CHAIN_STAGES { + return Err(MoldError::Validation(format!( + "chain request has {} stages; maximum is {}", + self.stages.len(), + MAX_CHAIN_STAGES, + ))); + } + for (idx, stage) in self.stages.iter().enumerate() { + if stage.frames == 0 { + return Err(MoldError::Validation(format!("stage {idx} has 0 frames",))); + } + if !is_ltx2_frame_count(stage.frames) { + return Err(MoldError::Validation(format!( + "stage {idx} has {} frames; LTX-2 requires 8k+1 (9, 17, 25, …, 97)", + stage.frames, + ))); + } + if self.motion_tail_frames >= stage.frames { + return Err(MoldError::Validation(format!( + "motion_tail_frames ({}) must be strictly less than stage {idx}'s frames ({})", + self.motion_tail_frames, stage.frames, + ))); + } + } + + // Canonicalise: clear auto-expand fields so downstream code only + // ever reads from `stages`. + self.prompt = None; + self.total_frames = None; + self.clip_frames = None; + self.source_image = None; + + Ok(self) + } +} + +/// Returns `true` iff `n` has the form `8k + 1` for some non-negative integer +/// `k` (1, 9, 17, 25, …). The LTX-2 pipeline has this constraint on pixel +/// frame counts due to the VAE's 8× temporal compression with a causal first +/// frame. +fn is_ltx2_frame_count(n: u32) -> bool { + n % 8 == 1 +} + +/// Compute the stage count and per-stage frame allocation for the auto- +/// expand form, matching Phase 1.4's stitch math: +/// +/// - Stage 0 contributes `clip_frames` pixel frames. +/// - Each continuation contributes `clip_frames - motion_tail_frames` new +/// frames (the leading `motion_tail_frames` are dropped at stitch time +/// because they duplicate the prior stage's latent tail). +/// +/// Returns enough stages so the stitched total reaches at least +/// `total_frames`; over-production is trimmed from the tail at stitch time +/// per the signed-off decision 2026-04-20. +fn build_auto_expand_stages( + prompt: &str, + total_frames: u32, + clip_frames: u32, + motion_tail_frames: u32, + source_image: Option>, +) -> Result> { + let (stage_count, per_stage_frames) = if total_frames <= clip_frames { + // Single stage: match the user's requested length exactly so we + // don't render 97 frames and throw most of them away. The frame + // count will still be validated as 8k+1 by the caller. + (1u32, total_frames) + } else { + let effective = clip_frames - motion_tail_frames; + // effective > 0 because the caller has already ensured + // motion_tail_frames < clip_frames. + let remainder = total_frames - clip_frames; + let count = 1 + remainder.div_ceil(effective); + (count, clip_frames) + }; + + let count_usize = stage_count as usize; + if count_usize > MAX_CHAIN_STAGES { + return Err(MoldError::Validation(format!( + "auto-expand would produce {stage_count} stages; maximum is {MAX_CHAIN_STAGES} \ + (try reducing total_frames or increasing clip_frames)", + ))); + } + + let mut stages = Vec::with_capacity(count_usize); + for _ in 0..stage_count { + // Every stage carries the starting image: stage 0 uses it as the + // i2v replacement at frame 0, and continuation stages use it as a + // soft identity anchor through the append path (see + // `Ltx2Engine::render_chain_stage`). Keeping a durable reference + // across stages is what stops scene/identity drift past the first + // clip, whose effects were traced in render-chain v1 as the + // dominant cause of "strange" continuations — the motion tail + // alone only carries ~0.7 s of pixel context, nowhere near enough + // for the model to remember the scene across an 8-stage chain. + stages.push(ChainStage { + prompt: prompt.to_string(), + frames: per_stage_frames, + source_image: source_image.clone(), + negative_prompt: None, + seed_offset: None, + }); + } + Ok(stages) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a minimal auto-expand request with the given knobs. All other + /// fields use their v1 defaults so tests can focus on the logic under + /// exercise. + fn auto_expand_request( + prompt: &str, + total_frames: u32, + clip_frames: u32, + motion_tail_frames: u32, + source_image: Option>, + ) -> ChainRequest { + ChainRequest { + model: "ltx-2-19b-distilled:fp8".into(), + stages: Vec::new(), + motion_tail_frames, + width: 1216, + height: 704, + fps: 24, + seed: Some(42), + steps: 8, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Mp4, + placement: None, + prompt: Some(prompt.into()), + total_frames: Some(total_frames), + clip_frames: Some(clip_frames), + source_image, + } + } + + fn canonical_request(stages: Vec, motion_tail_frames: u32) -> ChainRequest { + ChainRequest { + model: "ltx-2-19b-distilled:fp8".into(), + stages, + motion_tail_frames, + width: 1216, + height: 704, + fps: 24, + seed: Some(42), + steps: 8, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Mp4, + placement: None, + prompt: None, + total_frames: None, + clip_frames: None, + source_image: None, + } + } + + fn make_stage(frames: u32) -> ChainStage { + ChainStage { + prompt: "test".into(), + frames, + source_image: None, + negative_prompt: None, + seed_offset: None, + } + } + + #[test] + fn normalise_splits_single_prompt_into_stages() { + // total=400, clip=97, tail=4 → effective=93, remainder=303, + // N = 1 + ceil(303/93) = 1 + 4 = 5 stages of 97 frames each. + // Stitched = 97 + 4*93 = 469, which will be trimmed to 400 at + // stitch time (per the signed-off "trim from tail" decision). + let normalised = auto_expand_request("a cat walking", 400, 97, 4, None) + .normalise() + .expect("normalise should succeed"); + + assert_eq!( + normalised.stages.len(), + 5, + "400/97 with a 4-frame motion tail should expand to 5 stages", + ); + for stage in &normalised.stages { + assert_eq!(stage.frames, 97); + assert_eq!(stage.prompt, "a cat walking"); + assert!(stage.seed_offset.is_none()); + } + // Auto-expand fields are cleared post-normalisation. + assert!(normalised.prompt.is_none()); + assert!(normalised.total_frames.is_none()); + assert!(normalised.clip_frames.is_none()); + assert!(normalised.source_image.is_none()); + } + + #[test] + fn normalise_preserves_starting_image_across_all_stages() { + let png = vec![0x89, 0x50, 0x4e, 0x47, 0xde, 0xad, 0xbe, 0xef]; + let normalised = auto_expand_request("test", 200, 97, 4, Some(png.clone())) + .normalise() + .expect("normalise should succeed"); + + assert!(normalised.stages.len() >= 2); + for (idx, stage) in normalised.stages.iter().enumerate() { + // Every stage must carry the starting image. Stage 0 uses it + // as the i2v replacement at frame 0; continuations use it as a + // soft identity anchor through the append path so scene and + // subject identity stay coherent past the motion-tail window. + assert_eq!( + stage.source_image.as_deref(), + Some(png.as_slice()), + "stage {idx} must carry the starting image for cross-stage identity anchoring", + ); + } + } + + #[test] + fn normalise_rejects_empty() { + let mut req = canonical_request(Vec::new(), 4); + // No auto-expand fields either. + req.prompt = None; + req.total_frames = None; + + let err = req.normalise().expect_err("empty chain should fail"); + assert!( + matches!(err, MoldError::Validation(_)), + "empty chain should be a validation error, got {err:?}", + ); + } + + #[test] + fn normalise_rejects_non_8k1_frames() { + // Canonical form with a stage whose frames violates the 8k+1 + // constraint. + let req = canonical_request(vec![make_stage(50)], 4); + let err = req.normalise().expect_err("non-8k+1 frames should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("8k+1")), + "error must mention the 8k+1 constraint", + ); + } + + #[test] + fn normalise_accepts_canonical_form_unchanged() { + // Caller already built stages; normalise should validate and clear + // the (already-empty) auto-expand fields without touching stages. + let stages = vec![make_stage(97), make_stage(97), make_stage(97)]; + let normalised = canonical_request(stages.clone(), 4) + .normalise() + .expect("valid canonical form should pass"); + assert_eq!(normalised.stages.len(), 3); + for (left, right) in normalised.stages.iter().zip(stages.iter()) { + assert_eq!(left.frames, right.frames); + assert_eq!(left.prompt, right.prompt); + } + } + + #[test] + fn normalise_single_stage_when_total_leq_clip() { + // total=9 fits in one clip; don't render a full 97-frame stage and + // throw most of it away. + let normalised = auto_expand_request("short", 9, 97, 4, None) + .normalise() + .expect("short single-clip chain should pass"); + assert_eq!(normalised.stages.len(), 1); + assert_eq!(normalised.stages[0].frames, 9); + } + + #[test] + fn normalise_rejects_too_many_stages() { + // 17 canonical stages exceeds MAX_CHAIN_STAGES (16). + let stages = (0..17).map(|_| make_stage(97)).collect(); + let err = canonical_request(stages, 4) + .normalise() + .expect_err("17-stage chain should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("maximum")), + "error must mention the max-stages cap", + ); + } + + #[test] + fn normalise_rejects_auto_expand_too_long() { + // 16 × 97 = 1552 max stitched frames before trim; asking for + // 4000 frames should blow the guardrail. + let err = auto_expand_request("too long", 4000, 97, 4, None) + .normalise() + .expect_err("runaway auto-expand should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("stages")), + "error must name the stage count guardrail", + ); + } + + #[test] + fn normalise_rejects_motion_tail_ge_clip() { + // motion_tail must leave at least one new frame per continuation. + let err = auto_expand_request("bad tail", 200, 97, 97, None) + .normalise() + .expect_err("motion_tail >= clip should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("motion_tail_frames")), + "error must name motion_tail_frames", + ); + } + + #[test] + fn normalise_rejects_missing_total_frames_in_auto_expand() { + let mut req = canonical_request(Vec::new(), 4); + req.prompt = Some("missing total".into()); + // total_frames omitted. + let err = req + .normalise() + .expect_err("missing total_frames should fail"); + assert!( + matches!(err, MoldError::Validation(msg) if msg.contains("total_frames")), + "error must name total_frames", + ); + } + + #[test] + fn is_ltx2_frame_count_matches_8k_plus_1() { + for valid in [1u32, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97] { + assert!( + is_ltx2_frame_count(valid), + "{valid} should be a valid LTX-2 frame count", + ); + } + for invalid in [0u32, 2, 8, 10, 16, 50, 96, 98, 100] { + assert!( + !is_ltx2_frame_count(invalid), + "{invalid} must not pass the 8k+1 check", + ); + } + } + + #[test] + fn chain_progress_event_roundtrips_json_with_snake_case_tags() { + let cases = [ + ( + ChainProgressEvent::ChainStart { + stage_count: 5, + estimated_total_frames: 469, + }, + r#""type":"chain_start""#, + ), + ( + ChainProgressEvent::StageStart { stage_idx: 0 }, + r#""type":"stage_start""#, + ), + ( + ChainProgressEvent::DenoiseStep { + stage_idx: 2, + step: 4, + total: 8, + }, + r#""type":"denoise_step""#, + ), + ( + ChainProgressEvent::StageDone { + stage_idx: 3, + frames_emitted: 97, + }, + r#""type":"stage_done""#, + ), + ( + ChainProgressEvent::Stitching { total_frames: 400 }, + r#""type":"stitching""#, + ), + ]; + for (event, expected_tag) in cases { + let json = serde_json::to_string(&event).expect("serialize"); + assert!( + json.contains(expected_tag), + "missing snake_case tag {expected_tag} in {json}", + ); + let roundtrip: ChainProgressEvent = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(roundtrip, event, "roundtrip must preserve payload"); + } + } + + #[test] + fn build_stages_math_matches_stitch_budget() { + // Auto-expand must produce enough stages that the stitch delivers + // at least `total_frames` pixel frames. Stitch math: + // delivered = clip_frames + (N - 1) * (clip_frames - motion_tail) + let cases = [ + (400u32, 97u32, 4u32, 5u32), // 97 + 4*93 = 469 ≥ 400 + (200, 97, 4, 3), // 97 + 2*93 = 283 ≥ 200 + (97, 97, 4, 1), // single clip hits 97 exactly + (300, 97, 0, 4), // zero tail, 4*97 = 388 ≥ 300 + ]; + for (total, clip, tail, expected_n) in cases { + let req = auto_expand_request("m", total, clip, tail, None) + .normalise() + .expect("valid auto-expand should normalise"); + assert_eq!( + req.stages.len() as u32, + expected_n, + "expected {expected_n} stages for total={total}, clip={clip}, tail={tail}", + ); + let delivered = clip + (expected_n - 1) * (clip - tail); + assert!( + delivered >= total, + "{expected_n} stages deliver {delivered} frames but {total} were requested", + ); + } + } +} diff --git a/crates/mold-core/src/client.rs b/crates/mold-core/src/client.rs index e900739f..66835d5d 100644 --- a/crates/mold-core/src/client.rs +++ b/crates/mold-core/src/client.rs @@ -1,3 +1,4 @@ +use crate::chain::{ChainProgressEvent, ChainRequest, ChainResponse, SseChainCompleteEvent}; use crate::error::MoldError; use crate::types::{ ExpandRequest, ExpandResponse, GalleryImage, GenerateRequest, GenerateResponse, ImageData, @@ -313,6 +314,137 @@ impl MoldClient { anyhow::bail!("SSE stream ended without complete event") } + /// Submit a chained video generation request (non-streaming). + /// + /// The server normalises the auto-expand form into stages, runs each + /// stage sequentially with motion-tail latent carryover, stitches the + /// result into a single video, and returns a [`ChainResponse`]. Large + /// chains take minutes — prefer [`Self::generate_chain_stream`] for + /// interactive clients that want progress updates. + pub async fn generate_chain(&self, req: &ChainRequest) -> Result { + let resp = self + .client + .post(format!("{}/api/generate/chain", self.base_url)) + .json(req) + .send() + .await?; + + if resp.status() == reqwest::StatusCode::NOT_FOUND { + let body = resp.text().await.unwrap_or_default(); + if body.is_empty() { + anyhow::bail!("chain endpoint not found — server predates render-chain v1"); + } + return Err(MoldError::ModelNotFound(body).into()); + } + if resp.status() == reqwest::StatusCode::UNPROCESSABLE_ENTITY { + let body = resp.text().await.unwrap_or_default(); + return Err(MoldError::Validation(format!("validation error: {body}")).into()); + } + if resp.status().is_client_error() || resp.status().is_server_error() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("server error {status}: {body}"); + } + + let chain: ChainResponse = resp.json().await?; + Ok(chain) + } + + /// Submit a chained video generation request with SSE progress streaming. + /// + /// Returns: + /// - `Ok(Some(response))` — streaming succeeded and the `complete` event + /// carried the stitched video. + /// - `Ok(None)` — server doesn't have the chain endpoint (empty 404). + /// Callers can fall back to [`Self::generate_chain`] or error. + /// - `Err(_)` — validation, model-not-found, or mid-stream server error. + pub async fn generate_chain_stream( + &self, + req: &ChainRequest, + progress_tx: tokio::sync::mpsc::UnboundedSender, + ) -> Result> { + let mut resp = self + .client + .post(format!("{}/api/generate/chain/stream", self.base_url)) + .json(req) + .send() + .await?; + + if resp.status() == reqwest::StatusCode::NOT_FOUND { + let body = resp.text().await.unwrap_or_default(); + if body.is_empty() { + return Ok(None); + } + return Err(MoldError::ModelNotFound(body).into()); + } + if resp.status() == reqwest::StatusCode::UNPROCESSABLE_ENTITY { + let body = resp.text().await.unwrap_or_default(); + return Err(MoldError::Validation(format!("validation error: {body}")).into()); + } + if resp.status().is_client_error() || resp.status().is_server_error() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("server error {status}: {body}"); + } + + let b64 = base64::engine::general_purpose::STANDARD; + let mut buffer = String::new(); + while let Some(chunk) = resp.chunk().await? { + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(event_text) = next_sse_event(&mut buffer) { + let (event_type, data) = parse_sse_event(&event_text); + match event_type.as_str() { + "progress" => { + if let Ok(p) = serde_json::from_str::(&data) { + let _ = progress_tx.send(p); + } + } + "complete" => { + let complete: SseChainCompleteEvent = serde_json::from_str(&data)?; + let payload = b64.decode(&complete.video)?; + let thumbnail = complete + .thumbnail + .as_deref() + .and_then(|s| b64.decode(s).ok()) + .unwrap_or_default(); + let gif_preview = complete + .gif_preview + .as_deref() + .and_then(|s| b64.decode(s).ok()) + .unwrap_or_default(); + let video = VideoData { + data: payload, + format: complete.format, + width: complete.width, + height: complete.height, + frames: complete.frames, + fps: complete.fps, + thumbnail, + gif_preview, + has_audio: complete.has_audio, + duration_ms: complete.duration_ms, + audio_sample_rate: complete.audio_sample_rate, + audio_channels: complete.audio_channels, + }; + return Ok(Some(ChainResponse { + video, + stage_count: complete.stage_count, + gpu: complete.gpu, + })); + } + "error" => { + let error: SseErrorEvent = serde_json::from_str(&data)?; + anyhow::bail!("server error: {}", error.message); + } + _ => {} + } + } + } + + anyhow::bail!("chain SSE stream ended without complete event") + } + /// Ask the server to pull (download) a model. Blocks until the download /// completes on the server side. The server updates its in-memory config /// so subsequent generate/load requests can find the model. @@ -412,14 +544,27 @@ impl MoldClient { } pub async fn unload_model(&self) -> Result { - let resp = self + self.unload_model_target(None, None).await + } + + pub async fn unload_model_target( + &self, + model: Option<&str>, + gpu: Option, + ) -> Result { + let req = serde_json::json!({ + "model": model, + "gpu": gpu, + }); + let builder = self .client - .delete(format!("{}/api/models/unload", self.base_url)) - .send() - .await? - .error_for_status()? - .text() - .await?; + .delete(format!("{}/api/models/unload", self.base_url)); + let builder = if model.is_some() || gpu.is_some() { + builder.json(&req) + } else { + builder + }; + let resp = builder.send().await?.error_for_status()?.text().await?; Ok(resp) } diff --git a/crates/mold-core/src/download.rs b/crates/mold-core/src/download.rs index 7f30665e..94610be7 100644 --- a/crates/mold-core/src/download.rs +++ b/crates/mold-core/src/download.rs @@ -1448,7 +1448,7 @@ mod tests { buf.extend_from_slice(b"GGUF"); // Pad a couple hundred bytes of synthetic header bytes, then include // every tensor name as a plain UTF-8 substring so the scanner finds it. - buf.extend(std::iter::repeat(0u8).take(256)); + buf.extend(std::iter::repeat_n(0u8, 256)); for name in tensor_names { buf.extend_from_slice(name.as_bytes()); buf.push(0); diff --git a/crates/mold-core/src/lib.rs b/crates/mold-core/src/lib.rs index a16bb81f..9da6a5e2 100644 --- a/crates/mold-core/src/lib.rs +++ b/crates/mold-core/src/lib.rs @@ -1,5 +1,6 @@ pub mod build_info; pub mod catalog; +pub mod chain; pub mod client; pub mod config; pub mod control; @@ -18,6 +19,10 @@ mod config_test; mod test_support; pub use catalog::build_model_catalog; +pub use chain::{ + ChainProgressEvent, ChainRequest, ChainResponse, ChainStage, SseChainCompleteEvent, + MAX_CHAIN_STAGES, +}; pub use client::MoldClient; pub use config::{ parse_device_ref_str, Config, DefaultModelResolution, DefaultModelSource, LoggingConfig, diff --git a/crates/mold-core/src/placement_test.rs b/crates/mold-core/src/placement_test.rs index 800540a1..bf3e81c2 100644 --- a/crates/mold-core/src/placement_test.rs +++ b/crates/mold-core/src/placement_test.rs @@ -163,16 +163,18 @@ fn generate_request_without_placement_is_none() { #[test] fn model_config_serializes_placement_section() { use crate::config::{Config, ModelConfig}; - let mut mc = ModelConfig::default(); - mc.placement = Some(DevicePlacement { - text_encoders: DeviceRef::Cpu, - advanced: Some(AdvancedPlacement { - transformer: DeviceRef::gpu(0), - vae: DeviceRef::Cpu, - t5: Some(DeviceRef::Cpu), - ..Default::default() + let mc = ModelConfig { + placement: Some(DevicePlacement { + text_encoders: DeviceRef::Cpu, + advanced: Some(AdvancedPlacement { + transformer: DeviceRef::gpu(0), + vae: DeviceRef::Cpu, + t5: Some(DeviceRef::Cpu), + ..Default::default() + }), }), - }); + ..Default::default() + }; let mut cfg = Config::default(); cfg.models.insert("flux-dev:q4".to_string(), mc); diff --git a/crates/mold-core/src/types.rs b/crates/mold-core/src/types.rs index ade380e1..f737c776 100644 --- a/crates/mold-core/src/types.rs +++ b/crates/mold-core/src/types.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; /// Serde helpers for `Option>` as base64 in JSON. -mod base64_opt { +pub(crate) mod base64_opt { use base64::Engine as _; use serde::{Deserialize, Deserializer, Serializer}; @@ -1006,6 +1006,11 @@ pub struct ResourceSnapshot { pub timestamp: i64, pub gpus: Vec, pub system_ram: RamSnapshot, + /// System-wide CPU utilization (averaged across all cores). `None` when + /// the aggregator hasn't had two samples yet (CPU usage is computed from + /// deltas — the first snapshot always reports zero). + #[serde(default)] + pub cpu: Option, } /// Per-GPU memory snapshot. @@ -1021,6 +1026,18 @@ pub struct GpuSnapshot { pub vram_used_by_mold: Option, /// `vram_used - vram_used_by_mold`. `None` whenever `vram_used_by_mold` is. pub vram_used_by_other: Option, + /// GPU core utilization in percent (0-100). `None` on Metal and on the + /// `nvidia-smi` fallback path — only NVML exposes this cheaply. + #[serde(default)] + pub gpu_utilization: Option, +} + +/// Aggregate CPU snapshot. `usage_percent` is a 0-100 average across every +/// logical core. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct CpuSnapshot { + pub cores: u16, + pub usage_percent: f32, } /// System RAM snapshot. Per-process fields are always populated (via sysinfo). @@ -2233,6 +2250,7 @@ mod tests { vram_used: 14_200_000_000, vram_used_by_mold: Some(10_100_000_000), vram_used_by_other: Some(4_100_000_000), + gpu_utilization: Some(42), }], system_ram: RamSnapshot { total: 64_000_000_000, @@ -2240,6 +2258,10 @@ mod tests { used_by_mold: 22_100_000_000, used_by_other: 16_300_000_000, }, + cpu: Some(CpuSnapshot { + cores: 16, + usage_percent: 27.5, + }), }; let json = serde_json::to_string(&snap).unwrap(); let back: ResourceSnapshot = serde_json::from_str(&json).unwrap(); @@ -2269,6 +2291,7 @@ mod tests { vram_used: 38_000_000_000, vram_used_by_mold: None, vram_used_by_other: None, + gpu_utilization: None, }; let json = serde_json::to_string(&snap).unwrap(); // Both fields are present as `null` (not elided) so the SPA can diff --git a/crates/mold-core/tests/chain_client.rs b/crates/mold-core/tests/chain_client.rs new file mode 100644 index 00000000..dc06d1bb --- /dev/null +++ b/crates/mold-core/tests/chain_client.rs @@ -0,0 +1,234 @@ +//! Integration tests for `MoldClient::generate_chain{,_stream}` using +//! `wiremock` to simulate the `/api/generate/chain` server endpoints. +//! +//! These tests pin the HTTP surface (method, path, JSON request body) and +//! verify error translation (422 → Validation, 404 empty → None on stream, +//! 404 with body → ModelNotFound). They do NOT exercise real LTX-2 work — +//! the server side lands in Phase 2. + +use base64::Engine as _; +use mold_core::chain::{ChainProgressEvent, ChainRequest, ChainStage, SseChainCompleteEvent}; +use mold_core::error::MoldError; +use mold_core::types::OutputFormat; +use mold_core::MoldClient; +use wiremock::matchers::{body_json_schema, method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +fn mold_error(err: &anyhow::Error) -> &MoldError { + err.downcast_ref::() + .unwrap_or_else(|| panic!("not a MoldError: {err}")) +} + +fn sample_request() -> ChainRequest { + ChainRequest { + model: "ltx-2-19b-distilled:fp8".into(), + stages: vec![ChainStage { + prompt: "a cat walking".into(), + frames: 97, + source_image: None, + negative_prompt: None, + seed_offset: None, + }], + motion_tail_frames: 4, + width: 1216, + height: 704, + fps: 24, + seed: Some(42), + steps: 8, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Mp4, + placement: None, + prompt: None, + total_frames: None, + clip_frames: None, + source_image: None, + } +} + +fn minimal_chain_response_json() -> serde_json::Value { + serde_json::json!({ + "video": { + "data": [], + "format": "mp4", + "width": 1216, + "height": 704, + "frames": 97, + "fps": 24, + "thumbnail": [] + }, + "stage_count": 1 + }) +} + +// ── /api/generate/chain (non-streaming) ──────────────────────────────── + +#[tokio::test] +async fn generate_chain_posts_to_correct_endpoint_and_parses_response() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain")) + .and(body_json_schema::) + .respond_with(ResponseTemplate::new(200).set_body_json(minimal_chain_response_json())) + .expect(1) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let resp = client + .generate_chain(&sample_request()) + .await + .expect("non-streaming chain should succeed on 200"); + assert_eq!(resp.stage_count, 1); + assert_eq!(resp.video.frames, 97); + assert_eq!(resp.video.format, OutputFormat::Mp4); +} + +#[tokio::test] +async fn generate_chain_surfaces_422_as_validation_error() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain")) + .respond_with(ResponseTemplate::new(422).set_body_string("frames must be 8k+1")) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let err = client + .generate_chain(&sample_request()) + .await + .expect_err("422 must error"); + assert!( + matches!(mold_error(&err), MoldError::Validation(msg) if msg.contains("8k+1")), + "422 must translate to MoldError::Validation carrying the body", + ); +} + +#[tokio::test] +async fn generate_chain_translates_404_with_body_to_model_not_found() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain")) + .respond_with(ResponseTemplate::new(404).set_body_string("model 'ltx-2-foo' not found")) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let err = client + .generate_chain(&sample_request()) + .await + .expect_err("404 with body must error"); + assert!( + matches!(mold_error(&err), MoldError::ModelNotFound(msg) if msg.contains("ltx-2-foo")), + "404-with-body must translate to MoldError::ModelNotFound", + ); +} + +#[tokio::test] +async fn generate_chain_empty_404_fails_loudly_instead_of_silently() { + // Non-streaming callers have no fallback path — an empty 404 means the + // server predates render-chain v1, which is a hard error (unlike the + // streaming case where Ok(None) signals "try the non-streaming path"). + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain")) + .respond_with(ResponseTemplate::new(404).set_body_string("")) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let err = client + .generate_chain(&sample_request()) + .await + .expect_err("empty 404 must error on non-streaming path"); + let msg = format!("{err}"); + assert!( + msg.contains("chain endpoint not found"), + "error must name the missing endpoint, got: {msg}", + ); +} + +// ── /api/generate/chain/stream (SSE) ─────────────────────────────────── + +#[tokio::test] +async fn generate_chain_stream_returns_none_on_empty_404() { + // An empty 404 on the streaming endpoint means the server doesn't + // support chain SSE yet — callers are expected to fall back to the + // non-streaming path. + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain/stream")) + .respond_with(ResponseTemplate::new(404).set_body_string("")) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let (tx, _rx) = tokio::sync::mpsc::unbounded_channel::(); + let out = client + .generate_chain_stream(&sample_request(), tx) + .await + .expect("empty 404 should resolve to Ok(None)"); + assert!(out.is_none(), "empty 404 must signal unsupported endpoint"); +} + +#[tokio::test] +async fn generate_chain_stream_parses_progress_and_complete_events() { + let b64 = base64::engine::general_purpose::STANDARD; + let video_bytes = b"FAKE_MP4_BYTES"; + let thumb_bytes = b"THUMB"; + let complete = SseChainCompleteEvent { + video: b64.encode(video_bytes), + format: OutputFormat::Mp4, + width: 1216, + height: 704, + frames: 97, + fps: 24, + thumbnail: Some(b64.encode(thumb_bytes)), + gif_preview: None, + has_audio: false, + duration_ms: Some(4040), + audio_sample_rate: None, + audio_channels: None, + stage_count: 1, + gpu: Some(0), + generation_time_ms: Some(45_000), + }; + let progress = ChainProgressEvent::DenoiseStep { + stage_idx: 0, + step: 4, + total: 8, + }; + // Build a chunk-encoded SSE body carrying one progress event then + // complete. `\n\n` terminates each SSE event. + let body = format!( + "event: progress\ndata: {}\n\nevent: complete\ndata: {}\n\n", + serde_json::to_string(&progress).unwrap(), + serde_json::to_string(&complete).unwrap(), + ); + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/generate/chain/stream")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_string(body), + ) + .mount(&server) + .await; + + let client = MoldClient::new(&server.uri()); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let resp = client + .generate_chain_stream(&sample_request(), tx) + .await + .expect("SSE stream should succeed") + .expect("complete event should yield a response"); + + assert_eq!(resp.stage_count, 1); + assert_eq!(resp.video.data, video_bytes); + assert_eq!(resp.video.thumbnail, thumb_bytes); + assert_eq!(resp.gpu, Some(0)); + let ev = rx.recv().await.expect("progress event should be forwarded"); + assert_eq!(ev, progress); +} diff --git a/crates/mold-inference/src/device.rs b/crates/mold-inference/src/device.rs index a16ca50f..94f6bd4e 100644 --- a/crates/mold-inference/src/device.rs +++ b/crates/mold-inference/src/device.rs @@ -1,6 +1,57 @@ use crate::engine::LoadStrategy; use crate::progress::ProgressReporter; use mold_core::types::GpuSelection; +use std::cell::Cell; + +// ── Thread-local GPU ordinal guard ───────────────────────────────────────── +// +// Each GPU worker thread is pinned to a single ordinal. We stash that ordinal +// in a thread-local so cross-engine hotpaths (`create_device`, `reclaim_gpu_memory`) +// can debug-assert the caller isn't drifting onto a sibling GPU's context — +// the exact footgun that took the process down on killswitch when LTX-2 had +// `reclaim_gpu_memory(0)` hardcoded and nuked GPU 0's context while SD3.5 +// was still denoising there. +// +// Threads without a bound ordinal (tokio blocking pool, tests) see `None` +// and the assert is skipped. + +thread_local! { + static THREAD_GPU_ORDINAL: Cell> = const { Cell::new(None) }; +} + +/// Bind the current thread to a GPU ordinal. Call once from each GPU worker +/// thread's entry point. Any subsequent `create_device` / `reclaim_gpu_memory` +/// call on this thread must match `ordinal` (debug builds only). +pub fn init_thread_gpu_ordinal(ordinal: usize) { + THREAD_GPU_ORDINAL.with(|c| c.set(Some(ordinal))); +} + +/// Clear the thread's GPU binding. Not strictly needed in production (workers +/// run for the process lifetime) but useful for tests that reuse threads. +pub fn clear_thread_gpu_ordinal() { + THREAD_GPU_ORDINAL.with(|c| c.set(None)); +} + +/// Returns the currently-bound ordinal, if any. +pub fn thread_gpu_ordinal() -> Option { + THREAD_GPU_ORDINAL.with(|c| c.get()) +} + +/// Panic in debug builds if `ordinal` doesn't match the thread's bound GPU. +/// A mismatch means a call site is ignoring its engine's `gpu_ordinal` and +/// reaching for another GPU's context — the SD3.5/LTX-2 crash pattern. +#[inline] +fn debug_assert_ordinal_matches_thread(ordinal: usize, context: &'static str) { + if cfg!(debug_assertions) { + if let Some(expected) = thread_gpu_ordinal() { + assert_eq!( + expected, ordinal, + "{context}: ordinal {ordinal} does not match this thread's \ + bound GPU {expected} — hardcoded ordinal regression?" + ); + } + } +} // ── GPU discovery ────────────────────────────────────────────────────────── @@ -107,6 +158,7 @@ pub fn create_device( tracing::info!("CPU forced via MOLD_DEVICE=cpu"); return Ok(Device::Cpu); } + debug_assert_ordinal_matches_thread(ordinal, "create_device"); if candle_core::utils::cuda_is_available() { progress.info(&format!("Using CUDA device {ordinal}")); tracing::info!("Using CUDA device {ordinal}"); @@ -185,13 +237,37 @@ pub fn select_expand_device( gpus: &[DiscoveredGpu], threshold: u64, is_metal: bool, +) -> ExpandPlacement { + select_expand_device_with_preference(gpus, threshold, is_metal, None) +} + +/// Same as [`select_expand_device`], but prefers `preferred_ordinal` when it +/// is in the allowed GPU set and has enough free VRAM. +pub fn select_expand_device_with_preference( + gpus: &[DiscoveredGpu], + threshold: u64, + is_metal: bool, + preferred_ordinal: Option, ) -> ExpandPlacement { if is_metal { + if let Some(ordinal) = preferred_ordinal { + if let Some(g) = gpus.iter().find(|g| g.ordinal == ordinal) { + return ExpandPlacement::Gpu(g.ordinal); + } + } if let Some(g) = gpus.first() { return ExpandPlacement::Gpu(g.ordinal); } return ExpandPlacement::Cpu; } + if let Some(ordinal) = preferred_ordinal { + if let Some(g) = gpus + .iter() + .find(|g| g.ordinal == ordinal && g.free_vram_bytes > threshold) + { + return ExpandPlacement::Gpu(g.ordinal); + } + } for g in gpus { if g.free_vram_bytes > threshold { return ExpandPlacement::Gpu(g.ordinal); @@ -235,12 +311,14 @@ where #[cfg(feature = "cuda")] fn resolve_gpu_ordinal(ordinal: usize) -> anyhow::Result { + debug_assert_ordinal_matches_thread(ordinal, "resolve_device"); candle_core::Device::new_cuda(ordinal) .map_err(|e| anyhow::anyhow!("failed to open CUDA device {ordinal}: {e}")) } #[cfg(all(not(feature = "cuda"), feature = "metal"))] fn resolve_gpu_ordinal(ordinal: usize) -> anyhow::Result { + debug_assert_ordinal_matches_thread(ordinal, "resolve_device"); candle_core::Device::new_metal(ordinal) .map_err(|e| anyhow::anyhow!("failed to open Metal device {ordinal}: {e}")) } @@ -389,6 +467,8 @@ pub fn available_system_memory_bytes() -> Option { pub fn reclaim_gpu_memory(ordinal: usize) { use candle_core::cuda_backend::cudarc::driver::{result, sys}; + debug_assert_ordinal_matches_thread(ordinal, "reclaim_gpu_memory"); + // Synchronize to ensure all async GPU work completes before reset. let _ = result::ctx::synchronize(); @@ -1253,4 +1333,22 @@ mod tests { ExpandPlacement::Cpu, ); } + + #[test] + fn expand_prefers_requested_gpu_when_it_fits() { + let gpus = vec![gpu(0, 20), gpu(1, 20)]; + assert_eq!( + select_expand_device_with_preference(&gpus, 3 * GB, false, Some(1)), + ExpandPlacement::Gpu(1), + ); + } + + #[test] + fn expand_preference_falls_back_when_requested_gpu_cannot_fit() { + let gpus = vec![gpu(0, 20), gpu(1, 1)]; + assert_eq!( + select_expand_device_with_preference(&gpus, 3 * GB, false, Some(1)), + ExpandPlacement::Gpu(0), + ); + } } diff --git a/crates/mold-inference/src/encoders/sd3_clip.rs b/crates/mold-inference/src/encoders/sd3_clip.rs index c396563f..ddb3a89d 100644 --- a/crates/mold-inference/src/encoders/sd3_clip.rs +++ b/crates/mold-inference/src/encoders/sd3_clip.rs @@ -83,18 +83,16 @@ impl ClipWithTokenizer { .ok_or_else(|| anyhow::anyhow!("Failed to tokenize CLIP end-of-text"))?, }; - let mut tokens = self + let raw_tokens = self .tokenizer .encode(prompt, true) .map_err(|e| anyhow::anyhow!("CLIP tokenization failed: {e}"))? .get_ids() .to_vec(); - let eos_position = tokens.len() - 1; + let (tokens, eos_position) = + prepare_clip_tokens(raw_tokens, self.max_position_embeddings, pad_id); - while tokens.len() < self.max_position_embeddings { - tokens.push(pad_id); - } let tokens = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; let (_text_embeddings, text_embeddings_penultimate) = clip.forward_until_encoder_layer(&tokens, usize::MAX, -2)?; @@ -293,3 +291,101 @@ impl SD3TripleEncoder { self.clip_l.model.is_some() && self.clip_g.model.is_some() && self.t5.model.is_some() } } + +/// Prepare a CLIP token sequence for the fixed position-embedding window. +/// +/// CLIP's position-embedding table holds exactly `max_len` entries, so a token +/// tensor longer than that fails inside candle's `broadcast_add` when the +/// position embeddings are applied. This helper: +/// +/// - Truncates overlong sequences to `max_len`, copying the trailing token +/// (the tokenizer's EOS, assuming `add_special_tokens=true`) into the last +/// slot so the pooled-output path still reads an EOS-position hidden state. +/// - Pads short sequences up to `max_len` with `pad_id`. +/// - Returns the final `tokens` vector and the `eos_position` index the caller +/// uses to slice the pooled output. +fn prepare_clip_tokens(mut raw_tokens: Vec, max_len: usize, pad_id: u32) -> (Vec, usize) { + let original_len = raw_tokens.len(); + + if original_len > max_len { + let eos_id = *raw_tokens + .last() + .expect("original_len > max_len implies non-empty"); + raw_tokens.truncate(max_len); + if let Some(last) = raw_tokens.last_mut() { + *last = eos_id; + } + tracing::debug!( + "SD3 CLIP prompt exceeded {} tokens ({} raw); truncated with EOS preserved", + max_len, + original_len, + ); + } + + let eos_position = raw_tokens.len().saturating_sub(1); + + while raw_tokens.len() < max_len { + raw_tokens.push(pad_id); + } + + (raw_tokens, eos_position) +} + +#[cfg(test)] +mod tests { + use super::prepare_clip_tokens; + + const MAX_LEN: usize = 77; + const PAD_ID: u32 = 0; + const EOS_ID: u32 = 49407; + + #[test] + fn pads_short_prompt_to_max_len() { + let raw = vec![49406, 10, 20, 30, EOS_ID]; // 5 tokens, last is EOS + let (tokens, eos) = prepare_clip_tokens(raw, MAX_LEN, PAD_ID); + assert_eq!(tokens.len(), MAX_LEN, "must pad up to max_len"); + assert_eq!(eos, 4, "eos_position tracks the raw EOS slot"); + assert_eq!(tokens[4], EOS_ID, "EOS preserved at original position"); + assert_eq!(tokens[5], PAD_ID, "pads follow the real tokens"); + assert_eq!(*tokens.last().unwrap(), PAD_ID); + } + + #[test] + fn leaves_exact_length_untouched() { + let mut raw: Vec = (1..MAX_LEN as u32).collect(); + raw.push(EOS_ID); + assert_eq!(raw.len(), MAX_LEN); + let (tokens, eos) = prepare_clip_tokens(raw.clone(), MAX_LEN, PAD_ID); + assert_eq!(tokens.len(), MAX_LEN); + assert_eq!(eos, MAX_LEN - 1); + assert_eq!(tokens, raw); + } + + #[test] + fn truncates_overlong_prompt_preserving_eos() { + // 132-token sequence — matches the shapes in the original bug report + // ([1, 132, 768] vs [1, 77, 768]). + let mut raw: Vec = (1..=131).collect(); + raw.push(EOS_ID); + assert_eq!(raw.len(), 132); + + let (tokens, eos) = prepare_clip_tokens(raw, MAX_LEN, PAD_ID); + + assert_eq!(tokens.len(), MAX_LEN, "overlong sequence must be truncated"); + assert_eq!(eos, MAX_LEN - 1, "eos_position must land on the last slot"); + assert_eq!( + tokens[MAX_LEN - 1], + EOS_ID, + "EOS must be preserved in the final slot so pooled output reads EOS hidden state", + ); + } + + #[test] + fn handles_empty_input() { + // Degenerate case: tokenizer somehow returns no ids. Shouldn't panic. + let (tokens, eos) = prepare_clip_tokens(Vec::new(), MAX_LEN, PAD_ID); + assert_eq!(tokens.len(), MAX_LEN); + assert_eq!(eos, 0); + assert!(tokens.iter().all(|t| *t == PAD_ID)); + } +} diff --git a/crates/mold-inference/src/engine.rs b/crates/mold-inference/src/engine.rs index 949f5089..cddafc7a 100644 --- a/crates/mold-inference/src/engine.rs +++ b/crates/mold-inference/src/engine.rs @@ -35,6 +35,17 @@ pub trait InferenceEngine: Send + Sync { fn model_paths(&self) -> Option<&mold_core::ModelPaths> { None } + + /// Returns a [`ChainStageRenderer`] view of this engine if the family + /// supports chained video generation. Default is `None` — only LTX-2 + /// distilled overrides this in v1. + /// + /// Callers (the server chain route) invoke this once per stage to drive + /// [`crate::ltx2::Ltx2ChainOrchestrator::run`]; engines that don't support + /// chaining return `None` and the caller responds with 422. + fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> { + None + } } /// Restores an `Option` slot even if the current scope unwinds. diff --git a/crates/mold-inference/src/expand.rs b/crates/mold-inference/src/expand.rs index 1129397d..ede064f0 100644 --- a/crates/mold-inference/src/expand.rs +++ b/crates/mold-inference/src/expand.rs @@ -14,16 +14,19 @@ use mold_core::expand::{ExpandConfig, ExpandResult, PromptExpander}; use mold_core::expand_prompts::{build_batch_messages, build_single_messages, format_chatml}; use crate::device::{ - discover_gpus, expand_vram_threshold, memory_status_string, preflight_memory_check, - select_expand_device, ExpandPlacement, + discover_gpus, expand_vram_threshold, filter_gpus, memory_status_string, + preflight_memory_check, select_expand_device_with_preference, ExpandPlacement, }; use crate::progress::{ProgressCallback, ProgressReporter}; +use mold_core::types::GpuSelection; /// Local prompt expander using quantized Qwen3 GGUF. pub struct LocalExpander { model_path: PathBuf, tokenizer_path: PathBuf, progress: ProgressReporter, + gpu_selection: GpuSelection, + preferred_gpu: Option, } impl LocalExpander { @@ -33,6 +36,8 @@ impl LocalExpander { model_path: model_path.into(), tokenizer_path: tokenizer_path.into(), progress: ProgressReporter::default(), + gpu_selection: GpuSelection::All, + preferred_gpu: None, } } @@ -41,6 +46,18 @@ impl LocalExpander { self.progress.set_callback(callback); } + /// Restrict local expansion to the GPU ordinals selected by the caller. + pub fn with_gpu_selection(mut self, gpu_selection: GpuSelection) -> Self { + self.gpu_selection = gpu_selection; + self + } + + /// Prefer this GPU ordinal when it is allowed and has enough free VRAM. + pub fn with_preferred_gpu(mut self, preferred_gpu: Option) -> Self { + self.preferred_gpu = preferred_gpu; + self + } + /// Try to create a local expander by finding the model files. /// /// Searches the standard mold models directory for the expand model's @@ -124,9 +141,11 @@ impl LocalExpander { // Cascade: main GPU → remaining GPUs (ordinal order) → CPU. // `discover_gpus()` returns an empty list on CPU-only builds, which // lands us directly on CPU. - let gpus = discover_gpus(); + let discovered = discover_gpus(); + let gpus = filter_gpus(&discovered, &self.gpu_selection); let is_metal = candle_core::utils::metal_is_available(); - let placement = select_expand_device(&gpus, threshold, is_metal); + let placement = + select_expand_device_with_preference(&gpus, threshold, is_metal, self.preferred_gpu); let device = match placement { ExpandPlacement::Gpu(ordinal) => { diff --git a/crates/mold-inference/src/factory.rs b/crates/mold-inference/src/factory.rs index 4a82bfed..7b78df40 100644 --- a/crates/mold-inference/src/factory.rs +++ b/crates/mold-inference/src/factory.rs @@ -181,7 +181,12 @@ pub fn create_engine_with_pool( shared_pool, ))) } - "ltx2" | "ltx-2" => Ok(Box::new(Ltx2Engine::new(model_name, paths, load_strategy))), + "ltx2" | "ltx-2" => Ok(Box::new(Ltx2Engine::new( + model_name, + paths, + load_strategy, + gpu_ordinal, + ))), "wuerstchen" | "wuerstchen-v2" => Ok(Box::new(WuerstchenEngine::new( model_name, paths, diff --git a/crates/mold-inference/src/flux/pipeline.rs b/crates/mold-inference/src/flux/pipeline.rs index 46255056..d225bf95 100644 --- a/crates/mold-inference/src/flux/pipeline.rs +++ b/crates/mold-inference/src/flux/pipeline.rs @@ -317,7 +317,17 @@ fn find_flux_reference_gguf( // Dev candidates satisfy both schnell and dev targets (schnell tensors are a // subset of dev). Schnell candidates only satisfy schnell targets. - let mut candidates: Vec<&str> = vec!["flux-dev:q8", "flux-dev:q6", "flux-dev:q4"]; + // flux-krea is a dev-family fine-tune shipped as complete GGUFs by + // QuantStack, so it carries the full embedding set including guidance_in — + // fall back to it before asking the user to download flux-dev. + let mut candidates: Vec<&str> = vec![ + "flux-dev:q8", + "flux-dev:q6", + "flux-dev:q4", + "flux-krea:q8", + "flux-krea:q6", + "flux-krea:q4", + ]; if !needs_guidance { candidates.extend(["flux-schnell:q8", "flux-schnell:q4"]); } @@ -2589,6 +2599,35 @@ mod tests { std::fs::remove_dir_all(&dir).ok(); } + #[test] + fn find_flux_reference_accepts_krea_when_no_base_dev() { + // flux-krea is a dev-family fine-tune shipped as complete GGUFs — it + // should serve as a reference for city96-format fine-tunes (UltraReal, + // etc.) even when the base flux-dev GGUF isn't downloaded. + let dir = std::env::temp_dir().join(format!( + "mold-ref-krea-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + )); + let models_dir = dir.join("models"); + let krea_dir = models_dir.join("flux-krea-q8"); + std::fs::create_dir_all(&krea_dir).unwrap(); + let krea_path = krea_dir.join("flux1-krea-dev-Q8_0.gguf"); + + let mut complete: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec(); + complete.extend_from_slice(super::FLUX_GUIDANCE_EMBEDDING_TENSORS); + write_test_gguf(&krea_path, &complete); + + let picked = super::find_flux_reference_gguf(true, Some(&models_dir)) + .expect("complete flux-krea reference must be accepted for dev targets"); + assert_eq!(picked, krea_path); + + std::fs::remove_dir_all(&dir).ok(); + } + #[test] fn embedding_tensor_names_are_exhaustive() { // Verify the const arrays cover all non-diffusion-block tensors that diff --git a/crates/mold-inference/src/ltx2/chain.rs b/crates/mold-inference/src/ltx2/chain.rs new file mode 100644 index 00000000..9657f06b --- /dev/null +++ b/crates/mold-inference/src/ltx2/chain.rs @@ -0,0 +1,830 @@ +//! LTX-2 chain carryover primitives. +//! +//! Server-side chained video generation stitches multiple per-clip renders +//! into a single output. To avoid a VAE decode → RGB → VAE encode round-trip +//! between clips (which loses information and doubles VAE cost), the tail of +//! each clip is carried across as latent-space tokens and threaded into the +//! next clip's conditioning directly. +//! +//! This module owns the data types and shape math for that handoff. The +//! orchestrator and the `Ltx2Engine::generate_with_carryover` entry point +//! land in sibling commits. +//! +//! See `tasks/render-chain-v1-plan.md` Phase 1.1 for context. + +use anyhow::{anyhow, bail, Context, Result}; +use candle_core::Tensor; +use image::RgbImage; +use mold_core::chain::{ChainProgressEvent, ChainRequest, ChainStage}; +use mold_core::{GenerateRequest, OutputFormat}; + +use crate::ltx2::model::shapes::SpatioTemporalScaleFactors; + +/// Opaque carryover payload handed from one chain stage to the next. +/// +/// Holds the last `frames` decoded RGB frames of the emitting stage, not the +/// raw tail latents. The receiving stage re-encodes them fresh through the +/// LTX-2 video VAE so every resulting latent slot has correct causal / +/// continuation semantics in the receiving clip's frame of reference — a +/// direct latent slice from the emitting stage's continuation slots would +/// appear at the receiving stage's position 0/1 with slot-meaning mismatched +/// against the VAE's causal-first-frame convention. +/// +/// The VAE encode cost on the receiving side is negligible (≈tens of ms for +/// 17 frames at 704×1216), and it's paid inside a VAE load that's already +/// needed for the source-image anchor path (see pipeline.rs). +#[derive(Debug, Clone)] +pub struct ChainTail { + /// Number of *pixel* frames this tail represents (not latent frames). + /// Clients of [`ChainTail`] work in pixel-frame units because that's + /// what users think in; the latent-frame count is derived from this + /// plus the LTX-2 VAE's 8× causal temporal ratio. + pub frames: u32, + + /// The last `frames` decoded RGB frames of the emitting stage, in + /// capture order. The receiving stage VAE-encodes this contiguous pixel + /// window into `tail_latent_frame_count(frames)` latent slots. Each + /// resulting latent slot then carries correct causal (slot 0, 1 pixel) + /// or continuation (slots 1+, 8 pixels each) semantics for the receiving + /// clip's pinned region — monotonic, forward-in-time, no slot meaning + /// mismatch with the RoPE positions in the receiving clip. + pub tail_rgb_frames: Vec, +} + +/// Number of latent frames corresponding to `pixel_frames` pixel frames +/// under the LTX-2 VAE's 8× causal temporal compression. `1` for +/// `1..=8` pixel frames, `2` for `9..=16`, etc. Matches +/// `VideoLatentShape::from_pixel_shape`. +/// +/// Panics if `pixel_frames == 0` — a zero-frame tail is nonsensical and +/// would under-flow the formula. Callers must validate upstream. +pub fn tail_latent_frame_count(pixel_frames: u32) -> usize { + assert!( + pixel_frames > 0, + "tail_latent_frame_count: pixel_frames must be > 0", + ); + let scale = SpatioTemporalScaleFactors::default().time; + ((pixel_frames as usize - 1) / scale) + 1 +} + +/// Slice the last `tail_latent_frame_count(pixel_frames)` frames off the +/// time axis of a rank-5 video-latents tensor shaped +/// `[B, C, T, H, W]`. +/// +/// The returned tensor is a view/narrow on the input (no copy on candle's +/// current backends) so callers who intend to hand it to a separate engine +/// invocation — which may drop this engine's state and rebuild it — should +/// `.contiguous()` or `.copy()` the result before the original owner goes +/// out of scope. +/// +/// Errors if the tensor is not rank-5 or the requested tail exceeds the +/// available time axis — the latter would mean the orchestrator asked for +/// more tail than the stage produced, which indicates a caller bug. +/// +/// Kept after the v1.1 decoded-pixel-carryover switch because the utility +/// still reads cleanly from tests and is useful for ad-hoc debugging / +/// future experiments, but the production chain path no longer calls it. +#[allow(dead_code)] +pub fn extract_tail_latents(final_latents: &Tensor, pixel_frames: u32) -> Result { + let dims = final_latents.dims(); + if dims.len() != 5 { + return Err(anyhow!( + "extract_tail_latents: expected rank-5 tensor [B, C, T, H, W], got shape {:?}", + dims, + )); + } + let time = dims[2]; + let tail = tail_latent_frame_count(pixel_frames); + if tail > time { + return Err(anyhow!( + "extract_tail_latents: tail requests {} latent frames but the stage emitted only {} \ + (pixel_frames={}, tensor shape={:?})", + tail, + time, + pixel_frames, + dims, + )); + } + let start = time - tail; + final_latents + .narrow(2, start, tail) + .with_context(|| format!("narrow last {tail} latent frames off time axis")) +} + +// ── Orchestrator: loops stages, drops motion-tail prefix, accumulates frames + +/// Per-stage progress events the orchestrator observes from the renderer. +/// The renderer emits these synchronously while a stage is denoising; the +/// orchestrator wraps them with `stage_idx` before forwarding as +/// [`ChainProgressEvent`]s to the chain-level subscriber. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StageProgressEvent { + /// Denoise step `step` of `total` completed for the active stage. + DenoiseStep { step: u32, total: u32 }, +} + +/// Output of a single stage render: the decoded pixel frames (full clip, +/// before motion-tail trim), the pre-VAE-decode latent tail the next stage +/// needs, and the wall-clock elapsed time for the render. +#[derive(Debug)] +pub struct StageOutcome { + pub frames: Vec, + pub tail: ChainTail, + pub generation_time_ms: u64, +} + +/// Abstraction over "render one chain stage". Production uses the LTX-2 +/// engine impl (lands in Phase 1d); tests inject a fake implementation +/// that fabricates deterministic frames and a synthetic [`ChainTail`] +/// without loading candle weights. +pub trait ChainStageRenderer { + fn render_stage( + &mut self, + stage_req: &GenerateRequest, + carry: Option<&ChainTail>, + motion_tail_pixel_frames: u32, + stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, + ) -> Result; +} + +/// Output of an end-to-end chain run: accumulated RGB frames with motion- +/// tail prefix already trimmed on continuations, the number of stages +/// that ran, and the total elapsed render time. +/// +/// The orchestrator does *not* trim to a target total frame count or +/// encode the frames into an output video — those are the caller's job +/// (server / CLI). Keeps the orchestrator single-purpose: produce a +/// coherent frame stream from a stages list. +#[derive(Debug)] +pub struct ChainRunOutput { + pub frames: Vec, + pub stage_count: u32, + pub generation_time_ms: u64, +} + +/// Drives the per-stage render loop for a chained generation. Borrows its +/// renderer mutably so the loop can re-enter the engine on the same GPU +/// context across stages. +pub struct Ltx2ChainOrchestrator<'a, R: ChainStageRenderer + ?Sized> { + renderer: &'a mut R, +} + +impl<'a, R: ChainStageRenderer + ?Sized> Ltx2ChainOrchestrator<'a, R> { + pub fn new(renderer: &'a mut R) -> Self { + Self { renderer } + } + + /// Run every stage in `req.stages` and return the accumulated frames. + /// + /// Behaviour invariants (from the 2026-04-20 sign-off, amended 2026-04-21): + /// - Per-stage seeds default to the shared `base_seed` so the continuation + /// denoise starts from matching noise. Stages can opt in to variation by + /// setting `seed_offset`, which XORs into the base seed. + /// - Stage 0's output is kept whole; continuations drop their leading + /// `req.motion_tail_frames` pixel frames because those duplicate the + /// prior stage's tail that was threaded back as latent conditioning. + /// - Mid-chain failure returns the error immediately; partial frames are + /// discarded (no partial stitch is ever produced in v1). + pub fn run( + &mut self, + req: &ChainRequest, + mut chain_progress: Option<&mut dyn FnMut(ChainProgressEvent)>, + ) -> Result { + if req.stages.is_empty() { + bail!("Ltx2ChainOrchestrator::run: chain request has no stages"); + } + validate_motion_tail(req)?; + + let stage_count = req.stages.len() as u32; + let estimated_total_frames = estimate_stitched_frames(req); + if let Some(cb) = chain_progress.as_deref_mut() { + cb(ChainProgressEvent::ChainStart { + stage_count, + estimated_total_frames, + }); + } + + let base_seed = req.seed.unwrap_or(0); + let motion_tail_drop = req.motion_tail_frames as usize; + let mut accumulated_frames: Vec = Vec::new(); + let mut total_generation_ms: u64 = 0; + let mut carry: Option = None; + + for (idx, stage) in req.stages.iter().enumerate() { + let stage_idx = idx as u32; + if let Some(cb) = chain_progress.as_deref_mut() { + cb(ChainProgressEvent::StageStart { stage_idx }); + } + + let stage_seed = derive_stage_seed(base_seed, idx, stage); + let stage_req = build_stage_generate_request(stage, req, stage_seed, idx); + + // Wrap the chain progress subscriber so per-stage denoise + // events land on it with `stage_idx` tagged in. The wrapping + // closure holds a mutable reborrow of the outer callback for + // just the duration of this call — `render_stage` is + // synchronous so the reborrow ends before the next iteration. + let outcome = match chain_progress.as_deref_mut() { + Some(chain_cb) => { + let mut wrapping = |event: StageProgressEvent| match event { + StageProgressEvent::DenoiseStep { step, total } => { + chain_cb(ChainProgressEvent::DenoiseStep { + stage_idx, + step, + total, + }); + } + }; + self.renderer.render_stage( + &stage_req, + carry.as_ref(), + req.motion_tail_frames, + Some(&mut wrapping), + )? + } + None => self.renderer.render_stage( + &stage_req, + carry.as_ref(), + req.motion_tail_frames, + None, + )?, + }; + + let mut frames = outcome.frames; + if idx > 0 && motion_tail_drop > 0 { + if motion_tail_drop >= frames.len() { + bail!( + "stage {stage_idx}: emitted {} frames but motion_tail_drop={motion_tail_drop} — tail would consume the whole clip", + frames.len(), + ); + } + frames.drain(..motion_tail_drop); + } + let frames_emitted = frames.len() as u32; + accumulated_frames.extend(frames); + total_generation_ms = total_generation_ms.saturating_add(outcome.generation_time_ms); + carry = Some(outcome.tail); + + if let Some(cb) = chain_progress.as_deref_mut() { + cb(ChainProgressEvent::StageDone { + stage_idx, + frames_emitted, + }); + } + } + + if let Some(cb) = chain_progress.as_mut() { + cb(ChainProgressEvent::Stitching { + total_frames: accumulated_frames.len() as u32, + }); + } + + Ok(ChainRunOutput { + frames: accumulated_frames, + stage_count, + generation_time_ms: total_generation_ms, + }) + } +} + +fn validate_motion_tail(req: &ChainRequest) -> Result<()> { + for (idx, stage) in req.stages.iter().enumerate() { + if req.motion_tail_frames >= stage.frames { + bail!( + "motion_tail_frames ({}) must be strictly less than stage {idx}'s frames ({}) \ + so every continuation emits at least one new frame", + req.motion_tail_frames, + stage.frames, + ); + } + } + Ok(()) +} + +fn estimate_stitched_frames(req: &ChainRequest) -> u32 { + // delivered = stages[0].frames + Σ (stages[i].frames - motion_tail) for i >= 1 + let tail = req.motion_tail_frames; + req.stages + .iter() + .enumerate() + .map(|(idx, stage)| { + if idx == 0 { + stage.frames + } else { + stage.frames.saturating_sub(tail) + } + }) + .sum() +} + +fn derive_stage_seed(base_seed: u64, _idx: usize, stage: &ChainStage) -> u64 { + // Keep the seed stable across stages by default. An earlier revision + // XORed `(idx as u64) << 32` into each stage's seed so the initial + // noise tensor differed per clip; with the motion tail now re-encoded + // from the emitting stage's trailing RGB frames (see `ChainTail` + + // `StagedLatent`) the pinned region is frozen by `video_denoise_mask` + // anyway, so same-seed noise in the pinned tokens is a no-op, and + // same-seed noise in the free region lets the continuation settle on a + // consistent motion profile. Callers who want per-stage variation + // supply `stage.seed_offset` explicitly. + if let Some(offset) = stage.seed_offset { + base_seed ^ offset + } else { + base_seed + } +} + +fn build_stage_generate_request( + stage: &ChainStage, + chain: &ChainRequest, + stage_seed: u64, + idx: usize, +) -> GenerateRequest { + GenerateRequest { + prompt: stage.prompt.clone(), + negative_prompt: stage.negative_prompt.clone(), + model: chain.model.clone(), + width: chain.width, + height: chain.height, + steps: chain.steps, + guidance: chain.guidance, + seed: Some(stage_seed), + batch_size: 1, + // Continuation stages never use the per-chain output_format + // downstream — the orchestrator decodes to frames regardless — + // but MP4 is the canonical intermediate for LTX-2. + output_format: OutputFormat::Mp4, + embed_metadata: None, + scheduler: None, + // Every stage carries the starting image. Stage 0 uses it as the + // i2v replacement at frame 0; continuation stages have their + // frame-0 slot pinned by the motion-tail carryover latent, so + // `render_chain_stage` re-routes the staged image into the append + // path at a non-zero frame with soft strength — turning it into a + // durable identity anchor rather than a frame-0 replacement. + source_image: stage.source_image.clone(), + edit_images: None, + // Replacement strength from the chain request is only meaningful + // for stage 0's frame-0 i2v pin. Continuations override this at + // `render_chain_stage` time (the anchor uses a lower soft- + // strength constant there), so the value we plant here is inert + // on continuations. + strength: if idx == 0 { chain.strength } else { 1.0 }, + mask_image: None, + control_image: None, + control_model: None, + control_scale: 1.0, + expand: None, + original_prompt: None, + lora: None, + frames: Some(stage.frames), + fps: Some(chain.fps), + upscale_model: None, + gif_preview: false, + enable_audio: Some(false), // v1 chain: no audio plumbing yet + audio_file: None, + source_video: None, + keyframes: None, + pipeline: None, + loras: None, + retake_range: None, + spatial_upscale: None, + temporal_upscale: None, + placement: chain.placement.clone(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use candle_core::{DType, Device}; + + #[test] + fn tail_latent_frame_count_matches_vae_formula() { + // Single-frame tail and up to 8 pixel frames fit in 1 latent frame + // (LTX-2 VAE uses causal first frame + 8× temporal compression). + for px in [1u32, 2, 4, 8] { + assert_eq!(tail_latent_frame_count(px), 1, "{px} pixel frames"); + } + // 9..=16 span 2 latent frames, 17..=24 span 3, etc. + assert_eq!(tail_latent_frame_count(9), 2); + assert_eq!(tail_latent_frame_count(16), 2); + assert_eq!(tail_latent_frame_count(17), 3); + assert_eq!(tail_latent_frame_count(24), 3); + // Full-clip tail (97 frames) → 13 latent frames, matching + // VideoLatentShape::from_pixel_shape under the same VAE ratio. + assert_eq!(tail_latent_frame_count(97), 13); + } + + #[test] + #[should_panic(expected = "pixel_frames must be > 0")] + fn tail_latent_frame_count_rejects_zero() { + tail_latent_frame_count(0); + } + + #[test] + fn extract_tail_narrows_last_latent_frame_for_4_pixel_frame_tail() { + // Build a synthetic [1, 2, 3, 1, 1] where channel 0 is the latent- + // frame index and channel 1 is a sentinel (42, 43, 44) so we can + // see which frames the narrow returns. + let data = vec![ + // frame 0 + 0.0f32, 42.0, // frame 1 + 1.0, 43.0, // frame 2 + 2.0, 44.0, + ]; + // Arrange [B=1, C=2, T=3, H=1, W=1]. `Tensor::from_vec` fills in + // row-major order — the permute below puts channels on axis 1. + let raw = Tensor::from_vec(data, (1, 3, 2, 1, 1), &Device::Cpu).expect("build raw tensor"); + // Reshape [1, T, C, H, W] → [1, C, T, H, W] + let latents = raw + .permute([0, 2, 1, 3, 4]) + .expect("permute to [B, C, T, H, W]"); + assert_eq!(latents.dims(), &[1, 2, 3, 1, 1]); + + // tail_latent_frame_count(4) = 1 → take the last latent frame only. + let tail = extract_tail_latents(&latents, 4).expect("extract"); + assert_eq!(tail.dims(), &[1, 2, 1, 1, 1]); + let values = tail.flatten_all().unwrap().to_vec1::().unwrap(); + assert_eq!( + values, + vec![2.0, 44.0], + "tail must be the last latent frame (index 2) across all channels", + ); + } + + #[test] + fn extract_tail_narrows_two_frames_for_9_pixel_frame_tail() { + // Simple rank-5 zero tensor with T=3; narrowing the last 2 frames + // out of 3 is enough to verify the shape without wrestling with + // permutations again. + let latents = Tensor::zeros((1, 1, 3, 2, 2), DType::F32, &Device::Cpu).unwrap(); + let tail = extract_tail_latents(&latents, 9).expect("extract"); + assert_eq!(tail.dims(), &[1, 1, 2, 2, 2]); + } + + #[test] + fn extract_tail_rejects_rank_4_tensor() { + let bad = Tensor::zeros((1, 128, 3, 4), DType::F32, &Device::Cpu).unwrap(); + let err = extract_tail_latents(&bad, 4).expect_err("rank 4 must fail"); + let msg = format!("{err}"); + assert!( + msg.contains("rank-5") && msg.contains("T, H, W"), + "error must identify the rank mismatch, got: {msg}", + ); + } + + #[test] + fn extract_tail_rejects_oversize_request() { + // Tensor has 1 latent frame; asking for a 9-pixel-frame tail needs 2. + let latents = Tensor::zeros((1, 128, 1, 4, 4), DType::F32, &Device::Cpu).unwrap(); + let err = extract_tail_latents(&latents, 9).expect_err("oversize tail must fail"); + let msg = format!("{err}"); + assert!( + msg.contains("requests 2") && msg.contains("only 1"), + "error must name the latent-frame mismatch, got: {msg}", + ); + } + + // ── Orchestrator tests (fake renderer, weight-free) ─────────────── + + use image::Rgb; + use mold_core::chain::ChainStage; + + /// Deterministic fake renderer for orchestrator tests. Records every + /// call so assertions can inspect the per-stage request shape, emits + /// a solid-color frame block plus a zero-valued latent tail, and + /// optionally returns errors on pre-configured stage indices. + struct FakeRenderer { + calls: Vec, + /// If set, fail on the listed stage indices with the given message. + fail_on: Vec<(usize, String)>, + /// Per-call override of frame count (default: use stage_req.frames). + frame_count_override: Option, + /// If true, emit one DenoiseStep event per stage so tests can + /// verify progress forwarding. + emit_progress: bool, + } + + #[derive(Debug, Clone)] + struct CallRecord { + seed: Option, + has_source_image: bool, + has_carry: bool, + } + + impl FakeRenderer { + fn new() -> Self { + Self { + calls: Vec::new(), + fail_on: Vec::new(), + frame_count_override: None, + emit_progress: false, + } + } + } + + impl ChainStageRenderer for FakeRenderer { + fn render_stage( + &mut self, + stage_req: &GenerateRequest, + carry: Option<&ChainTail>, + _motion_tail_pixel_frames: u32, + mut stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, + ) -> Result { + let idx = self.calls.len(); + self.calls.push(CallRecord { + seed: stage_req.seed, + has_source_image: stage_req.source_image.is_some(), + has_carry: carry.is_some(), + }); + if let Some((_, msg)) = self.fail_on.iter().find(|(stage_idx, _)| *stage_idx == idx) { + bail!("{msg}"); + } + if self.emit_progress { + if let Some(cb) = stage_progress.as_mut() { + cb(StageProgressEvent::DenoiseStep { step: 1, total: 1 }); + } + } + + let frame_count = self + .frame_count_override + .unwrap_or_else(|| stage_req.frames.expect("fake renderer: stage_req.frames")); + let width = stage_req.width; + let height = stage_req.height; + // Colour the frames with the stage index so assertions can + // verify which stage a frame came from. + let mut frames = Vec::with_capacity(frame_count as usize); + for frame_num in 0..frame_count { + let channel = (idx as u8).wrapping_mul(37).wrapping_add(frame_num as u8); + frames.push(RgbImage::from_pixel(width, height, Rgb([channel, 0, 0]))); + } + + // Synthesize a 4-pixel-frame tail from the trailing RGB frames + // so orchestrator tests can assert on the count/shape without + // loading a real VAE. + let tail_pixel_frames: u32 = 4; + let take_from = frames + .len() + .saturating_sub(tail_pixel_frames as usize) + .min(frames.len()); + let tail_rgb_frames = frames[take_from..].to_vec(); + + Ok(StageOutcome { + frames, + tail: ChainTail { + frames: tail_pixel_frames, + tail_rgb_frames, + }, + generation_time_ms: 100, + }) + } + } + + fn stage(prompt: &str, frames: u32) -> ChainStage { + ChainStage { + prompt: prompt.into(), + frames, + source_image: None, + negative_prompt: None, + seed_offset: None, + } + } + + fn chain_req(stages: Vec, motion_tail_frames: u32) -> ChainRequest { + ChainRequest { + model: "ltx-2-19b-distilled:fp8".into(), + stages, + motion_tail_frames, + width: 1216, + height: 704, + fps: 24, + seed: Some(42), + steps: 8, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Mp4, + placement: None, + prompt: None, + total_frames: None, + clip_frames: None, + source_image: None, + } + } + + #[test] + fn chain_runs_all_stages_and_drops_tail_prefix_from_continuations() { + let stages = vec![stage("a", 97), stage("a", 97), stage("a", 97)]; + let req = chain_req(stages, 4); + let mut renderer = FakeRenderer::new(); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let out = orch.run(&req, None).expect("chain runs"); + // Stage 0 keeps all 97 frames; each continuation drops the + // leading 4 frames, so delivered = 97 + 2 * (97 - 4) = 97 + 186 = 283. + assert_eq!(out.frames.len(), 97 + 93 * 2); + assert_eq!(out.stage_count, 3); + assert_eq!(renderer.calls.len(), 3); + // Stage 0 has no carry; later stages do. + assert!(!renderer.calls[0].has_carry); + assert!(renderer.calls[1].has_carry); + assert!(renderer.calls[2].has_carry); + } + + #[test] + fn chain_with_zero_tail_concats_full_clips_without_drop() { + let stages = vec![stage("a", 97), stage("a", 97)]; + let req = chain_req(stages, 0); + let mut renderer = FakeRenderer::new(); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let out = orch.run(&req, None).expect("chain runs"); + assert_eq!( + out.frames.len(), + 97 * 2, + "zero motion tail must keep every frame on continuations", + ); + } + + #[test] + fn chain_empty_stages_errors_without_calling_renderer() { + let req = chain_req(vec![], 4); + let mut renderer = FakeRenderer::new(); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let err = orch.run(&req, None).expect_err("empty stages must fail"); + assert!( + format!("{err}").contains("has no stages"), + "error must name the missing stages, got: {err}", + ); + assert!(renderer.calls.is_empty()); + } + + #[test] + fn chain_fails_closed_mid_chain_discarding_accumulated_frames() { + // Signed-off decision 2026-04-20: mid-chain failure returns the + // error immediately and throws away any frames already produced. + // No partial stitch is ever written to the gallery. + let stages = vec![stage("a", 97), stage("a", 97), stage("a", 97)]; + let req = chain_req(stages, 4); + let mut renderer = FakeRenderer::new(); + renderer.fail_on = vec![(1, "simulated GPU OOM on stage 1".into())]; + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let err = orch + .run(&req, None) + .expect_err("mid-chain failure must bubble up"); + assert!( + format!("{err}").contains("simulated GPU OOM"), + "error must carry the renderer's message, got: {err}", + ); + // Stage 0 ran (recorded), stage 1 failed (recorded before bail), + // stage 2 never ran. + assert_eq!(renderer.calls.len(), 2); + } + + #[test] + fn chain_holds_seed_stable_across_stages_by_default() { + let stages = vec![stage("a", 9), stage("a", 9), stage("a", 9)]; + let mut req = chain_req(stages, 0); + req.seed = Some(42); + let mut renderer = FakeRenderer::new(); + renderer.frame_count_override = Some(9); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + orch.run(&req, None).expect("chain runs"); + // The orchestrator used to XOR `(idx as u64) << 32` into each + // stage's seed so initial noise differed per clip. With the + // motion-tail pin grounded on a proper causal-first latent, + // per-stage noise diversity just amplifies drift at the stitch + // point — same-seed noise stays frozen in the pinned region and + // produces a more consistent motion profile in the free region. + assert_eq!(renderer.calls[0].seed, Some(42)); + assert_eq!(renderer.calls[1].seed, Some(42)); + assert_eq!(renderer.calls[2].seed, Some(42)); + } + + #[test] + fn chain_propagates_source_image_to_every_stage() { + // Every stage must receive the starting image in its GenerateRequest. + // Stage 0 uses it as the frame-0 i2v replacement; continuations use + // it at the engine level as a soft identity anchor (routed through + // the append path by `Ltx2Engine::render_chain_stage`). Identity + // drift past the first clip was traced to the prior behaviour of + // dropping the image on continuations — no long-range identity + // anchor meant each continuation was anchored only to the drifted + // last frame of the prior clip, compounding errors stage-over-stage. + let mut stages = vec![stage("a", 9), stage("a", 9)]; + stages[0].source_image = Some(vec![0x89, 0x50, 0x4e, 0x47]); // PNG magic + stages[1].source_image = Some(vec![0x89, 0x50, 0x4e, 0x47]); + let req = chain_req(stages, 0); + let mut renderer = FakeRenderer::new(); + renderer.frame_count_override = Some(9); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + orch.run(&req, None).expect("chain runs"); + assert!( + renderer.calls[0].has_source_image, + "stage 0 must carry source_image (frame-0 i2v replacement)", + ); + assert!( + renderer.calls[1].has_source_image, + "continuation stage must also carry source_image (soft identity anchor)", + ); + } + + #[test] + fn chain_forwards_engine_events_with_stage_idx_wrapping() { + let stages = vec![stage("a", 9), stage("a", 9)]; + let req = chain_req(stages, 0); + let mut renderer = FakeRenderer::new(); + renderer.frame_count_override = Some(9); + renderer.emit_progress = true; + + let mut events: Vec = Vec::new(); + { + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let mut cb = |e: ChainProgressEvent| events.push(e); + orch.run(&req, Some(&mut cb)).expect("chain runs"); + } + + // Expected order: + // ChainStart, StageStart(0), DenoiseStep(0), StageDone(0), + // StageStart(1), DenoiseStep(1), StageDone(1), Stitching + assert!(matches!( + events[0], + ChainProgressEvent::ChainStart { stage_count: 2, .. } + )); + assert!(matches!( + events[1], + ChainProgressEvent::StageStart { stage_idx: 0 } + )); + assert!(matches!( + events[2], + ChainProgressEvent::DenoiseStep { + stage_idx: 0, + step: 1, + total: 1 + } + )); + assert!(matches!( + events[3], + ChainProgressEvent::StageDone { + stage_idx: 0, + frames_emitted: 9 + } + )); + assert!(matches!( + events[4], + ChainProgressEvent::StageStart { stage_idx: 1 } + )); + assert!(matches!( + events[5], + ChainProgressEvent::DenoiseStep { + stage_idx: 1, + step: 1, + total: 1 + } + )); + assert!(matches!( + events[6], + ChainProgressEvent::StageDone { + stage_idx: 1, + frames_emitted: 9 + } + )); + assert!(matches!( + events[7], + ChainProgressEvent::Stitching { total_frames: 18 } + )); + assert_eq!(events.len(), 8); + } + + #[test] + fn chain_rejects_motion_tail_ge_stage_frames_before_running() { + let stages = vec![stage("a", 9), stage("a", 9)]; + // tail=9 equals stage frames — no net-new content on continuation. + let req = chain_req(stages, 9); + let mut renderer = FakeRenderer::new(); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + let err = orch.run(&req, None).expect_err("must fail"); + assert!( + format!("{err}").contains("motion_tail_frames"), + "error must name motion_tail_frames, got: {err}", + ); + // Renderer never gets called because validation runs up-front. + assert!(renderer.calls.is_empty()); + } + + #[test] + fn chain_respects_seed_offset_override_when_stage_provides_one() { + let mut stages = vec![stage("a", 9), stage("a", 9)]; + stages[1].seed_offset = Some(0xDEADBEEF); + let mut req = chain_req(stages, 0); + req.seed = Some(100); + let mut renderer = FakeRenderer::new(); + renderer.frame_count_override = Some(9); + let mut orch = Ltx2ChainOrchestrator::new(&mut renderer); + orch.run(&req, None).expect("runs"); + assert_eq!(renderer.calls[0].seed, Some(100)); + assert_eq!( + renderer.calls[1].seed, + Some(100 ^ 0xDEADBEEFu64), + "seed_offset must XOR into the stable base seed when a stage opts in to variation", + ); + } +} diff --git a/crates/mold-inference/src/ltx2/conditioning.rs b/crates/mold-inference/src/ltx2/conditioning.rs index b0b8d0e7..d40c3a47 100644 --- a/crates/mold-inference/src/ltx2/conditioning.rs +++ b/crates/mold-inference/src/ltx2/conditioning.rs @@ -1,4 +1,5 @@ use anyhow::{bail, Result}; +use image::RgbImage; use mold_core::{GenerateRequest, TimeRange}; use std::fs; use std::ops::RangeInclusive; @@ -11,9 +12,45 @@ pub(crate) struct StagedImage { pub(crate) strength: f32, } -#[derive(Debug, Clone, PartialEq)] +/// Pre-decoded RGB frame window that the runtime re-encodes through the +/// video VAE into conditioning tokens. Populated by the render-chain +/// orchestrator with the trailing frames of the emitting stage; empty for +/// every non-chain caller today. +/// +/// Re-encoding on the receiving side (rather than narrowing the emitting +/// stage's final latent tensor) is what keeps slot semantics correct: the +/// first latent produced from `tail_rgb_frames` is a proper causal 1-pixel +/// encoding and subsequent latents are proper 8-pixel continuation +/// encodings at the receiving clip's own time axis, with no ambiguity +/// about which latent slot corresponds to which pixel-frame range. +#[derive(Debug, Clone)] +pub(crate) struct StagedLatent { + /// Contiguous, in-capture-order RGB frames from the end of the emitting + /// stage. Must be non-empty; the receiving runtime VAE-encodes them + /// into `tail_latent_frame_count(tail_rgb_frames.len())` latent slots. + pub(crate) tail_rgb_frames: Vec, + /// Starting pixel frame for this latent block. `0` routes the tokens + /// through `StageVideoConditioning::replacements`; non-zero values + /// build a `VideoTokenAppendCondition` like the keyframe image path. + pub(crate) frame: u32, + /// Replacement/append strength. `1.0` for chain motion-tail carryover + /// (hard-overwrite), matching the keyframe image strength convention. + pub(crate) strength: f32, +} + +/// Conditioning inputs staged for a single run. Carries both disk-backed +/// files (images, audio, reference video — existing single-clip flow) and +/// in-memory RGB frame blocks (chain carryover — new, empty for non-chain +/// callers). +/// +/// Not `PartialEq` because `StagedLatent` wraps `image::RgbImage` which +/// doesn't implement meaningful structural equality beyond raw byte +/// comparison. Existing tests only compare individual fields so this is +/// safe to drop. +#[derive(Debug, Clone)] pub(crate) struct StagedConditioning { pub(crate) images: Vec, + pub(crate) latents: Vec, pub(crate) audio_path: Option, pub(crate) video_path: Option, } @@ -99,6 +136,7 @@ pub(crate) fn stage_conditioning( Ok(StagedConditioning { images, + latents: Vec::new(), audio_path, video_path, }) @@ -224,6 +262,29 @@ mod tests { assert!(mask[18..].iter().all(|value| *value == 0.0)); } + #[test] + fn stage_conditioning_leaves_latents_empty_for_non_chain_callers() { + // Single-clip callers build `StagedConditioning` via this function; + // the `latents` field (used by the render-chain orchestrator to inject + // pre-encoded motion-tail tokens) must stay empty so existing runs + // keep routing conditioning through the image path with VAE encode. + let work_dir = tempfile::tempdir().unwrap(); + let mut req = req(); + req.source_image = Some(fake_png_bytes()); + req.keyframes = Some(vec![KeyframeCondition { + frame: 8, + image: fake_png_bytes(), + }]); + req.source_video = Some(fake_mp4_bytes()); + req.audio_file = Some(fake_wav_bytes()); + + let staged = stage_conditioning(&req, work_dir.path()).unwrap(); + assert!( + staged.latents.is_empty(), + "non-chain callers must leave latents empty", + ); + } + #[test] fn stage_conditioning_stages_source_image_as_frame_zero_replacement() { let work_dir = tempfile::tempdir().unwrap(); diff --git a/crates/mold-inference/src/ltx2/execution.rs b/crates/mold-inference/src/ltx2/execution.rs index b0624a2a..a76318ac 100644 --- a/crates/mold-inference/src/ltx2/execution.rs +++ b/crates/mold-inference/src/ltx2/execution.rs @@ -268,7 +268,7 @@ mod tests { } fn engine(model_name: &str, paths: ModelPaths) -> Ltx2Engine { - Ltx2Engine::new(model_name.to_string(), paths, LoadStrategy::Sequential) + Ltx2Engine::new(model_name.to_string(), paths, LoadStrategy::Sequential, 0) } #[test] diff --git a/crates/mold-inference/src/ltx2/mod.rs b/crates/mold-inference/src/ltx2/mod.rs index d2101d33..ac0c5b6e 100644 --- a/crates/mold-inference/src/ltx2/mod.rs +++ b/crates/mold-inference/src/ltx2/mod.rs @@ -1,5 +1,6 @@ mod assets; mod backend; +pub mod chain; mod conditioning; mod execution; mod guidance; @@ -13,4 +14,8 @@ mod runtime; mod sampler; mod text; +pub use chain::{ + extract_tail_latents, tail_latent_frame_count, ChainRunOutput, ChainStageRenderer, ChainTail, + Ltx2ChainOrchestrator, StageOutcome, StageProgressEvent, +}; pub use pipeline::Ltx2Engine; diff --git a/crates/mold-inference/src/ltx2/pipeline.rs b/crates/mold-inference/src/ltx2/pipeline.rs index 1f584d14..1dc7644a 100644 --- a/crates/mold-inference/src/ltx2/pipeline.rs +++ b/crates/mold-inference/src/ltx2/pipeline.rs @@ -11,7 +11,8 @@ use std::time::Instant; use super::assets; use super::backend::Ltx2Backend; -use super::conditioning; +use super::chain::{ChainStageRenderer, ChainTail, StageOutcome, StageProgressEvent}; +use super::conditioning::{self, StagedLatent}; use super::execution; use super::lora; use super::media::{self, ProbeMetadata}; @@ -24,6 +25,14 @@ use crate::engine::{gpu_dtype, rand_seed, InferenceEngine, LoadStrategy}; use crate::ltx_video::video_enc; use crate::progress::ProgressCallback; +/// Soft-conditioning strength for the cross-stage identity anchor on chain +/// continuations. The denoise mask at the anchor token becomes +/// `1 - strength = 0.6`, so the denoiser blends ~60% generated / ~40% +/// reference on every step — a gentle pull toward the source image rather +/// than a hard pin (hard-pinning a single pixel frame past the motion tail +/// would make continuations feel like cuts back to the starting shot). +const CHAIN_SOFT_ANCHOR_STRENGTH: f32 = 0.4; + pub struct Ltx2Engine { model_name: String, paths: ModelPaths, @@ -31,6 +40,11 @@ pub struct Ltx2Engine { native_runtime: Option, on_progress: Option, pending_placement: Option, + /// GPU ordinal this engine is pinned to. Every `Device::new_cuda` and + /// `reclaim_gpu_memory` call must use this ordinal — hardcoding `0` here + /// is what took down the process on killswitch when LTX-2 ran alongside + /// SD3.5 on a multi-GPU host. + gpu_ordinal: usize, } impl Ltx2Engine { @@ -51,7 +65,12 @@ impl Ltx2Engine { } } - pub fn new(model_name: String, paths: ModelPaths, _load_strategy: LoadStrategy) -> Self { + pub fn new( + model_name: String, + paths: ModelPaths, + _load_strategy: LoadStrategy, + gpu_ordinal: usize, + ) -> Self { Self { model_name, paths, @@ -59,6 +78,7 @@ impl Ltx2Engine { native_runtime: None, on_progress: None, pending_placement: None, + gpu_ordinal, } } @@ -75,6 +95,7 @@ impl Ltx2Engine { native_runtime: Some(runtime), on_progress: None, pending_placement: None, + gpu_ordinal: 0, } } @@ -217,7 +238,7 @@ impl Ltx2Engine { match backend { Ltx2Backend::Cuda => { self.info("CUDA detected, using native LTX-2 GPU path"); - let device = Device::new_cuda(0)?; + let device = Device::new_cuda(self.gpu_ordinal)?; configure_native_ltx2_cuda_device(&device)?; Ok(device) } @@ -258,9 +279,16 @@ impl Ltx2Engine { )?; Self::log_timing("pipeline.create_runtime.load_prompt_encoder", load_start); if prompt_device.is_cuda() { - Ok(Ltx2RuntimeSession::new_deferred_cuda(prompt_encoder)) + Ok(Ltx2RuntimeSession::new_deferred_cuda( + prompt_encoder, + self.gpu_ordinal, + )) } else { - Ok(Ltx2RuntimeSession::new(device, prompt_encoder)) + Ok(Ltx2RuntimeSession::new( + device, + prompt_encoder, + self.gpu_ordinal, + )) } } @@ -291,7 +319,7 @@ impl Ltx2Engine { self.info( "Native LTX-2 prompt path ran out of CUDA memory; retrying with CPU fallback", ); - crate::device::reclaim_gpu_memory(0); + crate::device::reclaim_gpu_memory(self.gpu_ordinal); self.load_runtime_session_on_device(plan, Device::Cpu) } Err(err) => Err(err), @@ -436,9 +464,16 @@ impl Ltx2Engine { plan.prompt_tokens.unconditional.valid_len() )); let create_runtime_start = Instant::now(); + // Reuse a persisted runtime only if it can serve this plan. An LTX-2 + // session consumes its prompt encoder on first `prepare()` (see + // runtime.rs `prepare()` — the take+drop frees VRAM for the + // transformer); a stale session left behind by a prior chain run + // survives intact for same-prompt continuations via the session- + // level encoding cache, but we must rebuild from scratch when the + // prompt changes so `prepare()` doesn't error on a consumed encoder. let mut runtime = match self.native_runtime.take() { - Some(runtime) => runtime, - None => self.create_runtime_session(&plan)?, + Some(runtime) if runtime.can_reuse_for(&plan) => runtime, + _ => self.create_runtime_session(&plan)?, }; Self::log_timing("pipeline.create_runtime", create_runtime_start); @@ -522,6 +557,160 @@ impl Ltx2Engine { gpu: None, }) } + + /// Render a single chain stage, optionally conditioning on a carryover + /// tail from the prior stage. + /// + /// `motion_tail_pixel_frames` is the number of pixel frames to narrow + /// off the emitted latents for the *next* stage's carryover. `0` + /// returns an error (nonsensical — use the regular single-clip path + /// if no tail is wanted). + /// + /// Scope: distilled LTX-2 pipeline only. Other pipeline families + /// return an error up-front so the chain orchestrator fails fast. + pub(crate) fn render_chain_stage( + &mut self, + req: &GenerateRequest, + carry: Option<&ChainTail>, + motion_tail_pixel_frames: u32, + ) -> Result { + if motion_tail_pixel_frames == 0 { + bail!("render_chain_stage: motion_tail_pixel_frames must be > 0"); + } + if !self.loaded { + self.load()?; + } + let start = Instant::now(); + self.emit("Preparing native LTX-2 chain stage"); + + let pipeline = self.select_pipeline(req)?; + if !matches!(pipeline, PipelineKind::Distilled) { + bail!( + "render-chain v1 only supports the distilled LTX-2 pipeline, got {:?}", + pipeline, + ); + } + + let work_dir = tempfile::tempdir().context("failed to create LTX-2 temp directory")?; + let native_output = work_dir.path().join("ltx2-native-output.mp4"); + let mut plan = self.materialize_request(req, work_dir.path(), &native_output)?; + + // Inject carryover RGB frames as a StagedLatent at frame 0. The + // runtime VAE-encodes them fresh on the receiving side so every + // resulting latent slot has correct causal/continuation semantics + // in this clip's own time axis (see conditioning.rs StagedLatent + // docstring + runtime.rs maybe_load_stage_video_conditioning). + // + // When the chain request carries a starting image (i2v flow), the + // orchestrator passes it through on every stage. Stage 0 uses it + // as the frame-0 i2v replacement — great. On continuations the + // motion-tail pin owns frame 0, so we re-route any frame-0 staged + // image to a non-zero frame with reduced "soft anchor" strength: + // the image becomes a durable identity reference appended to the + // token sequence (via the `VideoTokenAppendCondition` path in + // `maybe_load_stage_video_conditioning`), giving the free-region + // denoise a persistent cross-attention anchor for subject / scene + // appearance without freezing any tokens. Without this anchor, + // identity drift compounds stage-over-stage because each clip's + // only long-range reference is its own drifted last-frame carry. + if let Some(tail) = carry { + if tail.tail_rgb_frames.is_empty() { + bail!( + "render_chain_stage: carry.tail_rgb_frames is empty; caller must provide at least one frame" + ); + } + + // Re-route any frame-0 staged image into the soft-anchor + // append slot. The anchor frame is the first pixel past the + // motion-tail pin, so the reference token's RoPE sits exactly + // where new content starts — cross-attention propagates + // identity into the free region most directly from there. + // `CHAIN_SOFT_ANCHOR_STRENGTH = 0.4` gives the denoise mask a + // value of `1 - 0.4 = 0.6` at the anchor token, so the + // denoiser blends ~60% generated / ~40% reference every step. + let anchor_frame = motion_tail_pixel_frames; + for image in plan.conditioning.images.iter_mut() { + if image.frame == 0 { + image.frame = anchor_frame; + image.strength = CHAIN_SOFT_ANCHOR_STRENGTH; + } + } + + plan.conditioning.latents.push(StagedLatent { + tail_rgb_frames: tail.tail_rgb_frames.clone(), + frame: 0, + strength: 1.0, + }); + } + + // Reuse an existing runtime session if we have one AND it can + // serve this plan. Between stages of a same-prompt chain the + // session-level encoding cache handles the consumed-encoder + // invariant; if the prompt shifts (or a stale session leaked in + // from a prior run) we drop the runtime and rebuild so + // `prepare()` doesn't error on a missing encoder. + let mut runtime = match self.native_runtime.take() { + Some(runtime) if runtime.can_reuse_for(&plan) => runtime, + _ => self.create_runtime_session(&plan)?, + }; + + self.emit("Executing native LTX-2 chain stage runtime"); + let prepared = match runtime.prepare(&plan) { + Ok(prepared) => prepared, + Err(err) => { + self.native_runtime = Some(runtime); + return Err(err); + } + }; + let render_result = + runtime.render_native_video(&plan, &prepared, self.on_progress.as_ref()); + self.native_runtime = Some(runtime); + let rendered = render_result?; + + let frames = rendered.frames; + let tail_pixel_frames = motion_tail_pixel_frames as usize; + if frames.len() < tail_pixel_frames { + bail!( + "distilled render returned {} pixel frames but the chain caller requested a {}-frame tail; \ + this is a pipeline wiring bug", + frames.len(), + motion_tail_pixel_frames, + ); + } + let tail_start = frames.len() - tail_pixel_frames; + let tail_rgb_frames = frames[tail_start..].to_vec(); + + let generation_time_ms = start.elapsed().as_millis() as u64; + Self::log_timing("pipeline.render_chain_stage", start); + + Ok(StageOutcome { + frames, + tail: ChainTail { + frames: motion_tail_pixel_frames, + tail_rgb_frames, + }, + generation_time_ms, + }) + } +} + +impl ChainStageRenderer for Ltx2Engine { + fn render_stage( + &mut self, + stage_req: &GenerateRequest, + carry: Option<&ChainTail>, + motion_tail_pixel_frames: u32, + _stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, + ) -> Result { + // `_stage_progress` is intentionally unused in v1: per-stage + // denoise events flow through `self.on_progress` already. Phase 2's + // server route will install an on_progress callback that forwards + // those events onto the chain SSE stream with `stage_idx` tagged + // in. If the orchestrator later needs denoise-step events routed + // through its own channel, we can plumb `stage_progress` into a + // temporary ProgressCallback wrapper here. + self.render_chain_stage(stage_req, carry, motion_tail_pixel_frames) + } } impl InferenceEngine for Ltx2Engine { @@ -576,6 +765,10 @@ impl InferenceEngine for Ltx2Engine { fn model_paths(&self) -> Option<&ModelPaths> { Some(&self.paths) } + + fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> { + Some(self) + } } #[cfg(test)] @@ -855,7 +1048,7 @@ mod tests { .unwrap(), PaddingSide::Left, ); - Ltx2RuntimeSession::new(Device::Cpu, prompt_encoder) + Ltx2RuntimeSession::new(Device::Cpu, prompt_encoder, 0) } fn request(output_format: OutputFormat, enable_audio: Option) -> GenerateRequest { @@ -905,6 +1098,7 @@ mod tests { "ltx-2.3-22b-distilled:fp8".to_string(), dummy_paths(), LoadStrategy::Sequential, + 0, ); let req = GenerateRequest { prompt: "test".to_string(), @@ -966,6 +1160,7 @@ mod tests { "ltx-2-19b-distilled:fp8".to_string(), dummy_paths(), LoadStrategy::Sequential, + 0, ); assert_eq!(engine.request_quantization(), Some("fp8-cast".to_string())); } @@ -985,6 +1180,7 @@ mod tests { "ltx-2-19b-distilled:fp8".to_string(), dummy_paths_with_gemma_root(gemma_dir.path()), LoadStrategy::Sequential, + 0, ); let req = GenerateRequest { prompt: "test".to_string(), @@ -1052,6 +1248,7 @@ mod tests { "ltx-2-19b-distilled:fp8".to_string(), paths, LoadStrategy::Sequential, + 0, ); engine.load().unwrap(); @@ -1087,4 +1284,47 @@ mod tests { assert!(!video.has_audio); assert!(engine.native_runtime.is_none()); } + + #[test] + fn render_chain_stage_rejects_non_distilled_pipeline() { + // A model name without "distilled" in it selects `PipelineKind::TwoStage` + // via `select_pipeline`, which must be rejected up-front by the chain + // entry point before any runtime work happens. + let mut engine = Ltx2Engine::with_runtime_session( + "ltx-2-19b:fp8".to_string(), + dummy_paths(), + runtime_session(), + ); + engine.loaded = true; + let req = request(OutputFormat::Mp4, Some(false)); + let err = engine + .render_chain_stage(&req, None, 4) + .expect_err("must fail on non-distilled pipeline"); + let msg = format!("{err}"); + assert!( + msg.contains("distilled"), + "error must name the pipeline constraint, got: {msg}", + ); + } + + #[test] + fn render_chain_stage_rejects_zero_motion_tail() { + // Zero-frame motion tail is nonsensical — it would narrow nothing off + // for the next stage. Fast-fail before any allocation. + let mut engine = Ltx2Engine::with_runtime_session( + "ltx-2-19b-distilled:fp8".to_string(), + dummy_paths(), + runtime_session(), + ); + engine.loaded = true; + let req = request(OutputFormat::Mp4, Some(false)); + let err = engine + .render_chain_stage(&req, None, 0) + .expect_err("must fail on zero motion tail"); + let msg = format!("{err}"); + assert!( + msg.contains("motion_tail_pixel_frames"), + "error must name the motion_tail constraint, got: {msg}", + ); + } } diff --git a/crates/mold-inference/src/ltx2/runtime.rs b/crates/mold-inference/src/ltx2/runtime.rs index 781ab3c4..8b1b6c55 100644 --- a/crates/mold-inference/src/ltx2/runtime.rs +++ b/crates/mold-inference/src/ltx2/runtime.rs @@ -291,21 +291,104 @@ impl Ltx2VaeLatentStats { pub struct Ltx2RuntimeSession { device: Option, prompt_encoder: Option, + /// Cached output of the last successful `encode_prompt_pair_with_unconditional` + /// call. The prompt encoder is intentionally consumed during the first + /// `prepare()` so its VRAM can be freed for the transformer (see the + /// `take()` + drop pattern below); that leaves subsequent `prepare()` + /// calls on the same session with no encoder. For the render-chain + /// path every stage shares the same prompt tokens, so we cache the + /// encoding after the first encode and reuse it on follow-up stages — + /// no re-encode, no encoder re-load, no VRAM re-hit. + cached_prompt_encoding: Option, + /// Optional slot wired into `render_real_distilled_av` so + /// `Ltx2Engine::render_chain_stage` can snapshot the pre-VAE-decode + /// final latents and forward them to the next chain stage as a + /// [`super::chain::ChainTail`]. `None` outside chain flow. + pub(crate) tail_capture: Option>>>, + /// GPU ordinal inherited from `Ltx2Engine`. Used for the deferred CUDA + /// device creation in `prepare()` and for post-OOM context reset. + gpu_ordinal: usize, +} + +/// Remembers the last `encode_prompt_pair_with_unconditional` call so +/// successive `prepare()` calls with the same prompt can skip the encoder +/// entirely — used by the render-chain path where stages share a prompt. +struct CachedPromptEncoding { + token_pair: super::text::gemma::EncodedPromptPair, + encode_unconditional: bool, + encoding: NativePromptEncoding, + prompt_device_is_cuda: bool, + prepared_device: candle_core::Device, } impl Ltx2RuntimeSession { - pub fn new(device: candle_core::Device, prompt_encoder: NativePromptEncoder) -> Self { + pub fn new( + device: candle_core::Device, + prompt_encoder: NativePromptEncoder, + gpu_ordinal: usize, + ) -> Self { Self { device: Some(device), prompt_encoder: Some(prompt_encoder), + cached_prompt_encoding: None, + tail_capture: None, + gpu_ordinal, } } - pub fn new_deferred_cuda(prompt_encoder: NativePromptEncoder) -> Self { + pub fn new_deferred_cuda(prompt_encoder: NativePromptEncoder, gpu_ordinal: usize) -> Self { Self { device: None, prompt_encoder: Some(prompt_encoder), + cached_prompt_encoding: None, + tail_capture: None, + gpu_ordinal, + } + } + + /// Arm the pre-VAE-decode latent capture slot. The distilled render + /// path writes its `final_video_latents` into the returned slot when + /// this is set, letting a caller drain the raw latents after a render + /// completes. Kept after the v1.1 decoded-pixel-carryover switch in + /// case future work (e.g. quality-diagnostic tooling) wants access + /// to the pre-decode tensor; the production chain path no longer + /// arms it. + #[allow(dead_code)] + pub(crate) fn arm_tail_capture(&mut self) -> std::sync::Arc>> { + let slot = std::sync::Arc::new(std::sync::Mutex::new(None)); + self.tail_capture = Some(std::sync::Arc::clone(&slot)); + slot + } + + /// Disarm the latent capture slot. See [`arm_tail_capture`]. + #[allow(dead_code)] + pub(crate) fn clear_tail_capture(&mut self) { + self.tail_capture = None; + } + + /// Whether this session can serve `plan` without a rebuild. Returns + /// `true` if the encoder is still available OR the cached encoding + /// matches the plan's prompt tokens. Callers use this to decide + /// whether to reuse a persisted runtime (fast path — keeps transformer + /// and VAE warm) or drop it and build a fresh one (the only way to + /// recover when the encoder has been consumed on a prior `prepare()` + /// and a different prompt arrives). + pub fn can_reuse_for(&self, plan: &Ltx2GeneratePlan) -> bool { + if self.prompt_encoder.is_some() { + return true; + } + let Ok(encode_unconditional) = prompt_requires_unconditional_context(plan) else { + return false; + }; + // Alt-prompt debug mode requires the live encoder; cache alone + // isn't sufficient. + if ltx_debug_alt_prompt().is_some() { + return false; } + self.cached_prompt_encoding.as_ref().is_some_and(|cached| { + cached.encode_unconditional == encode_unconditional + && cached.token_pair == plan.prompt_tokens + }) } pub fn prepare(&mut self, plan: &Ltx2GeneratePlan) -> Result { @@ -334,7 +417,31 @@ impl Ltx2RuntimeSession { stage1_shape.width = implicit_x2_shape.width; stage1_shape.height = implicit_x2_shape.height; } - let (prompt_device_is_cuda, prepared_device, prompt, debug_alt_prompt) = { + let encode_unconditional_prompt = prompt_requires_unconditional_context(plan)?; + let alt_prompt_env = ltx_debug_alt_prompt(); + // Chain path fast-path: if a previous `prepare()` already encoded + // the exact same prompt+unconditional combo, reuse those embeddings + // instead of demanding the encoder back. Disabled when the + // `MOLD_LTX_DEBUG_ALT_PROMPT` debug hook is active because that branch + // still needs the live encoder. + let cache_hit = alt_prompt_env.is_none() + && self.cached_prompt_encoding.as_ref().is_some_and(|cached| { + cached.encode_unconditional == encode_unconditional_prompt + && cached.token_pair == plan.prompt_tokens + }); + let (prompt_device_is_cuda, prepared_device, prompt, debug_alt_prompt) = if cache_hit { + let cached = self + .cached_prompt_encoding + .as_ref() + .expect("cache_hit implies cached_prompt_encoding is Some"); + log_timing("prepare.prompt_pair", Instant::now()); + ( + cached.prompt_device_is_cuda, + cached.prepared_device.clone(), + cached.encoding.clone(), + None, + ) + } else { let mut prompt_encoder = self .prompt_encoder .take() @@ -346,7 +453,6 @@ impl Ltx2RuntimeSession { prompt_encoder.device().clone() }; let prompt_encode_start = Instant::now(); - let encode_unconditional_prompt = prompt_requires_unconditional_context(plan)?; let prompt = move_prompt_encoding_to_device( prompt_encoder.encode_prompt_pair_with_unconditional( &plan.prompt_tokens, @@ -356,7 +462,7 @@ impl Ltx2RuntimeSession { )?; log_timing("prepare.prompt_pair", prompt_encode_start); let alt_prompt_start = Instant::now(); - let debug_alt_prompt = match ltx_debug_alt_prompt() { + let debug_alt_prompt = match alt_prompt_env.clone() { Some(alt_prompt) => { let assets = super::text::gemma::GemmaAssets::discover(Path::new(&plan.gemma_root)) @@ -387,6 +493,18 @@ impl Ltx2RuntimeSession { } } log_timing("prepare.prompt_debug", prompt_debug_start); + // Cache the encoding for the next chain stage. Dropping the + // encoder here (end of the else branch) still happens — we're + // only holding on to the `NativePromptEncoding` output, not the + // encoder itself, so the VRAM-free property of the original + // take() pattern is preserved. + self.cached_prompt_encoding = Some(CachedPromptEncoding { + token_pair: plan.prompt_tokens.clone(), + encode_unconditional: encode_unconditional_prompt, + encoding: prompt.clone(), + prompt_device_is_cuda, + prepared_device: prepared_device.clone(), + }); ( prompt_device_is_cuda, prepared_device, @@ -397,8 +515,8 @@ impl Ltx2RuntimeSession { let device_handoff_start = Instant::now(); if prompt_device_is_cuda { if self.device.is_none() { - crate::device::reclaim_gpu_memory(0); - self.device = Some(new_native_cuda_device()?); + crate::device::reclaim_gpu_memory(self.gpu_ordinal); + self.device = Some(new_native_cuda_device(self.gpu_ordinal)?); } else if let Some(device) = self.device.as_ref() { if device.is_cuda() { device.synchronize()?; @@ -597,7 +715,13 @@ impl Ltx2RuntimeSession { return Ok(None); } let render = match plan.pipeline { - PipelineKind::Distilled => render_real_distilled_av(plan, prepared, device, progress), + PipelineKind::Distilled => render_real_distilled_av( + plan, + prepared, + device, + progress, + self.tail_capture.as_ref(), + ), PipelineKind::OneStage => render_real_one_stage_av(plan, prepared, device, progress), PipelineKind::TwoStage | PipelineKind::TwoStageHq @@ -841,8 +965,8 @@ fn overlay_alpha(overlay: &ConditioningOverlay, frame_idx: u32, total_frames: u3 } #[cfg(feature = "cuda")] -fn new_native_cuda_device() -> Result { - let device = candle_core::Device::new_cuda(0)?; +fn new_native_cuda_device(ordinal: usize) -> Result { + let device = candle_core::Device::new_cuda(ordinal)?; let cuda = device.as_cuda_device()?; if cuda.is_event_tracking() { unsafe { @@ -853,7 +977,7 @@ fn new_native_cuda_device() -> Result { } #[cfg(not(feature = "cuda"))] -fn new_native_cuda_device() -> Result { +fn new_native_cuda_device(_ordinal: usize) -> Result { anyhow::bail!("CUDA backend is unavailable in this build") } @@ -1219,17 +1343,36 @@ fn maybe_load_stage_video_conditioning( dtype: DType, include_reference_video: bool, ) -> Result { - if plan.conditioning.images.is_empty() && !include_reference_video { + if plan.conditioning.images.is_empty() + && plan.conditioning.latents.is_empty() + && !include_reference_video + { return Ok(StageVideoConditioning::default()); } - let mut vae = load_ltx2_video_vae(plan, device, dtype)?; - vae.use_tiling = false; - vae.use_framewise_decoding = false; + // The VAE is needed for staged images, reference video ingest, and — + // on chain continuations — re-encoding the emitting stage's trailing + // RGB frames into a proper-slot-semantics conditioning latent. Every + // StagedLatent now carries RGB frames, so any non-empty + // plan.conditioning.latents implies a VAE load. + let need_vae = !plan.conditioning.images.is_empty() + || include_reference_video + || !plan.conditioning.latents.is_empty(); + let mut vae = if need_vae { + let mut loaded = load_ltx2_video_vae(plan, device, dtype)?; + loaded.use_tiling = false; + loaded.use_framewise_decoding = false; + Some(loaded) + } else { + None + }; let patchifier = VideoLatentPatchifier::new(1); let mut conditioning = StageVideoConditioning::default(); for image in &plan.conditioning.images { + let vae = vae.as_mut().expect( + "need_vae guarantees the VAE is loaded whenever plan.conditioning.images is non-empty", + ); let bytes = std::fs::read(&image.path).with_context(|| { format!( "failed to read staged LTX-2 conditioning image '{}'", @@ -1271,7 +1414,52 @@ fn maybe_load_stage_video_conditioning( )?); } } + // Chain carryover: every StagedLatent is a contiguous RGB window from + // the end of the emitting stage. Re-encoding on the receiving side + // (rather than slicing the emitting stage's final latent tensor) keeps + // slot semantics aligned with the receiving clip's time axis — slot 0 + // is a proper causal 1-pixel encoding, slot 1+ are proper 8-pixel + // continuation encodings, with no ambiguity about which latent slot + // corresponds to which pixel-frame range. + for staged in &plan.conditioning.latents { + if staged.tail_rgb_frames.is_empty() { + anyhow::bail!( + "StagedLatent has an empty tail_rgb_frames; at least one frame is required" + ); + } + let vae = vae.as_mut().expect( + "need_vae guarantees the VAE is loaded whenever plan.conditioning.latents is non-empty", + ); + let video = video_tensor_from_frames(&staged.tail_rgb_frames, device, dtype) + .context("encode chain tail RGB frames into pixel tensor for carryover")?; + let latents = vae + .encode(&video) + .context("failed to encode chain tail RGB frames through the LTX-2 video VAE")? + .to_dtype(DType::F32)?; + let use_guiding_latent = matches!(plan.pipeline, PipelineKind::Keyframe); + if staged.frame == 0 && !use_guiding_latent { + let tokens = patchifier.patchify(&latents)?; + conditioning.replacements.push(VideoTokenReplacement { + start_token: 0, + tokens, + strength: staged.strength as f64, + }); + } else { + conditioning + .appended + .push(append_condition_from_video_latents( + &latents, + pixel_shape, + staged.frame, + 1, + staged.strength as f64, + )?); + } + } if include_reference_video { + let vae = vae.as_mut().expect( + "need_vae guarantees the VAE is loaded whenever include_reference_video is true", + ); let video_path = plan.conditioning.video_path.as_ref().with_context(|| { format!( "native {:?} stage requested reference video conditioning without a staged source_video", @@ -1379,6 +1567,36 @@ fn apply_video_token_replacements( Ok(patched) } +/// Build the "clean reference" tensor used by the denoise mask blend at every +/// step. For replacement-based conditioning (e.g. i2v source image) with +/// `strength < 1.0`, `video_latents` already holds `noise*(1-s) + source*s` at +/// the replacement positions. If we reuse that as the clean target, the +/// denoise-mask blend pulls those tokens toward a noisy ghost of the image at +/// every step — the first latent frame never converges to the pure source. +/// +/// Re-applying the replacements with strength 1.0 overwrites those positions +/// with the pure source tokens, leaving appended keyframe tokens (already +/// full-strength in `apply_appended_video_conditioning`) and pure-noise +/// regions untouched. +fn clean_latents_for_conditioning( + video_latents: &Tensor, + conditioning: &StageVideoConditioning, +) -> Result { + if conditioning.replacements.is_empty() { + return Ok(video_latents.clone()); + } + let hard_replacements: Vec = conditioning + .replacements + .iter() + .map(|replacement| VideoTokenReplacement { + start_token: replacement.start_token, + tokens: replacement.tokens.clone(), + strength: 1.0, + }) + .collect(); + apply_video_token_replacements(video_latents, &hard_replacements) +} + fn apply_appended_video_conditioning( video_latents: &Tensor, video_positions: &Tensor, @@ -1454,9 +1672,10 @@ fn reapply_stage_video_conditioning( let mut parts = vec![base]; for condition in &conditioning.appended { - if condition.strength < 1.0 { - continue; - } + // Appended conditioning tokens must remain present for the whole + // denoise loop. Their strength is expressed via the denoise mask; + // dropping "soft" appended tokens here desynchronizes the token + // count from the cached clean latents and mask tensors. parts.push( condition .tokens @@ -1650,6 +1869,7 @@ fn render_real_distilled_av( prepared: &NativePreparedRun, device: &candle_core::Device, progress: Option<&ProgressCallback>, + tail_capture: Option<&std::sync::Arc>>>, ) -> Result { let debug_enabled = ltx_debug_enabled(); let prompt_inputs = prepare_render_prompt_inputs( @@ -1934,6 +2154,16 @@ fn render_real_distilled_av( vae.use_tiling = false; vae.use_framewise_decoding = false; let decode_start = Instant::now(); + // Chain-stage hook: capture the pre-decode F32 latents so + // `Ltx2Engine::render_chain_stage` can narrow the tail off for the next + // stage's conditioning. Cheap shallow clone (candle tensors are + // Arc-backed). A poisoned mutex is ignored here — the outer caller + // detects an empty slot and emits a clear error. + if let Some(slot) = tail_capture { + if let Ok(mut guard) = slot.lock() { + *guard = Some(latents.clone()); + } + } let (_dec_output, video) = vae.decode(&latents.to_dtype(dtype)?, None, false, false)?; if debug_enabled { log_tensor_stats("decoded_video", &video)?; @@ -2699,7 +2929,7 @@ fn run_real_distilled_stage( )?; let clean_video_latents = match video_clean_latents { Some(latents) => video_patchifier.patchify(latents)?, - None => video_latents.clone(), + None => clean_latents_for_conditioning(&video_latents, video_conditioning)?, }; let video_denoise_mask = match video_denoise_mask { Some(mask) => mask.to_device(&device)?.to_dtype(DType::F32)?, @@ -4622,14 +4852,14 @@ mod tests { use super::{ apply_stage_video_conditioning, apply_video_token_replacements, - build_video_conditioning_self_attention_mask, convert_velocity_to_x0, - convert_x0_to_velocity, decoded_video_to_frames, effective_native_guidance_scale, - emit_denoise_progress, guided_velocity_from_cfg, keyframe_only_conditioning, - ltx2_video_transformer_config, reapply_stage_video_conditioning, - should_inspect_step_velocity, source_image_only_conditioning, - strip_appended_video_conditioning, Ltx2RuntimeSession, StageVideoConditioning, - VideoTokenAppendCondition, VideoTokenReplacement, LTX2_AUDIO_LATENT_CHANNELS, - LTX2_VIDEO_LATENT_CHANNELS, + build_video_conditioning_self_attention_mask, clean_latents_for_conditioning, + convert_velocity_to_x0, convert_x0_to_velocity, decoded_video_to_frames, + effective_native_guidance_scale, emit_denoise_progress, guided_velocity_from_cfg, + keyframe_only_conditioning, ltx2_video_transformer_config, + reapply_stage_video_conditioning, should_inspect_step_velocity, + source_image_only_conditioning, strip_appended_video_conditioning, Ltx2RuntimeSession, + StageVideoConditioning, VideoTokenAppendCondition, VideoTokenReplacement, + LTX2_AUDIO_LATENT_CHANNELS, LTX2_VIDEO_LATENT_CHANNELS, }; use crate::ltx2::conditioning::{self, StagedConditioning}; use crate::ltx2::model::VideoPixelShape; @@ -4882,7 +5112,7 @@ mod tests { .unwrap(), PaddingSide::Left, ); - Ltx2RuntimeSession::new(candle_core::Device::Cpu, prompt_encoder) + Ltx2RuntimeSession::new(candle_core::Device::Cpu, prompt_encoder, 0) } fn build_plan( @@ -5387,6 +5617,33 @@ mod tests { #[test] fn runtime_session_prepare_consumes_prompt_encoder() { + // The encoder is still consumed on first prepare() — the encoder + // slot moves out to free VRAM for the transformer. But same-prompt + // follow-up calls now short-circuit through `cached_prompt_encoding` + // so chain stages that replicate the prompt can reuse the session + // instead of erroring on a consumed encoder. + let req = req("ltx-2.3-22b-distilled:fp8", OutputFormat::Mp4, Some(false)); + let temp_dir = tempfile::tempdir().unwrap(); + let conditioning = conditioning::stage_conditioning(&req, temp_dir.path()).unwrap(); + let preset = preset_for_model(&req.model).unwrap(); + let plan = build_plan(&req, preset, conditioning); + + let mut session = runtime_session(); + session.prepare(&plan).unwrap(); + + // Encoder slot is empty post-take. + assert!(session.prompt_encoder.is_none()); + // But `can_reuse_for` reports true because the cached encoding + // matches the incoming plan's prompt tokens. + assert!(session.can_reuse_for(&plan)); + // Same-prompt re-prepare succeeds from the cache. + session + .prepare(&plan) + .expect("same-prompt cache hit must succeed"); + } + + #[test] + fn runtime_session_prepare_rejects_encoder_reuse_with_different_prompt() { let req = req("ltx-2.3-22b-distilled:fp8", OutputFormat::Mp4, Some(false)); let temp_dir = tempfile::tempdir().unwrap(); let conditioning = conditioning::stage_conditioning(&req, temp_dir.path()).unwrap(); @@ -5396,7 +5653,17 @@ mod tests { let mut session = runtime_session(); session.prepare(&plan).unwrap(); - assert!(session.prepare(&plan).is_err()); + // Mutate the plan's prompt tokens so the cache key misses. + let mut plan_alt = plan.clone(); + plan_alt.prompt_tokens.conditional.input_ids[0] = + plan_alt.prompt_tokens.conditional.input_ids[0].wrapping_add(1); + + // can_reuse_for must report false for a fresh prompt because the + // encoder has already been consumed. + assert!(!session.can_reuse_for(&plan_alt)); + // And prepare() with the new plan fails explicitly so the caller + // knows to drop the session and rebuild. + assert!(session.prepare(&plan_alt).is_err()); } #[test] @@ -5693,6 +5960,109 @@ mod tests { ); } + #[test] + fn reapply_stage_video_conditioning_keeps_soft_appended_tokens() { + let latents = + Tensor::from_vec(vec![0.0f32, 0.0, 1.0, 1.0], (1, 2, 2), &Device::Cpu).unwrap(); + let conditioning = StageVideoConditioning { + replacements: vec![], + appended: vec![VideoTokenAppendCondition { + tokens: Tensor::from_vec(vec![9.0f32, 10.0], (1, 1, 2), &Device::Cpu).unwrap(), + positions: Tensor::from_vec(vec![30.0f32, 40.0, 50.0], (1, 3, 1, 1), &Device::Cpu) + .unwrap(), + strength: 0.4, + }], + }; + + let reapplied = reapply_stage_video_conditioning(&latents, 2, &conditioning).unwrap(); + assert_eq!(reapplied.dims3().unwrap(), (1, 3, 2)); + assert_eq!( + reapplied.flatten_all().unwrap().to_vec1::().unwrap(), + vec![0.0, 0.0, 1.0, 1.0, 9.0, 10.0] + ); + } + + #[test] + fn clean_latents_replace_soft_blended_positions_with_pure_source() { + // Simulate the state after `apply_stage_video_conditioning` with + // strength 0.75: at the replacement positions, `video_latents` already + // holds `noise*0.25 + source*0.75`. The denoise-mask blend uses + // `clean_latents` as the target it pulls those positions toward at + // every step — so the clean target must be pure source, not the + // pre-blended mix. + let noise = [0.0f32, 0.0, 1.0, 1.0, 2.0, 2.0]; + let source = [10.0f32, 10.0]; + let strength = 0.75f32; + let blended_first = [ + noise[0] * (1.0 - strength) + source[0] * strength, + noise[1] * (1.0 - strength) + source[1] * strength, + ]; + let soft_blended = Tensor::from_vec( + vec![ + blended_first[0], + blended_first[1], + noise[2], + noise[3], + noise[4], + noise[5], + ], + (1, 3, 2), + &Device::Cpu, + ) + .unwrap(); + let conditioning = StageVideoConditioning { + replacements: vec![VideoTokenReplacement { + start_token: 0, + tokens: Tensor::from_vec(source.to_vec(), (1, 1, 2), &Device::Cpu).unwrap(), + strength: strength as f64, + }], + appended: vec![], + }; + + let clean = clean_latents_for_conditioning(&soft_blended, &conditioning).unwrap(); + let values = clean.flatten_all().unwrap().to_vec1::().unwrap(); + + assert_eq!( + values, + vec![source[0], source[1], noise[2], noise[3], noise[4], noise[5]], + "soft-blended replacement positions must be overwritten with the pure \ + source tokens; other positions must be preserved unchanged" + ); + } + + #[test] + fn clean_latents_passthrough_when_no_replacements() { + let latents = + Tensor::from_vec(vec![0.0f32, 1.0, 2.0, 3.0], (1, 2, 2), &Device::Cpu).unwrap(); + let conditioning = StageVideoConditioning::default(); + + let clean = clean_latents_for_conditioning(&latents, &conditioning).unwrap(); + assert_eq!( + clean.flatten_all().unwrap().to_vec1::().unwrap(), + vec![0.0, 1.0, 2.0, 3.0] + ); + } + + #[test] + fn staged_latent_patchifies_to_same_token_shape_as_image_at_single_latent_frame() { + // A 4-pixel-frame motion tail at 1216×704 output lands on a latent + // block of shape [1, 128, 1, 22, 38]. The render-chain orchestrator + // produces this block from the prior stage's denoise result; the + // image-conditioning path produces the same shape after VAE encode. + // Both must patchify to [1, T*H*W, C] = [1, 1*22*38, 128] tokens so + // the downstream replacement pass sees them identically regardless + // of which path produced them. + let latents = Tensor::zeros( + (1, LTX2_VIDEO_LATENT_CHANNELS, 1, 22, 38), + DType::F32, + &Device::Cpu, + ) + .unwrap(); + let patchifier = super::VideoLatentPatchifier::new(1); + let tokens = patchifier.patchify(&latents).expect("patchify"); + assert_eq!(tokens.dims(), &[1, 22 * 38, LTX2_VIDEO_LATENT_CHANNELS]); + } + #[test] fn video_conditioning_self_attention_mask_blocks_cross_keyframe_attention() { let conditioning = StageVideoConditioning { diff --git a/crates/mold-inference/src/ltx_video/mod.rs b/crates/mold-inference/src/ltx_video/mod.rs index 4e37627f..aa01b614 100644 --- a/crates/mold-inference/src/ltx_video/mod.rs +++ b/crates/mold-inference/src/ltx_video/mod.rs @@ -1,5 +1,8 @@ pub(crate) mod latent_upsampler; mod pipeline; -pub(crate) mod video_enc; +// Video encoding helpers (GIF/APNG/WebP/MP4 + thumbnail) are used by +// chain stitching in `mold-server`, so the module is public rather than +// crate-private. +pub mod video_enc; pub use pipeline::LtxVideoEngine; diff --git a/crates/mold-inference/src/qwen_image/offload.rs b/crates/mold-inference/src/qwen_image/offload.rs index 774afdd0..0b0733d4 100644 --- a/crates/mold-inference/src/qwen_image/offload.rs +++ b/crates/mold-inference/src/qwen_image/offload.rs @@ -495,6 +495,7 @@ impl OffloadedQwenImageTransformer { cpu_vb: VarBuilder, cfg: &QwenImageConfig, gpu_device: &Device, + gpu_ordinal: usize, progress: &ProgressReporter, ) -> Result { progress.info("Loading transformer with dynamic GPU/CPU placement…"); @@ -521,7 +522,7 @@ impl OffloadedQwenImageTransformer { // Measure free VRAM after stem layers and decide how many blocks fit gpu_device.synchronize()?; - let free_vram = crate::device::free_vram_bytes(0).unwrap_or(0); + let free_vram = crate::device::free_vram_bytes(gpu_ordinal).unwrap_or(0); const VRAM_HEADROOM: u64 = 4_500_000_000; // 4.5GB for attention + activations + CUDA overhead let vram_budget = free_vram.saturating_sub(VRAM_HEADROOM); diff --git a/crates/mold-inference/src/qwen_image/pipeline.rs b/crates/mold-inference/src/qwen_image/pipeline.rs index 7128afa8..711a840b 100644 --- a/crates/mold-inference/src/qwen_image/pipeline.rs +++ b/crates/mold-inference/src/qwen_image/pipeline.rs @@ -1003,6 +1003,7 @@ impl QwenImageEngine { cpu_vb, cfg, device, + self.base.gpu_ordinal, &self.base.progress, )?, )) diff --git a/crates/mold-inference/src/upscaler/engine.rs b/crates/mold-inference/src/upscaler/engine.rs index 89590da3..eb741180 100644 --- a/crates/mold-inference/src/upscaler/engine.rs +++ b/crates/mold-inference/src/upscaler/engine.rs @@ -75,16 +75,26 @@ pub struct UpscalerEngine { loaded: Option, progress: ProgressReporter, load_strategy: LoadStrategy, + /// GPU ordinal this engine is pinned to. Same multi-GPU footgun as + /// `Ltx2Engine::gpu_ordinal` — hardcoding `0` would corrupt a sibling + /// GPU's CUDA context on unload. + gpu_ordinal: usize, } impl UpscalerEngine { - pub fn new(name: String, weights_path: PathBuf, load_strategy: LoadStrategy) -> Self { + pub fn new( + name: String, + weights_path: PathBuf, + load_strategy: LoadStrategy, + gpu_ordinal: usize, + ) -> Self { Self { name, weights_path, loaded: None, progress: ProgressReporter::default(), load_strategy, + gpu_ordinal, } } @@ -240,7 +250,7 @@ impl UpscaleEngine for UpscalerEngine { let load_start = Instant::now(); self.progress.stage_start("Loading upscaler model"); - let device = create_device(0, &self.progress)?; + let device = create_device(self.gpu_ordinal, &self.progress)?; // Determine dtype: prefer F16 on GPU, F32 on CPU let dtype = if matches!(device, Device::Cpu) { @@ -317,7 +327,7 @@ impl UpscaleEngine for UpscalerEngine { fn unload(&mut self) { if self.loaded.is_some() { self.loaded = None; - crate::reclaim_gpu_memory(0); + crate::reclaim_gpu_memory(self.gpu_ordinal); tracing::info!("Upscaler model unloaded: {}", self.name); } } @@ -340,6 +350,7 @@ pub fn create_upscale_engine( model_name: String, weights_path: PathBuf, load_strategy: LoadStrategy, + gpu_ordinal: usize, ) -> Result> { if !weights_path.exists() { bail!("upscaler weights not found: {}", weights_path.display()); @@ -348,5 +359,6 @@ pub fn create_upscale_engine( model_name, weights_path, load_strategy, + gpu_ordinal, ))) } diff --git a/crates/mold-server/Cargo.toml b/crates/mold-server/Cargo.toml index 28e76c4b..34e6645f 100644 --- a/crates/mold-server/Cargo.toml +++ b/crates/mold-server/Cargo.toml @@ -25,6 +25,7 @@ default = [] cuda = ["mold-inference/cuda"] metal = ["mold-inference/metal"] expand = ["mold-inference/expand"] +mp4 = ["mold-inference/mp4"] metrics = ["dep:metrics", "dep:metrics-exporter-prometheus"] nvml = ["dep:nvml-wrapper"] @@ -72,3 +73,7 @@ async-stream = "0.3" [dev-dependencies] tempfile = "3" tokio = { version = "1", features = ["full", "test-util"] } +# Chain route tests build a synthetic motion-tail Tensor via the same +# candle APIs the inference crate uses — keep this in lockstep with +# mold-inference's pinned candle-core-mold version. +candle-core = { package = "candle-core-mold", version = "0.9.10" } diff --git a/crates/mold-server/src/downloads.rs b/crates/mold-server/src/downloads.rs index 28dcb754..dcfb8618 100644 --- a/crates/mold-server/src/downloads.rs +++ b/crates/mold-server/src/downloads.rs @@ -236,10 +236,6 @@ impl DownloadQueue { } } -#[cfg(test)] -#[path = "downloads_test.rs"] -mod tests; - // ── PullDriver trait + real & test implementations ────────────────────────── /// Trait that hides the HuggingFace pull behind something the tests can fake. @@ -729,3 +725,7 @@ fn now_ms() -> i64 { .map(|d| d.as_millis() as i64) .unwrap_or(0) } + +#[cfg(test)] +#[path = "downloads_test.rs"] +mod tests; diff --git a/crates/mold-server/src/downloads_test.rs b/crates/mold-server/src/downloads_test.rs index d487d269..22dac0e8 100644 --- a/crates/mold-server/src/downloads_test.rs +++ b/crates/mold-server/src/downloads_test.rs @@ -4,6 +4,11 @@ //! These tests never touch HuggingFace — they inject a fake `PullDriver` so //! the queue logic can be exercised in isolation. +// The tests use `std::sync::Mutex<()>` to serialize process-global env-var +// mutations; holding the guard across `.await` is intentional under the +// current-thread tokio test runtime. +#![allow(clippy::await_holding_lock)] + use crate::downloads::DownloadQueue; #[tokio::test] diff --git a/crates/mold-server/src/gpu_pool.rs b/crates/mold-server/src/gpu_pool.rs index 89ced462..ce919ec4 100644 --- a/crates/mold-server/src/gpu_pool.rs +++ b/crates/mold-server/src/gpu_pool.rs @@ -1,8 +1,9 @@ use crate::model_cache::{ModelCache, ModelResidency}; -use mold_core::types::{GpuWorkerState, GpuWorkerStatus}; +use mold_core::types::{DevicePlacement, DeviceRef, GpuWorkerState, GpuWorkerStatus}; use mold_db::MetadataDb; use mold_inference::device::DiscoveredGpu; use mold_inference::shared_pool::SharedPool; +use std::collections::BTreeSet; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, RwLock}; use std::time::Instant; @@ -24,6 +25,8 @@ pub struct GpuWorker { #[derive(Debug)] pub struct ActiveGeneration { pub model: String, + pub prompt_sha256: String, + pub started_at_unix_ms: u64, pub started_at: Instant, } @@ -63,6 +66,7 @@ impl GpuWorker { /// Build a status snapshot for this worker. pub fn status(&self) -> GpuWorkerStatus { let active_gen = self.active_generation.read().unwrap(); + let in_flight = self.in_flight.load(Ordering::SeqCst); // Prefer the active-generation model name — during inflight generation // the cache entry is taken out of the cache (take-and-restore pattern), // so `cache.active_model()` returns None. Falling back to the cache @@ -74,7 +78,7 @@ impl GpuWorker { let state = if self.is_degraded() { GpuWorkerState::Degraded - } else if active_gen.is_some() { + } else if active_gen.is_some() || in_flight > 0 { GpuWorkerState::Generating } else { GpuWorkerState::Idle @@ -92,6 +96,61 @@ impl GpuWorker { } impl GpuPool { + /// Return the worker bound to `ordinal`, if present in this pool. + pub fn worker_by_ordinal(&self, ordinal: usize) -> Option> { + self.workers + .iter() + .find(|w| w.gpu.ordinal == ordinal) + .cloned() + } + + /// Validate a request/config placement against the active worker pool. + /// + /// In multi-GPU worker mode a request may explicitly pin components to at + /// most one GPU ordinal. Cross-GPU component placement would bypass the + /// worker-affinity model entirely, so reject it here instead of letting the + /// engines silently allocate on a sibling GPU. + pub fn resolve_explicit_placement_gpu( + &self, + placement: Option<&DevicePlacement>, + ) -> Result, String> { + if self.workers.is_empty() { + return Ok(None); + } + let Some(placement) = placement else { + return Ok(None); + }; + + let ordinals = placement_gpu_ordinals(placement); + if ordinals.is_empty() { + return Ok(None); + } + if ordinals.len() > 1 { + let rendered = ordinals + .iter() + .map(|o| format!("gpu:{o}")) + .collect::>() + .join(", "); + return Err(format!( + "multi-GPU worker mode only supports placement on one GPU ordinal per request; got {rendered}" + )); + } + + let ordinal = *ordinals.iter().next().expect("checked non-empty"); + if self.worker_by_ordinal(ordinal).is_none() { + let available = self + .workers + .iter() + .map(|w| w.gpu.ordinal.to_string()) + .collect::>() + .join(", "); + return Err(format!( + "gpu:{ordinal} is not available in this server's worker pool [{available}]" + )); + } + Ok(Some(ordinal)) + } + /// Find a non-degraded worker that already has this model loaded on GPU. /// If multiple workers have it, prefer the one with fewer in-flight requests. pub fn find_loaded(&self, model_name: &str) -> Option> { @@ -102,6 +161,10 @@ impl GpuPool { if w.is_degraded() { return false; } + let active_gen = w.active_generation.read().unwrap(); + if active_gen.as_ref().is_some_and(|g| g.model == model_name) { + return true; + } let cache = w.model_cache.lock().unwrap(); cache .get(model_name) @@ -117,8 +180,8 @@ impl GpuPool { /// Select the best worker for a model, using the placement strategy /// (checked in order): /// 1. Loaded and idle (model on GPU, no in-flight requests). - /// 2. Idle GPU with no model (spreads hot models across free GPUs). - /// 3. Loaded but busy — whichever loaded copy has the fewest in-flight. + /// 2. Loaded but busy — queue behind the warm copy instead of reloading. + /// 3. Idle GPU with no model (spreads cold loads across free GPUs). /// 4. Non-degraded worker with the most headroom (will evict LRU). pub fn select_worker(&self, model_name: &str, estimated_vram: u64) -> Option> { self.select_worker_excluding(model_name, estimated_vram, &[]) @@ -149,13 +212,19 @@ impl GpuPool { let mut other: Vec<&Arc> = Vec::new(); for w in &eligible { + let active_gen = w.active_generation.read().unwrap(); + let active_model = active_gen.as_ref().map(|g| g.model.as_str()); let (has_model, has_any_loaded) = { let cache = w.model_cache.lock().unwrap(); - let has_model = cache - .get(model_name) - .map(|e| e.residency == ModelResidency::Gpu) - .unwrap_or(false); - (has_model, cache.active_model().is_some()) + let has_model = active_model == Some(model_name) + || cache + .get(model_name) + .map(|e| e.residency == ModelResidency::Gpu) + .unwrap_or(false); + ( + has_model, + active_model.is_some() || cache.active_model().is_some(), + ) }; let in_flight = w.in_flight.load(Ordering::SeqCst); // During an in-flight generation the worker thread calls @@ -169,7 +238,7 @@ impl GpuPool { // and `active_generation.is_some()` (set by the worker around // the take-and-restore window) together cover every moment // between "about to pick up a job" and "just finished". - let is_busy = in_flight > 0 || w.active_generation.read().unwrap().is_some(); + let is_busy = in_flight > 0 || active_model.is_some(); if has_model && !is_busy { loaded_idle.push(w); @@ -188,7 +257,13 @@ impl GpuPool { return loaded_idle.first().map(|w| (*w).clone()); } - // 2. Idle GPU with no model — spread! Prefer smallest GPU that fits. + // 2. Loaded but busy — least in-flight wins. + if !loaded_busy.is_empty() { + loaded_busy.sort_by_key(|w| w.in_flight.load(Ordering::SeqCst)); + return loaded_busy.first().map(|w| (*w).clone()); + } + + // 3. Idle GPU with no model — spread! Prefer smallest GPU that fits. if !idle_empty.is_empty() { idle_empty.sort_by_key(|w| w.gpu.total_vram_bytes); if let Some(w) = idle_empty @@ -201,12 +276,6 @@ impl GpuPool { return idle_empty.last().map(|w| (*w).clone()); } - // 3. Loaded but busy — least in-flight wins. - if !loaded_busy.is_empty() { - loaded_busy.sort_by_key(|w| w.in_flight.load(Ordering::SeqCst)); - return loaded_busy.first().map(|w| (*w).clone()); - } - // 4. All GPUs busy with other models — most headroom first (evict LRU there). let mut busy = other; busy.sort_by(|a, b| { @@ -228,10 +297,39 @@ impl GpuPool { } } +fn placement_gpu_ordinals(placement: &DevicePlacement) -> BTreeSet { + let mut ordinals = BTreeSet::new(); + collect_gpu_ordinal(placement.text_encoders, &mut ordinals); + if let Some(adv) = placement.advanced.as_ref() { + collect_gpu_ordinal(adv.transformer, &mut ordinals); + collect_gpu_ordinal(adv.vae, &mut ordinals); + if let Some(device) = adv.clip_l { + collect_gpu_ordinal(device, &mut ordinals); + } + if let Some(device) = adv.clip_g { + collect_gpu_ordinal(device, &mut ordinals); + } + if let Some(device) = adv.t5 { + collect_gpu_ordinal(device, &mut ordinals); + } + if let Some(device) = adv.qwen { + collect_gpu_ordinal(device, &mut ordinals); + } + } + ordinals +} + +fn collect_gpu_ordinal(device: DeviceRef, out: &mut BTreeSet) { + if let DeviceRef::Gpu { ordinal } = device { + out.insert(ordinal); + } +} + #[cfg(test)] mod tests { use super::*; use crate::model_cache::ModelCache; + use mold_core::types::AdvancedPlacement; use mold_inference::shared_pool::SharedPool; /// Build a test GpuWorker with a scratch job channel and everything else @@ -301,6 +399,8 @@ mod tests { *busy.active_generation.write().unwrap() = Some(ActiveGeneration { model: "big-model".to_string(), + prompt_sha256: String::new(), + started_at_unix_ms: 0, started_at: Instant::now(), }); @@ -346,4 +446,90 @@ mod tests { // Both busy → "most headroom" — the larger GPU wins. assert_eq!(picked.gpu.ordinal, 0); } + + #[test] + fn select_worker_keeps_queueing_behind_busy_warm_worker() { + let (warm_busy, _warm_busy_rx) = test_worker(0, 24_000_000_000); + let (cold_idle, _cold_idle_rx) = test_worker(1, 24_000_000_000); + + warm_busy.in_flight.store(1, Ordering::SeqCst); + *warm_busy.active_generation.write().unwrap() = Some(ActiveGeneration { + model: "flux-dev:q4".to_string(), + prompt_sha256: String::new(), + started_at_unix_ms: 0, + started_at: Instant::now(), + }); + + let pool = GpuPool { + workers: vec![warm_busy.clone(), cold_idle.clone()], + }; + + let picked = pool + .select_worker("flux-dev:q4", 6_000_000_000) + .expect("warm worker should be preferred"); + assert_eq!(picked.gpu.ordinal, 0); + } + + #[test] + fn resolve_explicit_placement_gpu_accepts_single_worker_ordinal() { + let (worker, _rx) = test_worker(1, 24_000_000_000); + let pool = GpuPool { + workers: vec![worker], + }; + let placement = DevicePlacement { + text_encoders: DeviceRef::Auto, + advanced: Some(AdvancedPlacement { + transformer: DeviceRef::gpu(1), + ..AdvancedPlacement::default() + }), + }; + + assert_eq!( + pool.resolve_explicit_placement_gpu(Some(&placement)) + .unwrap(), + Some(1) + ); + } + + #[test] + fn resolve_explicit_placement_gpu_rejects_cross_gpu_requests() { + let (worker0, _rx0) = test_worker(0, 24_000_000_000); + let (worker1, _rx1) = test_worker(1, 24_000_000_000); + let pool = GpuPool { + workers: vec![worker0, worker1], + }; + let placement = DevicePlacement { + text_encoders: DeviceRef::gpu(0), + advanced: Some(AdvancedPlacement { + transformer: DeviceRef::gpu(1), + ..AdvancedPlacement::default() + }), + }; + + let err = pool + .resolve_explicit_placement_gpu(Some(&placement)) + .unwrap_err(); + assert!(err.contains("one GPU ordinal per request"), "{err}"); + } + + #[test] + fn resolve_explicit_placement_gpu_rejects_ordinals_outside_pool() { + let (worker1, _rx1) = test_worker(1, 24_000_000_000); + let pool = GpuPool { + workers: vec![worker1], + }; + let placement = DevicePlacement { + text_encoders: DeviceRef::Auto, + advanced: Some(AdvancedPlacement { + transformer: DeviceRef::gpu(0), + ..AdvancedPlacement::default() + }), + }; + + let err = pool + .resolve_explicit_placement_gpu(Some(&placement)) + .unwrap_err(); + assert!(err.contains("gpu:0"), "{err}"); + assert!(err.contains("[1]"), "{err}"); + } } diff --git a/crates/mold-server/src/gpu_worker.rs b/crates/mold-server/src/gpu_worker.rs index a2a5fe72..b813c09c 100644 --- a/crates/mold-server/src/gpu_worker.rs +++ b/crates/mold-server/src/gpu_worker.rs @@ -8,9 +8,10 @@ use mold_core::{ Config, ImageData, ModelPaths, OutputFormat, OutputMetadata, SseErrorEvent, SseProgressEvent, }; use mold_inference::device; +use sha2::{Digest, Sha256}; use std::sync::atomic::Ordering; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; /// Spawn the dedicated OS thread for a GPU worker. /// Returns the JoinHandle (caller should keep it alive). @@ -21,6 +22,10 @@ pub fn spawn_gpu_thread( std::thread::Builder::new() .name(format!("gpu-worker-{}", worker.gpu.ordinal)) .spawn(move || { + // Bind this thread to its GPU ordinal so `create_device` / + // `reclaim_gpu_memory` can debug-assert callers don't drift onto + // a sibling GPU's context. See device::init_thread_gpu_ordinal. + mold_inference::device::init_thread_gpu_ordinal(worker.gpu.ordinal); tracing::info!( gpu = worker.gpu.ordinal, name = %worker.gpu.name, @@ -54,6 +59,12 @@ fn process_job(worker: &GpuWorker, job: GpuJob) { } let _slot = QueueSlot(job.queue.clone()); + if job.result_tx.is_closed() { + tracing::debug!(gpu = ordinal, model = %model_name, "skipping dispatched job — client disconnected"); + worker.in_flight.fetch_sub(1, Ordering::SeqCst); + return; + } + tracing::info!(gpu = ordinal, model = %model_name, "dispatched job"); // Acquire per-GPU load lock — ensures only one model load at a time per GPU. @@ -80,10 +91,26 @@ fn process_job(worker: &GpuWorker, job: GpuJob) { let mut gen = worker.active_generation.write().unwrap(); *gen = Some(ActiveGeneration { model: model_name.clone(), + prompt_sha256: format!("{:x}", Sha256::digest(job.request.prompt.as_bytes())), + started_at_unix_ms: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64, started_at: Instant::now(), }); } + if job.result_tx.is_closed() { + tracing::debug!( + gpu = ordinal, + model = %model_name, + "skipping generation after model readiness — client disconnected" + ); + worker.in_flight.fetch_sub(1, Ordering::SeqCst); + clear_active_generation(worker); + return; + } + // Take-and-restore: remove engine from cache, release lock during inference. let taken = { let mut cache = worker.model_cache.lock().unwrap(); diff --git a/crates/mold-server/src/lib.rs b/crates/mold-server/src/lib.rs index 97ea6f42..b6caea40 100644 --- a/crates/mold-server/src/lib.rs +++ b/crates/mold-server/src/lib.rs @@ -13,6 +13,7 @@ pub mod rate_limit; pub mod request_id; pub mod resources; pub mod routes; +pub mod routes_chain; pub mod state; pub mod web_ui; @@ -115,12 +116,12 @@ pub async fn run_server( } // ── Create generation queue ──────────────────────────────────────────── - let (job_tx, job_rx) = tokio::sync::mpsc::channel(16); + let (job_tx, job_rx) = tokio::sync::mpsc::channel(queue_size.max(1)); let queue_handle = QueueHandle::new(job_tx); // ── Create AppState ──────────────────────────────────────────────────── - let mut state = match ModelPaths::resolve(&model_name, &config) { - Some(paths) => { + let mut state = if gpu_pool.worker_count() > 0 { + if let Some(paths) = ModelPaths::resolve(&model_name, &config) { info!(model = %model_name, "configured model"); info!(transformer = %paths.transformer.display()); info!(vae = %paths.vae.display()); @@ -151,25 +152,71 @@ pub async fn run_server( if let Some(text_tok) = &paths.text_tokenizer { info!(text_tok = %text_tok.display()); } - - let offload = std::env::var("MOLD_OFFLOAD").is_ok_and(|v| v == "1"); - let engine = mold_inference::create_engine_with_pool( - model_name, - paths, - &config, - mold_inference::LoadStrategy::Eager, - 0, - offload, - Some(shared_pool.clone()), - )?; - let mut state = - state::AppState::new(engine, config, queue_handle, gpu_pool.clone(), queue_size); - state.shared_pool = shared_pool; - state - } - None => { + info!("multi-GPU mode defers model loading to per-GPU workers"); + } else { info!("no default model configured — models will be pulled on first request"); - state::AppState::empty(config, queue_handle, gpu_pool.clone(), queue_size) + } + let mut state = state::AppState::empty(config, queue_handle, gpu_pool.clone(), queue_size); + state.shared_pool = shared_pool; + state + } else { + match ModelPaths::resolve(&model_name, &config) { + Some(paths) => { + info!(model = %model_name, "configured model"); + info!(transformer = %paths.transformer.display()); + info!(vae = %paths.vae.display()); + if let Some(spatial_upscaler) = &paths.spatial_upscaler { + info!(spatial_upscaler = %spatial_upscaler.display()); + } + if let Some(t5) = &paths.t5_encoder { + info!(t5 = %t5.display()); + } + if let Some(clip) = &paths.clip_encoder { + info!(clip = %clip.display()); + } + if let Some(t5_tok) = &paths.t5_tokenizer { + info!(t5_tok = %t5_tok.display()); + } + if let Some(clip_tok) = &paths.clip_tokenizer { + info!(clip_tok = %clip_tok.display()); + } + if let Some(clip2) = &paths.clip_encoder_2 { + info!(clip2 = %clip2.display()); + } + if let Some(clip2_tok) = &paths.clip_tokenizer_2 { + info!(clip2_tok = %clip2_tok.display()); + } + for (i, te) in paths.text_encoder_files.iter().enumerate() { + info!(text_encoder_shard = i, path = %te.display()); + } + if let Some(text_tok) = &paths.text_tokenizer { + info!(text_tok = %text_tok.display()); + } + + let offload = std::env::var("MOLD_OFFLOAD").is_ok_and(|v| v == "1"); + let engine = mold_inference::create_engine_with_pool( + model_name, + paths, + &config, + mold_inference::LoadStrategy::Eager, + 0, + offload, + Some(shared_pool.clone()), + )?; + let mut state = state::AppState::new( + engine, + config, + queue_handle, + gpu_pool.clone(), + queue_size, + ); + state.shared_pool = shared_pool; + state + } + None => { + info!("no default model configured — models will be pulled on first request"); + state::AppState::empty(config, queue_handle, gpu_pool.clone(), queue_size) + } } }; diff --git a/crates/mold-server/src/queue.rs b/crates/mold-server/src/queue.rs index 56655114..01971a3d 100644 --- a/crates/mold-server/src/queue.rs +++ b/crates/mold-server/src/queue.rs @@ -574,6 +574,33 @@ pub async fn run_queue_dispatcher( let model_name = job.request.model.clone(); let estimated_vram = estimate_model_vram(&model_name); + let preferred_gpu = match state + .gpu_pool + .resolve_explicit_placement_gpu(job.request.placement.as_ref()) + { + Ok(ordinal) => ordinal, + Err(err_msg) => { + tracing::warn!(model = %model_name, "{err_msg}"); + if let Some(tx) = job.progress_tx { + let _ = tx.send(SseMessage::Error(SseErrorEvent { + message: err_msg.clone(), + })); + } + let _ = job.result_tx.send(Err(err_msg)); + state.queue.decrement(); + #[cfg(feature = "metrics")] + crate::metrics::record_queue_depth(state.queue.pending()); + continue; + } + }; + + if job.result_tx.is_closed() { + tracing::debug!(model = %model_name, "skipping queued multi-GPU job — client disconnected"); + state.queue.decrement(); + #[cfg(feature = "metrics")] + crate::metrics::record_queue_depth(state.queue.pending()); + continue; + } // Build the GpuJob once; the retry loop moves it between attempts. let mut gpu_job = Some(GpuJob { @@ -588,18 +615,50 @@ pub async fn run_queue_dispatcher( }); let mut skip: Vec = Vec::new(); - let max_attempts = state.gpu_pool.worker_count().max(1); let mut dispatched = false; - for _ in 0..max_attempts { - let worker = - match state + while !dispatched { + if gpu_job + .as_ref() + .is_some_and(|pending| pending.result_tx.is_closed()) + { + tracing::debug!( + model = %model_name, + "dropping queued multi-GPU job before dispatch — client disconnected" + ); + state.queue.decrement(); + break; + } + + let worker = if let Some(ordinal) = preferred_gpu { + state.gpu_pool.worker_by_ordinal(ordinal) + } else { + state .gpu_pool .select_worker_excluding(&model_name, estimated_vram, &skip) - { - Some(w) => w, - None => break, + }; + + let Some(worker) = worker else { + let rejected = gpu_job + .take() + .expect("gpu_job retained after failed dispatch"); + let err_msg = if state.gpu_pool.worker_count() == 0 { + format!("no GPU available for model {model_name}") + } else if let Some(ordinal) = preferred_gpu { + format!("gpu:{ordinal} is not available for model {model_name}") + } else { + format!("no GPU worker available for model {model_name}") }; + tracing::error!(model = %model_name, "{err_msg}"); + if let Some(tx) = rejected.progress_tx { + let _ = tx.send(SseMessage::Error(SseErrorEvent { + message: err_msg.clone(), + })); + } + let _ = rejected.result_tx.send(Err(err_msg)); + state.queue.decrement(); + break; + }; // Increment in-flight BEFORE sending to reserve the slot. worker.in_flight.fetch_add(1, Ordering::SeqCst); @@ -607,41 +666,47 @@ pub async fn run_queue_dispatcher( match worker.job_tx.try_send(pending) { Ok(()) => { dispatched = true; - break; } - Err(std::sync::mpsc::TrySendError::Full(j)) - | Err(std::sync::mpsc::TrySendError::Disconnected(j)) => { + Err(std::sync::mpsc::TrySendError::Full(j)) => { + worker.in_flight.fetch_sub(1, Ordering::SeqCst); + gpu_job = Some(j); + if preferred_gpu.is_none() { + skip.push(worker.gpu.ordinal); + if skip.len() >= state.gpu_pool.worker_count().max(1) { + skip.clear(); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } else { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + Err(std::sync::mpsc::TrySendError::Disconnected(j)) => { worker.in_flight.fetch_sub(1, Ordering::SeqCst); tracing::warn!( gpu = worker.gpu.ordinal, - "GPU worker channel full — retrying on another worker" + "GPU worker disconnected — retrying dispatch" ); - skip.push(worker.gpu.ordinal); gpu_job = Some(j); + if preferred_gpu.is_none() { + skip.push(worker.gpu.ordinal); + } else { + let rejected = gpu_job.take().expect("gpu_job retained after disconnect"); + let err_msg = format!( + "gpu:{} disconnected while dispatching model {model_name}", + worker.gpu.ordinal + ); + if let Some(tx) = rejected.progress_tx { + let _ = tx.send(SseMessage::Error(SseErrorEvent { + message: err_msg.clone(), + })); + } + let _ = rejected.result_tx.send(Err(err_msg)); + state.queue.decrement(); + break; + } } } } - - if !dispatched { - // Either no workers are eligible or every candidate's channel is full. - let rejected = gpu_job.expect("gpu_job retained after failed dispatch"); - let err_msg = if state.gpu_pool.worker_count() == 0 { - format!("no GPU available for model {model_name}") - } else { - format!("all GPU workers are busy for model {model_name} — queue is full") - }; - tracing::error!(model = %model_name, "{err_msg}"); - if let Some(tx) = rejected.progress_tx { - let _ = tx.send(SseMessage::Error(SseErrorEvent { - message: err_msg.clone(), - })); - } - let _ = rejected.result_tx.send(Err(err_msg)); - // Job was rejected before the worker could observe it, so we must - // release the global queue slot here — the worker-side decrement - // won't run. - state.queue.decrement(); - } #[cfg(feature = "metrics")] crate::metrics::record_queue_depth(state.queue.pending()); } @@ -670,8 +735,15 @@ pub fn estimate_model_vram(model_name: &str) -> u64 { #[cfg(test)] mod tests { use super::*; + use crate::gpu_pool::{GpuPool, GpuWorker}; + use crate::model_cache::ModelCache; + use crate::state::QueueHandle; use mold_core::{GenerateRequest, ImageData, OutputFormat}; use mold_db::MetadataDb; + use mold_inference::device::DiscoveredGpu; + use mold_inference::shared_pool::SharedPool; + use std::sync::atomic::AtomicUsize; + use std::sync::{Arc, Mutex, RwLock}; use tempfile::TempDir; /// A `GenerateRequest` with the bare minimum fields populated — enough to @@ -729,6 +801,33 @@ mod tests { } } + fn test_worker( + ordinal: usize, + channel_size: usize, + ) -> ( + Arc, + std::sync::mpsc::Receiver, + ) { + let (job_tx, job_rx) = std::sync::mpsc::sync_channel(channel_size); + let worker = Arc::new(GpuWorker { + gpu: DiscoveredGpu { + ordinal, + name: format!("gpu{ordinal}"), + total_vram_bytes: 24_000_000_000, + free_vram_bytes: 24_000_000_000, + }, + model_cache: Arc::new(Mutex::new(ModelCache::new(3))), + active_generation: Arc::new(RwLock::new(None)), + model_load_lock: Arc::new(Mutex::new(())), + shared_pool: Arc::new(Mutex::new(SharedPool::new())), + in_flight: AtomicUsize::new(0), + consecutive_failures: AtomicUsize::new(0), + degraded_until: RwLock::new(None), + job_tx, + }); + (worker, job_rx) + } + #[test] fn save_image_to_dir_writes_file_and_creates_missing_dir() { let tmp = TempDir::new().unwrap(); @@ -1050,4 +1149,105 @@ mod tests { assert!(!event.video_has_audio); assert!(event.video_duration_ms.is_none()); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn queue_dispatcher_waits_for_worker_capacity_instead_of_rejecting() { + let (worker, worker_rx) = test_worker(0, 1); + let (job_tx, job_rx) = tokio::sync::mpsc::channel(4); + let queue = QueueHandle::new(job_tx.clone()); + let state = crate::state::AppState::empty( + mold_core::Config::default(), + queue.clone(), + Arc::new(GpuPool { + workers: vec![worker.clone()], + }), + 8, + ); + + let (filler_result_tx, _filler_result_rx) = tokio::sync::oneshot::channel(); + let filler_job = crate::gpu_pool::GpuJob { + model: "busy-model".to_string(), + request: fake_request("busy-model"), + progress_tx: None, + result_tx: filler_result_tx, + output_dir: None, + config: state.config.clone(), + metadata_db: state.metadata_db.clone(), + queue: state.queue.clone(), + }; + worker.job_tx.send(filler_job).unwrap(); + + let dispatcher = tokio::spawn(run_queue_dispatcher(job_rx, state.clone())); + + let (result_tx, mut result_rx) = tokio::sync::oneshot::channel(); + let job = crate::state::GenerationJob { + request: fake_request("flux-dev:q4"), + progress_tx: None, + result_tx, + output_dir: None, + }; + let _position = queue.submit(job, 8).await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + assert!( + result_rx.try_recv().is_err(), + "dispatcher should keep the job pending while all worker channels are full" + ); + + let _filler = worker_rx + .recv() + .expect("filler job should occupy the local channel"); + let dispatched = worker_rx + .recv_timeout(std::time::Duration::from_secs(1)) + .expect("queued job should dispatch once capacity is available"); + assert_eq!(dispatched.model, "flux-dev:q4"); + + drop(job_tx); + dispatcher.abort(); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn queue_dispatcher_honors_explicit_placement_gpu() { + let (worker0, rx0) = test_worker(0, 1); + let (worker1, rx1) = test_worker(1, 1); + let (job_tx, job_rx) = tokio::sync::mpsc::channel(4); + let queue = QueueHandle::new(job_tx.clone()); + let state = crate::state::AppState::empty( + mold_core::Config::default(), + queue.clone(), + Arc::new(GpuPool { + workers: vec![worker0, worker1], + }), + 8, + ); + + let dispatcher = tokio::spawn(run_queue_dispatcher(job_rx, state)); + + let mut request = fake_request("flux-dev:q4"); + request.placement = Some(mold_core::types::DevicePlacement { + text_encoders: mold_core::types::DeviceRef::Auto, + advanced: Some(mold_core::types::AdvancedPlacement { + transformer: mold_core::types::DeviceRef::gpu(1), + ..mold_core::types::AdvancedPlacement::default() + }), + }); + + let (result_tx, _result_rx) = tokio::sync::oneshot::channel(); + let job = crate::state::GenerationJob { + request, + progress_tx: None, + result_tx, + output_dir: None, + }; + let _position = queue.submit(job, 8).await.unwrap(); + + let dispatched = rx1 + .recv_timeout(std::time::Duration::from_secs(1)) + .expect("explicit placement should route to gpu 1"); + assert_eq!(dispatched.model, "flux-dev:q4"); + assert!(rx0.try_recv().is_err(), "gpu 0 should not receive the job"); + + drop(job_tx); + dispatcher.abort(); + } } diff --git a/crates/mold-server/src/resources.rs b/crates/mold-server/src/resources.rs index d764e59b..50e33390 100644 --- a/crates/mold-server/src/resources.rs +++ b/crates/mold-server/src/resources.rs @@ -113,6 +113,9 @@ pub(crate) mod nvml_source { .sum::() }); let used_by_other = used_by_mold.map(|m| mem.used.saturating_sub(m)); + // NVML's GPU-core utilization over the last sample period. + // Cheap — this is just a driver query, not a counter reset. + let gpu_util = dev.utilization_rates().ok().map(|u| u.gpu.min(100) as u8); out.push(GpuSnapshot { ordinal: ordinal as usize, name, @@ -121,6 +124,7 @@ pub(crate) mod nvml_source { vram_used: mem.used, vram_used_by_mold: used_by_mold, vram_used_by_other: used_by_other, + gpu_utilization: gpu_util, }); } out @@ -160,8 +164,8 @@ pub fn parse_nvidia_smi_line(line: &str) -> Option<(usize, String, u64, u64)> { Some((ordinal, name, total_mb * 1_000_000, used_mb * 1_000_000)) } -use mold_core::RamSnapshot; -use sysinfo::{Pid, ProcessRefreshKind, RefreshKind, System}; +use mold_core::{CpuSnapshot, RamSnapshot}; +use sysinfo::{CpuRefreshKind, Pid, ProcessRefreshKind, RefreshKind, System}; /// Metal unified-memory snapshot — macOS only. Off-Darwin returns an empty /// Vec so callers on Linux/CUDA hosts can unconditionally call this. @@ -188,6 +192,7 @@ pub fn metal_snapshot() -> Vec { vram_used: used, vram_used_by_mold: None, vram_used_by_other: None, + gpu_utilization: None, }] } #[cfg(not(target_os = "macos"))] @@ -265,6 +270,7 @@ impl SmiSource { vram_used: used, vram_used_by_mold: None, vram_used_by_other: None, + gpu_utilization: None, }) }) .collect() @@ -276,7 +282,15 @@ impl SmiSource { /// /// Source priority on CUDA: NVML (if linked) → `nvidia-smi` subprocess → empty. /// On macOS: `metal_snapshot()`. +/// +/// CPU utilization is `None` — call `build_snapshot_with_cpu` with a +/// persistent `System` to populate it (sysinfo computes CPU usage from +/// deltas between refreshes, so the aggregator needs to hold state). pub fn build_snapshot() -> ResourceSnapshot { + build_snapshot_inner(None) +} + +fn build_snapshot_inner(cpu: Option) -> ResourceSnapshot { let hostname = hostname::get() .ok() .and_then(|h| h.into_string().ok()) @@ -294,6 +308,40 @@ pub fn build_snapshot() -> ResourceSnapshot { timestamp, gpus, system_ram, + cpu, + } +} + +/// Holds the persistent `System` sysinfo needs for CPU delta computation. +pub struct CpuSampler { + sys: System, + cores: u16, +} + +impl CpuSampler { + pub fn new() -> Self { + let mut sys = System::new_with_specifics( + RefreshKind::nothing().with_cpu(CpuRefreshKind::everything().with_cpu_usage()), + ); + // Prime the sampler. The first `global_cpu_usage()` read always + // returns 0 — the real number shows up on the second refresh. + sys.refresh_cpu_usage(); + let cores = sys.cpus().len().min(u16::MAX as usize) as u16; + Self { sys, cores } + } + + pub fn sample(&mut self) -> CpuSnapshot { + self.sys.refresh_cpu_usage(); + CpuSnapshot { + cores: self.cores, + usage_percent: self.sys.global_cpu_usage().clamp(0.0, 100.0), + } + } +} + +impl Default for CpuSampler { + fn default() -> Self { + Self::new() } } @@ -326,27 +374,45 @@ fn collect_gpus() -> Vec { pub fn spawn_aggregator(bcast: Arc) -> JoinHandle<()> { tokio::spawn(async move { // Immediate first tick so `latest()` is populated before any HTTP - // request arrives. - bcast.publish(build_snapshot()); + // request arrives. CPU usage is None on this first sample (no delta + // to compute against yet). + bcast.publish(build_snapshot_inner(None)); let mut interval = tokio::time::interval(Duration::from_secs(1)); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); // Consume the first tick (it fires immediately) so we don't double-emit. interval.tick().await; + + // The sampler lives on the blocking thread across ticks — sysinfo + // computes CPU usage from deltas, so we can't rebuild it every tick. + let mut sampler: Option = None; loop { interval.tick().await; - let snap = tokio::task::spawn_blocking(build_snapshot) - .await - .unwrap_or_else(|_| ResourceSnapshot { - hostname: "unknown".to_string(), - timestamp: 0, - gpus: Vec::new(), - system_ram: mold_core::RamSnapshot { - total: 0, - used: 0, - used_by_mold: 0, - used_by_other: 0, + let taken = sampler.take(); + let (snap, returned) = tokio::task::spawn_blocking(move || { + let mut s = taken.unwrap_or_default(); + let cpu = s.sample(); + let snap = build_snapshot_inner(Some(cpu)); + (snap, s) + }) + .await + .unwrap_or_else(|_| { + ( + ResourceSnapshot { + hostname: "unknown".to_string(), + timestamp: 0, + gpus: Vec::new(), + system_ram: mold_core::RamSnapshot { + total: 0, + used: 0, + used_by_mold: 0, + used_by_other: 0, + }, + cpu: None, }, - }); + CpuSampler::new(), + ) + }); + sampler = Some(returned); bcast.publish(snap); } }) diff --git a/crates/mold-server/src/resources_test.rs b/crates/mold-server/src/resources_test.rs index d6bff40a..2e34ba22 100644 --- a/crates/mold-server/src/resources_test.rs +++ b/crates/mold-server/src/resources_test.rs @@ -15,6 +15,7 @@ fn fake_snapshot() -> ResourceSnapshot { vram_used: 0, vram_used_by_mold: Some(0), vram_used_by_other: Some(0), + gpu_utilization: None, }], system_ram: RamSnapshot { total: 64_000_000_000, @@ -22,6 +23,7 @@ fn fake_snapshot() -> ResourceSnapshot { used_by_mold: 0, used_by_other: 0, }, + cpu: None, } } diff --git a/crates/mold-server/src/routes.rs b/crates/mold-server/src/routes.rs index 4ea11cd8..cfaec861 100644 --- a/crates/mold-server/src/routes.rs +++ b/crates/mold-server/src/routes.rs @@ -10,11 +10,13 @@ use axum::{ }; use base64::Engine as _; use mold_core::{ - ActiveGenerationStatus, GpuInfo, GpuWorkerState, ModelInfoExtended, ResourceSnapshot, - ServerStatus, SseErrorEvent, SseProgressEvent, + types::GpuSelection, ActiveGenerationStatus, GpuInfo, GpuWorkerState, ModelInfoExtended, + ResourceSnapshot, ServerStatus, SseErrorEvent, SseProgressEvent, }; use serde::{Deserialize, Serialize}; +use std::cmp::Reverse; use std::convert::Infallible; +use std::sync::atomic::Ordering; use tokio_stream::StreamExt as _; use utoipa::OpenApi; @@ -133,7 +135,19 @@ use crate::queue::clean_error_message; #[derive(OpenApi)] #[openapi( - paths(generate, generate_stream, expand_prompt, list_models, load_model, pull_model_endpoint, unload_model, server_status, health), + paths( + generate, + generate_stream, + expand_prompt, + list_models, + load_model, + pull_model_endpoint, + unload_model, + server_status, + health, + crate::routes_chain::generate_chain, + crate::routes_chain::generate_chain_stream, + ), components(schemas( mold_core::GenerateRequest, mold_core::GenerateResponse, @@ -148,6 +162,11 @@ use crate::queue::clean_error_message; mold_core::SseProgressEvent, mold_core::SseCompleteEvent, mold_core::SseErrorEvent, + mold_core::ChainRequest, + mold_core::ChainResponse, + mold_core::ChainStage, + mold_core::ChainProgressEvent, + mold_core::SseChainCompleteEvent, ModelInfoExtended, LoadModelBody, UnloadRequest, @@ -171,6 +190,14 @@ pub fn create_router(state: AppState) -> Router { Router::new() .route("/api/generate", post(generate)) .route("/api/generate/stream", post(generate_stream)) + .route( + "/api/generate/chain", + post(crate::routes_chain::generate_chain), + ) + .route( + "/api/generate/chain/stream", + post(crate::routes_chain::generate_chain_stream), + ) .route("/api/expand", post(expand_prompt)) .route("/api/models", get(list_models)) .route("/api/models/load", post(load_model)) @@ -275,8 +302,10 @@ async fn prepare_generation( // return `SubmitError::Full`, which is mapped to `ApiError::queue_full()`. apply_default_metadata_setting(state, request).await; + let preferred_gpu = validate_multi_gpu_placement(state, request.placement.as_ref())?; + // Expand prompt if requested (before validation, so the expanded prompt gets validated) - maybe_expand_prompt(state, request).await?; + maybe_expand_prompt(state, request, preferred_gpu).await?; if let Err(e) = validate_generate_request(request) { return Err(ApiError::validation(e)); @@ -301,6 +330,61 @@ async fn prepare_generation( Ok((output_dir, dim_warning)) } +fn active_gpu_selection(state: &AppState) -> GpuSelection { + let ordinals: Vec = state + .gpu_pool + .workers + .iter() + .map(|w| w.gpu.ordinal) + .collect(); + if ordinals.is_empty() { + GpuSelection::All + } else { + GpuSelection::Specific(ordinals) + } +} + +fn validate_multi_gpu_placement( + state: &AppState, + placement: Option<&mold_core::types::DevicePlacement>, +) -> Result, ApiError> { + state + .gpu_pool + .resolve_explicit_placement_gpu(placement) + .map_err(ApiError::validation) +} + +fn select_aux_worker( + state: &AppState, +) -> Result, ApiError> { + let mut workers: Vec<_> = state + .gpu_pool + .workers + .iter() + .filter(|w| !w.is_degraded()) + .cloned() + .collect(); + workers.sort_by_key(|w| { + ( + w.in_flight.load(Ordering::SeqCst), + Reverse(w.gpu.total_vram_bytes), + ) + }); + workers + .into_iter() + .next() + .ok_or_else(|| ApiError::internal("no GPU worker available for auxiliary workload")) +} + +fn clear_global_upscaler_cache(state: &AppState) { + if let Ok(mut cache) = state.upscaler_cache.try_lock() { + if cache.is_some() { + *cache = None; + tracing::info!("upscaler cache cleared"); + } + } +} + // ── /api/generate ───────────────────────────────────────────────────────────── #[utoipa::path( @@ -460,12 +544,14 @@ async fn apply_default_metadata_setting(state: &AppState, req: &mut mold_core::G async fn maybe_expand_prompt( state: &AppState, req: &mut mold_core::GenerateRequest, + preferred_gpu: Option, ) -> Result<(), ApiError> { if req.expand != Some(true) { return Ok(()); } let config = state.config.read().await; + let config_snapshot = config.clone(); let expand_settings = config.expand.clone().with_env_overrides(); // Resolve model family for prompt style @@ -487,7 +573,12 @@ async fn maybe_expand_prompt( // Drop config lock before blocking drop(config); - let expander = create_server_expander(&expand_settings)?; + let expander = create_server_expander( + &config_snapshot, + &expand_settings, + active_gpu_selection(state), + preferred_gpu, + )?; let result = tokio::task::spawn_blocking(move || expander.expand(&original_prompt, &expand_config)) .await @@ -504,7 +595,10 @@ async fn maybe_expand_prompt( /// Create the appropriate expander for server-side use. fn create_server_expander( + _config: &mold_core::Config, settings: &mold_core::ExpandSettings, + _gpu_selection: GpuSelection, + _preferred_gpu: Option, ) -> Result, ApiError> { if let Some(api_expander) = settings.create_api_expander() { return Ok(Box::new(api_expander)); @@ -512,11 +606,14 @@ fn create_server_expander( #[cfg(feature = "expand")] { - let config = mold_core::Config::load_or_default(); if let Some(local) = - mold_inference::expand::LocalExpander::from_config(&config, Some(&settings.model)) + mold_inference::expand::LocalExpander::from_config(_config, Some(&settings.model)) { - return Ok(Box::new(local)); + return Ok(Box::new( + local + .with_gpu_selection(_gpu_selection) + .with_preferred_gpu(_preferred_gpu), + )); } return Err(ApiError::validation( "local expand model not found — run: mold pull qwen3-expand".to_string(), @@ -561,9 +658,15 @@ async fn expand_prompt( let expand_settings = config.expand.clone().with_env_overrides(); let expand_config = expand_settings.to_expand_config(&req.model_family, req.variations); let prompt = req.prompt.clone(); + let config_snapshot = config.clone(); drop(config); - let expander = create_server_expander(&expand_settings)?; + let expander = create_server_expander( + &config_snapshot, + &expand_settings, + active_gpu_selection(&state), + None, + )?; let result = tokio::task::spawn_blocking(move || expander.expand(&prompt, &expand_config)) .await .map_err(|e| ApiError::internal(format!("expand task failed: {e}")))? @@ -621,12 +724,40 @@ async fn upscale( let model_name_owned = model_name.clone(); drop(config); - let upscaler_cache = state.upscaler_cache.clone(); - let resp = + let resp = if state.gpu_pool.worker_count() > 0 { + let worker = select_aux_worker(&state)?; + worker.in_flight.fetch_add(1, Ordering::SeqCst); + let worker_clone = worker.clone(); + let result = + tokio::task::spawn_blocking(move || -> anyhow::Result { + struct ThreadGpuGuard; + impl Drop for ThreadGpuGuard { + fn drop(&mut self) { + mold_inference::device::clear_thread_gpu_ordinal(); + } + } + + mold_inference::device::init_thread_gpu_ordinal(worker_clone.gpu.ordinal); + let _thread_gpu = ThreadGpuGuard; + let _load_lock = worker_clone.model_load_lock.lock().unwrap(); + let mut engine = mold_inference::create_upscale_engine( + model_name_owned, + weights_path, + mold_inference::LoadStrategy::Eager, + worker_clone.gpu.ordinal, + )?; + engine.upscale(&req) + }) + .await + .map_err(|e| ApiError::internal(format!("upscale task panicked: {e}"))); + worker.in_flight.fetch_sub(1, Ordering::SeqCst); + result?.map_err(|e| ApiError::internal(format!("upscale failed: {e}")))? + } else { + let upscaler_cache = state.upscaler_cache.clone(); tokio::task::spawn_blocking(move || -> anyhow::Result { let mut cache = upscaler_cache.lock().unwrap_or_else(|e| e.into_inner()); - // Reuse cached engine if same model + // Reuse cached engine if same model. let needs_new = cache .as_ref() .is_none_or(|e| e.model_name() != model_name_owned); @@ -635,6 +766,7 @@ async fn upscale( model_name_owned, weights_path, mold_inference::LoadStrategy::Eager, + 0, )?; *cache = Some(new_engine); } @@ -643,7 +775,8 @@ async fn upscale( }) .await .map_err(|e| ApiError::internal(format!("upscale task panicked: {e}")))? - .map_err(|e| ApiError::internal(format!("upscale failed: {e}")))?; + .map_err(|e| ApiError::internal(format!("upscale failed: {e}")))? + }; Ok(Json(resp)) } @@ -783,70 +916,163 @@ async fn upscale_stream( return; }; - let result = tokio::task::spawn_blocking(move || { - let mut cache = upscaler_cache.lock().unwrap(); + let result = if state_clone.gpu_pool.worker_count() > 0 { + match select_aux_worker(&state_clone) { + Ok(worker) => { + worker.in_flight.fetch_add(1, Ordering::SeqCst); + let worker_clone = worker.clone(); + let tx_for_worker = tx.clone(); + let model_name_for_worker = model_name_owned.clone(); + let weights_path_for_worker = weights_path.clone(); + let req_for_worker = req.clone(); + let result = tokio::task::spawn_blocking(move || { + struct ThreadGpuGuard; + impl Drop for ThreadGpuGuard { + fn drop(&mut self) { + mold_inference::device::clear_thread_gpu_ordinal(); + } + } - let needs_new = cache - .as_ref() - .is_none_or(|e| e.model_name() != model_name_owned); - if needs_new { - let _ = tx.send(SseMessage::Progress( - mold_core::SseProgressEvent::StageStart { - name: "Loading upscaler model".to_string(), - }, - )); - match mold_inference::create_upscale_engine( - model_name_owned, - weights_path, - mold_inference::LoadStrategy::Eager, - ) { - Ok(new_engine) => { - *cache = Some(new_engine); - } - Err(e) => { - let _ = tx.send(SseMessage::Error(mold_core::SseErrorEvent { - message: format!("failed to load upscaler: {e}"), + mold_inference::device::init_thread_gpu_ordinal(worker_clone.gpu.ordinal); + let _thread_gpu = ThreadGpuGuard; + let _load_lock = worker_clone.model_load_lock.lock().unwrap(); + let _ = tx_for_worker.send(SseMessage::Progress( + mold_core::SseProgressEvent::StageStart { + name: format!( + "Loading upscaler model on GPU {}", + worker_clone.gpu.ordinal + ), + }, + )); + let mut engine = match mold_inference::create_upscale_engine( + model_name_for_worker, + weights_path_for_worker, + mold_inference::LoadStrategy::Eager, + worker_clone.gpu.ordinal, + ) { + Ok(engine) => engine, + Err(e) => { + let _ = tx_for_worker.send(SseMessage::Error( + mold_core::SseErrorEvent { + message: format!("failed to load upscaler: {e}"), + }, + )); + return; + } + }; + + let tx_progress = tx_for_worker.clone(); + engine.set_on_progress(Box::new(move |event| { + let sse_event: mold_core::SseProgressEvent = event.into(); + let _ = tx_progress.send(SseMessage::Progress(sse_event)); })); - return; - } - } - } - let engine = cache.as_mut().unwrap(); - - // Install progress callback for tile-by-tile progress - let tx_progress = tx.clone(); - engine.set_on_progress(Box::new(move |event| { - let sse_event: mold_core::SseProgressEvent = event.into(); - let _ = tx_progress.send(SseMessage::Progress(sse_event)); - })); + match engine.upscale(&req_for_worker) { + Ok(resp) => { + let image_b64 = base64::engine::general_purpose::STANDARD + .encode(&resp.image.data); + let _ = tx_for_worker.send(SseMessage::UpscaleComplete( + mold_core::SseUpscaleCompleteEvent { + image: image_b64, + format: resp.image.format, + model: resp.model, + scale_factor: resp.scale_factor, + original_width: resp.original_width, + original_height: resp.original_height, + upscale_time_ms: resp.upscale_time_ms, + }, + )); + } + Err(e) => { + let _ = tx_for_worker.send(SseMessage::Error( + mold_core::SseErrorEvent { + message: format!("upscale failed: {e}"), + }, + )); + } + } - match engine.upscale(&req) { - Ok(resp) => { - let image_b64 = - base64::engine::general_purpose::STANDARD.encode(&resp.image.data); - let _ = tx.send(SseMessage::UpscaleComplete( - mold_core::SseUpscaleCompleteEvent { - image: image_b64, - format: resp.image.format, - model: resp.model, - scale_factor: resp.scale_factor, - original_width: resp.original_width, - original_height: resp.original_height, - upscale_time_ms: resp.upscale_time_ms, - }, - )); + engine.clear_on_progress(); + }) + .await; + worker.in_flight.fetch_sub(1, Ordering::SeqCst); + result } Err(e) => { let _ = tx.send(SseMessage::Error(mold_core::SseErrorEvent { - message: format!("upscale failed: {e}"), + message: e.error, })); + return; } } + } else { + let model_name_for_cache = model_name_owned.clone(); + let weights_path_for_cache = weights_path.clone(); + let req_for_cache = req.clone(); + tokio::task::spawn_blocking(move || { + let mut cache = upscaler_cache.lock().unwrap(); + + let needs_new = cache + .as_ref() + .is_none_or(|e| e.model_name() != model_name_for_cache); + if needs_new { + let _ = tx.send(SseMessage::Progress( + mold_core::SseProgressEvent::StageStart { + name: "Loading upscaler model".to_string(), + }, + )); + match mold_inference::create_upscale_engine( + model_name_for_cache, + weights_path_for_cache, + mold_inference::LoadStrategy::Eager, + 0, + ) { + Ok(new_engine) => { + *cache = Some(new_engine); + } + Err(e) => { + let _ = tx.send(SseMessage::Error(mold_core::SseErrorEvent { + message: format!("failed to load upscaler: {e}"), + })); + return; + } + } + } - engine.clear_on_progress(); - }) - .await; + let engine = cache.as_mut().unwrap(); + let tx_progress = tx.clone(); + engine.set_on_progress(Box::new(move |event| { + let sse_event: mold_core::SseProgressEvent = event.into(); + let _ = tx_progress.send(SseMessage::Progress(sse_event)); + })); + + match engine.upscale(&req_for_cache) { + Ok(resp) => { + let image_b64 = + base64::engine::general_purpose::STANDARD.encode(&resp.image.data); + let _ = tx.send(SseMessage::UpscaleComplete( + mold_core::SseUpscaleCompleteEvent { + image: image_b64, + format: resp.image.format, + model: resp.model, + scale_factor: resp.scale_factor, + original_width: resp.original_width, + original_height: resp.original_height, + upscale_time_ms: resp.upscale_time_ms, + }, + )); + } + Err(e) => { + let _ = tx.send(SseMessage::Error(mold_core::SseErrorEvent { + message: format!("upscale failed: {e}"), + })); + } + } + + engine.clear_on_progress(); + }) + .await + }; if let Err(e) = result { tracing::error!("upscale task panicked: {e}"); @@ -1217,6 +1443,7 @@ async fn unload_model( ) -> Result { let req = body.map(|b| b.0).unwrap_or_default(); tracing::debug!(model = ?req.model, gpu = ?req.gpu, "unload request"); + clear_global_upscaler_cache(&state); // Multi-GPU path: target specific GPU or model across the pool. if state.gpu_pool.worker_count() > 0 { @@ -1309,11 +1536,8 @@ async fn server_status(State(state): State) -> Json { let gen = w.active_generation.read().ok()?; gen.as_ref().map(|g| ActiveGenerationStatus { model: g.model.clone(), - // The per-worker ActiveGeneration doesn't carry the prompt hash, - // so expose the model-only summary. Callers that need the hash - // can subscribe to SSE progress events. - prompt_sha256: String::new(), - started_at_unix_ms: 0, + prompt_sha256: g.prompt_sha256.clone(), + started_at_unix_ms: g.started_at_unix_ms, elapsed_ms: g.started_at.elapsed().as_millis() as u64, }) }) @@ -2409,6 +2633,7 @@ async fn put_model_placement( axum::extract::Path(name): axum::extract::Path, Json(placement): Json, ) -> Result, ApiError> { + validate_multi_gpu_placement(&state, Some(&placement))?; { let mut cfg = state.config.write().await; cfg.set_model_placement(&name, Some(placement.clone())); diff --git a/crates/mold-server/src/routes_chain.rs b/crates/mold-server/src/routes_chain.rs new file mode 100644 index 00000000..d2240cf4 --- /dev/null +++ b/crates/mold-server/src/routes_chain.rs @@ -0,0 +1,797 @@ +//! Server-side chained video generation endpoints. +//! +//! Exposes `POST /api/generate/chain` (synchronous) and +//! `POST /api/generate/chain/stream` (SSE). Both drive +//! [`mold_inference::ltx2::Ltx2ChainOrchestrator`] through an engine's +//! [`mold_inference::ltx2::ChainStageRenderer`] view. +//! +//! Unlike the single-shot generate path (which queues through +//! [`crate::state::QueueHandle`] to keep small GPU jobs FIFO-fair), chains +//! are multi-minute compound jobs — the handler take/restores the engine +//! out of the model cache and runs the full sequence in a +//! [`tokio::task::spawn_blocking`] so the sync orchestrator never blocks +//! the async runtime. While the chain is running the engine is removed +//! from the cache, so concurrent generate/chain requests for the same +//! model cannot race. + +use std::convert::Infallible; + +use axum::{ + extract::State, + response::sse::{Event as SseEvent, KeepAlive, Sse}, + Json, +}; +use base64::Engine as _; +use mold_core::chain::{ChainProgressEvent, ChainRequest, ChainResponse, SseChainCompleteEvent}; +use mold_core::{OutputFormat, OutputMetadata, VideoData}; +use tokio_stream::StreamExt as _; + +use crate::model_cache::CachedEngine; +use crate::model_manager; +use crate::queue::save_video_to_dir; +use crate::routes::ApiError; +use crate::state::AppState; + +/// Internal wire event used by the chain SSE stream before per-event +/// serialization. Separate from [`crate::state::SseMessage`] because chain +/// complete events carry a different payload (`SseChainCompleteEvent`) and +/// progress events are chain-shaped (`ChainProgressEvent`) rather than the +/// single-stage `SseProgressEvent`. +pub(crate) enum ChainSseMessage { + Progress(ChainProgressEvent), + Complete(SseChainCompleteEvent), + Error(String), +} + +fn chain_sse_event(msg: ChainSseMessage) -> SseEvent { + match msg { + ChainSseMessage::Progress(ev) => match serde_json::to_string(&ev) { + Ok(data) => SseEvent::default().event("progress").data(data), + Err(e) => SseEvent::default() + .event("error") + .data(format!(r#"{{"message":"serialize progress: {e}"}}"#)), + }, + ChainSseMessage::Complete(ev) => match serde_json::to_string(&ev) { + Ok(data) => SseEvent::default().event("complete").data(data), + Err(e) => SseEvent::default() + .event("error") + .data(format!(r#"{{"message":"serialize complete: {e}"}}"#)), + }, + ChainSseMessage::Error(message) => SseEvent::default() + .event("error") + .data(serde_json::json!({ "message": message }).to_string()), + } +} + +/// Encode chain frames into bytes for the requested output format. Returns +/// the encoded payload plus a best-effort animated-GIF preview for the +/// gallery. +/// +/// MP4 is gated behind the `mp4` feature flag; when the flag is disabled, +/// the handler falls back to APNG so the endpoint still produces a usable +/// animation on every build. +fn encode_chain_output( + frames: &[image::RgbImage], + fps: u32, + format: OutputFormat, +) -> anyhow::Result<(Vec, OutputFormat, Vec)> { + use mold_inference::ltx_video::video_enc; + + // Always produce a GIF preview for the gallery UI. Non-fatal. + let gif_preview = match video_enc::encode_gif(frames, fps) { + Ok(b) => b, + Err(e) => { + tracing::warn!("chain gif preview encode failed: {e:#}"); + Vec::new() + } + }; + + let (bytes, actual_format) = match format { + OutputFormat::Mp4 => { + #[cfg(feature = "mp4")] + { + (video_enc::encode_mp4(frames, fps)?, OutputFormat::Mp4) + } + #[cfg(not(feature = "mp4"))] + { + tracing::warn!( + "chain requested MP4 but server was built without the `mp4` feature — \ + falling back to APNG" + ); + ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ) + } + } + OutputFormat::Apng => ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ), + OutputFormat::Gif => (video_enc::encode_gif(frames, fps)?, OutputFormat::Gif), + // WebP is always available here because mold-inference's webp + // feature would need to gate at the transitive-dep level; for the + // chain route v1 we fall back to APNG when WebP is requested so + // we don't bind the server crate to another optional dep. + OutputFormat::Webp => { + tracing::warn!( + "chain WebP output is not supported on the server yet — falling back to APNG" + ); + ( + video_enc::encode_apng(frames, fps, None)?, + OutputFormat::Apng, + ) + } + other => anyhow::bail!("{other:?} is not a video output format for chain generation"), + }; + + Ok((bytes, actual_format, gif_preview)) +} + +/// Build the `OutputMetadata` for a stitched chain output. Pulls chain- +/// level parameters (dimensions, seed, steps) from `req` and the prompt / +/// negative prompt from `stages[0]`. +fn chain_output_metadata(req: &ChainRequest, frame_count: u32) -> OutputMetadata { + let first_stage = req.stages.first(); + OutputMetadata { + prompt: first_stage.map(|s| s.prompt.clone()).unwrap_or_default(), + negative_prompt: first_stage.and_then(|s| s.negative_prompt.clone()), + original_prompt: None, + model: req.model.clone(), + seed: req.seed.unwrap_or(0), + steps: req.steps, + guidance: req.guidance, + width: req.width, + height: req.height, + strength: Some(req.strength), + scheduler: None, + lora: None, + lora_scale: None, + frames: Some(frame_count), + fps: Some(req.fps), + version: mold_core::build_info::version_string().to_string(), + } +} + +/// Trim a frame buffer to the caller's requested total frame count, per +/// the signed-off "trim from tail" decision (2026-04-20). The orchestrator +/// always over-produces to hit or exceed `total_frames`; trimming here +/// keeps the output length deterministic without altering per-stage +/// denoise behaviour. +fn trim_to_total_frames(frames: &mut Vec, total_frames: Option) { + if let Some(target) = total_frames { + let target = target as usize; + if frames.len() > target { + frames.truncate(target); + } + } +} + +/// Produce a PNG thumbnail for the chain output — best-effort, returns +/// an empty `Vec` on failure so the save/response paths still succeed. +fn chain_thumbnail(frames: &[image::RgbImage]) -> Vec { + match mold_inference::ltx_video::video_enc::first_frame_png(frames) { + Ok(b) => b, + Err(e) => { + tracing::warn!("chain thumbnail encode failed: {e:#}"); + Vec::new() + } + } +} + +/// Build a `VideoData` for the `ChainResponse` body. +fn build_video_data( + bytes: Vec, + format: OutputFormat, + req: &ChainRequest, + frame_count: u32, + thumbnail: Vec, + gif_preview: Vec, +) -> VideoData { + let duration_ms = if req.fps == 0 { + None + } else { + Some((frame_count as u64 * 1000) / req.fps as u64) + }; + VideoData { + data: bytes, + format, + width: req.width, + height: req.height, + frames: frame_count, + fps: req.fps, + thumbnail, + gif_preview, + has_audio: false, + duration_ms, + audio_sample_rate: None, + audio_channels: None, + } +} + +/// Build the SSE `complete` payload for a finished chain run. Sibling of +/// [`crate::queue::build_sse_complete_event`] — kept in this module so the +/// chain-specific payload can evolve independently from the single-shot +/// one. +fn build_sse_chain_complete_event( + resp: &ChainResponse, + generation_time_ms: u64, +) -> SseChainCompleteEvent { + let b64 = base64::engine::general_purpose::STANDARD; + let video = &resp.video; + SseChainCompleteEvent { + video: b64.encode(&video.data), + format: video.format, + width: video.width, + height: video.height, + frames: video.frames, + fps: video.fps, + thumbnail: if video.thumbnail.is_empty() { + None + } else { + Some(b64.encode(&video.thumbnail)) + }, + gif_preview: if video.gif_preview.is_empty() { + None + } else { + Some(b64.encode(&video.gif_preview)) + }, + has_audio: video.has_audio, + duration_ms: video.duration_ms, + audio_sample_rate: video.audio_sample_rate, + audio_channels: video.audio_channels, + stage_count: resp.stage_count, + gpu: resp.gpu, + generation_time_ms: Some(generation_time_ms), + } +} + +/// Errors surfaced from the chain-run helper. Mapped to appropriate HTTP +/// status codes by the route handlers. +#[derive(Debug)] +enum ChainRunError { + /// Model family doesn't support chain rendering (422). + UnsupportedModel(String), + /// Engine missing from cache after `ensure_model_ready` (500). + CacheMiss(String), + /// Orchestrator returned an error mid-chain (502). + Inference(String), + /// Output encoding / stitch failure (500). + Encode(String), + /// Task panic or join error (500). + Internal(String), +} + +impl From for ApiError { + fn from(err: ChainRunError) -> Self { + match err { + ChainRunError::UnsupportedModel(msg) => ApiError::validation(msg), + ChainRunError::CacheMiss(msg) => ApiError::internal(msg), + ChainRunError::Inference(msg) => { + ApiError::internal_with_status(msg, axum::http::StatusCode::BAD_GATEWAY) + } + ChainRunError::Encode(msg) => ApiError::internal(msg), + ChainRunError::Internal(msg) => ApiError::internal(msg), + } + } +} + +/// Drive the chain to completion. Shared between the non-streaming and SSE +/// paths — the only caller-provided variable is `progress_cb`, which is +/// `None` for the plain JSON endpoint and `Some` for the SSE endpoint. +async fn run_chain( + state: &AppState, + req: ChainRequest, + progress_cb: Option>, +) -> Result<(ChainResponse, u64), ChainRunError> { + // Serialize concurrent chain requests. The chain handler deliberately + // takes the engine out of `model_cache` for the full multi-minute run + // (see below) — without this lock a second chain request arriving + // mid-run calls `ensure_model_ready`, sees an empty cache, tries to + // load a second copy of the model, and the subsequent `cache.take()` + // reports "engine vanished from cache after ensure_model_ready". + // Holding for the whole chain is intentional: single-clip requests + // keep flowing through the normal generation queue; only chains wait + // on each other. + let _chain_guard = state.chain_lock.lock().await; + + // Ensure the model is loaded. Progress forwarding is not plumbed yet — + // load-time events go through the model manager's own tracing. Chain + // stage events (StageStart/DenoiseStep/StageDone/Stitching) come from + // the orchestrator during the blocking task below. + model_manager::ensure_model_ready(state, &req.model, None) + .await + .map_err(|e| ChainRunError::CacheMiss(e.error))?; + + // Take the engine out of the cache so the blocking orchestrator run + // owns it for the full multi-minute chain without holding the async + // mutex guard across an await. Restore when we're done (or on error). + let mut cache = state.model_cache.lock().await; + let cached: CachedEngine = cache.take(&req.model).ok_or_else(|| { + ChainRunError::CacheMiss(format!( + "engine '{}' vanished from cache after ensure_model_ready", + req.model + )) + })?; + drop(cache); + + let req_for_task = req.clone(); + let join_handle = tokio::task::spawn_blocking(move || { + let mut cached = cached; + let mut progress_cb = progress_cb; + let outcome = { + let engine = &mut cached.engine; + match engine.as_chain_renderer() { + Some(renderer) => { + let mut orch = mold_inference::ltx2::Ltx2ChainOrchestrator::new(renderer); + // The orchestrator expects `Option<&mut dyn FnMut(...)>` + // — synthesise that from the optional boxed callback we + // moved into this task. + let result = if let Some(cb) = progress_cb.as_deref_mut() { + orch.run(&req_for_task, Some(cb)) + } else { + orch.run(&req_for_task, None) + }; + result.map_err(|e| ChainRunError::Inference(format!("{e:#}"))) + } + None => Err(ChainRunError::UnsupportedModel(format!( + "model '{}' does not support chained video generation", + req_for_task.model + ))), + } + }; + (cached, outcome) + }); + + let (cached, outcome) = match join_handle.await { + Ok(pair) => pair, + Err(join_err) => { + return Err(ChainRunError::Internal(format!( + "chain orchestrator task failed: {join_err}" + ))); + } + }; + + // Restore the engine to the cache regardless of success/failure so the + // next request can reuse it. + { + let mut cache = state.model_cache.lock().await; + cache.restore(cached); + } + + let chain_output = outcome?; + let stage_count = chain_output.stage_count; + let generation_time_ms = chain_output.generation_time_ms; + let mut frames = chain_output.frames; + trim_to_total_frames(&mut frames, req.total_frames); + + if frames.is_empty() { + return Err(ChainRunError::Encode( + "chain run emitted zero frames after trim".to_string(), + )); + } + + let (bytes, output_format, gif_preview) = + encode_chain_output(&frames, req.fps, req.output_format) + .map_err(|e| ChainRunError::Encode(format!("encode chain output: {e:#}")))?; + let thumbnail = chain_thumbnail(&frames); + let frame_count = frames.len() as u32; + + // Save to the gallery directory (best-effort, non-blocking). + let output_dir = { + let config = state.config.read().await; + if config.is_output_disabled() { + None + } else { + Some(config.effective_output_dir()) + } + }; + if let Some(dir) = output_dir { + let metadata = chain_output_metadata(&req, frame_count); + let bytes_clone = bytes.clone(); + let gif_clone = gif_preview.clone(); + let model = req.model.clone(); + let db = state.metadata_db.clone(); + tokio::task::spawn_blocking(move || { + save_video_to_dir( + &dir, + &bytes_clone, + &gif_clone, + output_format, + &model, + &metadata, + Some(generation_time_ms as i64), + db.as_ref().as_ref(), + ); + }); + } + + let video = build_video_data( + bytes, + output_format, + &req, + frame_count, + thumbnail, + gif_preview, + ); + let response = ChainResponse { + video, + stage_count, + gpu: None, + }; + Ok((response, generation_time_ms)) +} + +/// `POST /api/generate/chain` — synchronous chained video generation. +#[utoipa::path( + post, + path = "/api/generate/chain", + tag = "generation", + request_body = mold_core::ChainRequest, + responses( + (status = 200, description = "Stitched chain video", body = mold_core::ChainResponse), + (status = 422, description = "Invalid request or unsupported model"), + (status = 500, description = "Chain render failed"), + (status = 502, description = "Chain render failed mid-stage"), + ) +)] +pub async fn generate_chain( + State(state): State, + Json(req): Json, +) -> Result, ApiError> { + let req = req + .normalise() + .map_err(|e| ApiError::validation(e.to_string()))?; + + tracing::info!( + model = %req.model, + stages = req.stages.len(), + width = req.width, + height = req.height, + fps = req.fps, + "generate/chain request" + ); + + let (response, _elapsed_ms) = run_chain(&state, req, None).await?; + Ok(Json(response)) +} + +/// `POST /api/generate/chain/stream` — SSE-streamed chain generation. Emits +/// [`ChainProgressEvent`]s as `event: progress` frames while the chain +/// runs, and a single `event: complete` frame with a [`SseChainCompleteEvent`] +/// payload when the stitched output is ready. Mid-chain failure closes the +/// stream with an `event: error` frame carrying the orchestrator message. +#[utoipa::path( + post, + path = "/api/generate/chain/stream", + tag = "generation", + request_body = mold_core::ChainRequest, + responses( + (status = 200, description = "SSE event stream with chain progress and completion"), + (status = 422, description = "Invalid request or unsupported model"), + (status = 500, description = "Chain render failed"), + ) +)] +pub async fn generate_chain_stream( + State(state): State, + Json(req): Json, +) -> Result>>, ApiError> { + let req = req + .normalise() + .map_err(|e| ApiError::validation(e.to_string()))?; + + tracing::info!( + model = %req.model, + stages = req.stages.len(), + width = req.width, + height = req.height, + fps = req.fps, + "generate/chain/stream request" + ); + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + let state_clone = state.clone(); + let tx_for_task = tx.clone(); + + tokio::spawn(async move { + let tx_for_cb = tx_for_task.clone(); + let cb: Box = Box::new(move |event| { + let _ = tx_for_cb.send(ChainSseMessage::Progress(event)); + }); + match run_chain(&state_clone, req, Some(cb)).await { + Ok((response, elapsed_ms)) => { + let complete = build_sse_chain_complete_event(&response, elapsed_ms); + let _ = tx_for_task.send(ChainSseMessage::Complete(complete)); + } + Err(err) => { + let api_err: ApiError = err.into(); + let _ = tx_for_task.send(ChainSseMessage::Error(api_err.error)); + } + } + // `tx_for_task` is dropped here, closing the channel and finalizing + // the SSE stream after the last complete/error frame. + }); + drop(tx); // ensure only the task holds the sender + + let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx) + .map(|msg| Ok::<_, Infallible>(chain_sse_event(msg))); + + Ok(Sse::new(stream).keep_alive( + KeepAlive::new() + .interval(std::time::Duration::from_secs(15)) + .text("ping"), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + use image::{Rgb, RgbImage}; + use mold_core::chain::{ChainProgressEvent, ChainRequest, ChainStage}; + use mold_core::{GenerateRequest, GenerateResponse}; + use mold_inference::ltx2::{ChainStageRenderer, ChainTail, StageOutcome, StageProgressEvent}; + use mold_inference::InferenceEngine; + use std::sync::{Arc, Mutex}; + + /// Mock engine that delegates to a simple chain renderer producing + /// deterministic solid-color frames + a zero-valued latent tail. The + /// chain renderer is owned by the engine so `as_chain_renderer` can + /// hand out a `&mut dyn ChainStageRenderer` over it. + struct ChainMockEngine { + loaded: bool, + fail_on_stage: Option, + renderer_calls: Arc>, + } + + impl ChainMockEngine { + fn ready() -> Self { + Self { + loaded: true, + fail_on_stage: None, + renderer_calls: Arc::new(Mutex::new(0)), + } + } + fn failing_at(idx: usize) -> Self { + Self { + loaded: true, + fail_on_stage: Some(idx), + renderer_calls: Arc::new(Mutex::new(0)), + } + } + } + + impl ChainStageRenderer for ChainMockEngine { + fn render_stage( + &mut self, + stage_req: &GenerateRequest, + _carry: Option<&ChainTail>, + _motion_tail_pixel_frames: u32, + _stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>, + ) -> Result { + let idx = { + let mut calls = self.renderer_calls.lock().unwrap(); + let idx = *calls; + *calls += 1; + idx + }; + if self.fail_on_stage == Some(idx) { + anyhow::bail!("simulated chain failure at stage {idx}"); + } + let frame_count = stage_req.frames.expect("chain stage missing frame count") as usize; + let width = stage_req.width; + let height = stage_req.height; + let mut frames = Vec::with_capacity(frame_count); + for f in 0..frame_count { + let shade = (idx as u8).wrapping_mul(17).wrapping_add(f as u8); + frames.push(RgbImage::from_pixel(width, height, Rgb([shade, 0, 0]))); + } + let tail_pixel_frames = 4usize; + let take_from = frames + .len() + .saturating_sub(tail_pixel_frames) + .min(frames.len()); + let tail_rgb_frames = frames[take_from..].to_vec(); + Ok(StageOutcome { + frames, + tail: ChainTail { + frames: tail_pixel_frames as u32, + tail_rgb_frames, + }, + generation_time_ms: 10, + }) + } + } + + impl InferenceEngine for ChainMockEngine { + fn generate(&mut self, _req: &GenerateRequest) -> Result { + anyhow::bail!("chain mock engine does not support single-shot generate") + } + fn model_name(&self) -> &str { + "ltx-2-19b-distilled:mock" + } + fn is_loaded(&self) -> bool { + self.loaded + } + fn load(&mut self) -> Result<()> { + self.loaded = true; + Ok(()) + } + fn as_chain_renderer( + &mut self, + ) -> Option<&mut dyn mold_inference::ltx2::ChainStageRenderer> { + Some(self) + } + } + + /// Build an AppState whose model cache already contains a chain-capable + /// mock engine under the model name the tests pass in their requests. + fn state_with_chain_engine(engine: ChainMockEngine) -> AppState { + AppState::with_engine(engine) + } + + fn chain_req_for_mock(model: &str, stages: u32) -> ChainRequest { + ChainRequest { + model: model.to_string(), + stages: (0..stages) + .map(|_| ChainStage { + prompt: "a cat walking".into(), + frames: 9, + source_image: None, + negative_prompt: None, + seed_offset: None, + }) + .collect(), + motion_tail_frames: 0, // simplifies frame accounting for the mock + width: 64, + height: 64, + fps: 12, + seed: Some(42), + steps: 4, + guidance: 3.0, + strength: 1.0, + output_format: OutputFormat::Apng, // avoid needing the mp4 feature in tests + placement: None, + prompt: None, + total_frames: None, + clip_frames: None, + source_image: None, + } + } + + #[tokio::test] + async fn chain_happy_path_returns_stage_count_and_video() { + let engine = ChainMockEngine::ready(); + let state = state_with_chain_engine(engine); + let req = chain_req_for_mock("ltx-2-19b-distilled:mock", 3); + + let (resp, elapsed_ms) = run_chain(&state, req, None) + .await + .expect("chain run succeeds"); + + assert_eq!(resp.stage_count, 3, "response must report all 3 stages"); + assert_eq!(resp.video.fps, 12); + assert_eq!(resp.video.frames, 9 * 3, "3 stages × 9 frames with tail=0"); + assert_eq!(resp.video.format, OutputFormat::Apng); + assert!(!resp.video.data.is_empty(), "apng bytes written"); + // elapsed_ms is the sum of the mock's reported per-stage time (10ms each). + assert_eq!(elapsed_ms, 30); + } + + #[tokio::test] + async fn chain_stream_emits_progress_then_complete_in_order() { + let engine = ChainMockEngine::ready(); + let state = state_with_chain_engine(engine); + let req = chain_req_for_mock("ltx-2-19b-distilled:mock", 2); + + let collected: Arc>> = Arc::new(Mutex::new(Vec::new())); + let collected_cb = collected.clone(); + let cb: Box = Box::new(move |ev| { + collected_cb.lock().unwrap().push(ev); + }); + let (resp, _) = run_chain(&state, req, Some(cb)) + .await + .expect("chain run succeeds"); + + assert_eq!(resp.stage_count, 2); + let events = collected.lock().unwrap(); + assert!(!events.is_empty(), "progress events must flow"); + assert!( + matches!( + events[0], + ChainProgressEvent::ChainStart { stage_count: 2, .. } + ), + "first event must be ChainStart, got {:?}", + events[0] + ); + assert!( + matches!(events.last().unwrap(), ChainProgressEvent::Stitching { .. }), + "last event must be Stitching, got {:?}", + events.last() + ); + // There must be exactly one StageStart + StageDone per stage. + let stage_starts = events + .iter() + .filter(|e| matches!(e, ChainProgressEvent::StageStart { .. })) + .count(); + let stage_dones = events + .iter() + .filter(|e| matches!(e, ChainProgressEvent::StageDone { .. })) + .count(); + assert_eq!(stage_starts, 2); + assert_eq!(stage_dones, 2); + } + + #[tokio::test] + async fn chain_mid_chain_failure_maps_to_bad_gateway() { + let engine = ChainMockEngine::failing_at(1); + let state = state_with_chain_engine(engine); + let req = chain_req_for_mock("ltx-2-19b-distilled:mock", 3); + + let err = run_chain(&state, req, None) + .await + .expect_err("mid-chain failure must bubble up"); + match err { + ChainRunError::Inference(msg) => { + assert!( + msg.contains("simulated chain failure"), + "inference error must carry renderer message, got: {msg}" + ); + } + other => panic!("expected Inference error, got {other:?}"), + } + } + + #[tokio::test] + async fn chain_unsupported_model_rejects_with_validation() { + /// Engine that is fully capable of single-shot generate but refuses + /// chain rendering (mirrors every non-LTX-2 family). + struct NonChainEngine; + impl InferenceEngine for NonChainEngine { + fn generate(&mut self, _req: &GenerateRequest) -> Result { + anyhow::bail!("no single-shot generate in this test either") + } + fn model_name(&self) -> &str { + "flux-dev:q8" + } + fn is_loaded(&self) -> bool { + true + } + fn load(&mut self) -> Result<()> { + Ok(()) + } + // No override for as_chain_renderer — default returns None. + } + + let state = AppState::with_engine(NonChainEngine); + let mut req = chain_req_for_mock("flux-dev:q8", 2); + req.model = "flux-dev:q8".into(); + let err = run_chain(&state, req, None) + .await + .expect_err("non-chain model must fail"); + match err { + ChainRunError::UnsupportedModel(msg) => { + assert!( + msg.contains("does not support chained video generation"), + "unsupported-model error must name the constraint, got: {msg}" + ); + } + other => panic!("expected UnsupportedModel, got {other:?}"), + } + } + + #[tokio::test] + async fn chain_trims_frames_from_tail_when_total_frames_set() { + let engine = ChainMockEngine::ready(); + let state = state_with_chain_engine(engine); + let mut req = chain_req_for_mock("ltx-2-19b-distilled:mock", 2); + // Each stage produces 9 frames with tail=0 → 18 total. Trim to 10. + req.total_frames = Some(10); + + let (resp, _) = run_chain(&state, req, None).await.expect("chain runs"); + assert_eq!( + resp.video.frames, 10, + "total_frames must trim the stitched output length" + ); + } +} diff --git a/crates/mold-server/src/routes_test.rs b/crates/mold-server/src/routes_test.rs index a7c1237f..cb00dc93 100644 --- a/crates/mold-server/src/routes_test.rs +++ b/crates/mold-server/src/routes_test.rs @@ -1,3 +1,8 @@ +// Tests use `std::sync::Mutex<()>` to serialize process-global env-var +// mutations; holding the guard across `.await` is intentional under the +// current-thread tokio test runtime. +#![allow(clippy::await_holding_lock)] + #[cfg(test)] mod tests { use axum::{ @@ -12,7 +17,7 @@ mod tests { use std::path::PathBuf; use std::sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, - Arc, Condvar, Mutex, + Arc, Condvar, Mutex, RwLock, }; use std::time::Duration; use tower::ServiceExt; @@ -277,6 +282,34 @@ mod tests { )) } + fn gpu_worker_stub(ordinal: usize) -> Arc { + let (job_tx, _job_rx) = std::sync::mpsc::sync_channel(1); + Arc::new(crate::gpu_pool::GpuWorker { + gpu: mold_inference::device::DiscoveredGpu { + ordinal, + name: format!("gpu{ordinal}"), + total_vram_bytes: 24_000_000_000, + free_vram_bytes: 24_000_000_000, + }, + model_cache: Arc::new(Mutex::new(crate::model_cache::ModelCache::new(3))), + active_generation: Arc::new(RwLock::new(None)), + model_load_lock: Arc::new(Mutex::new(())), + shared_pool: Arc::new(Mutex::new(mold_inference::shared_pool::SharedPool::new())), + in_flight: AtomicUsize::new(0), + consecutive_failures: AtomicUsize::new(0), + degraded_until: RwLock::new(None), + job_tx, + }) + } + + fn app_with_worker_pool(engine: MockEngine, ordinals: &[usize]) -> axum::Router { + let mut state = AppState::with_engine(engine); + state.gpu_pool = Arc::new(crate::gpu_pool::GpuPool { + workers: ordinals.iter().copied().map(gpu_worker_stub).collect(), + }); + create_router(state) + } + fn generate_body(prompt: &str, width: u32, height: u32) -> String { // Use "mock-model" to match MockEngine::model_name() — avoids hot-swap path. format!( @@ -381,6 +414,41 @@ mod tests { assert_eq!(status["current_generation"], serde_json::Value::Null); } + #[tokio::test] + async fn status_multi_gpu_current_generation_includes_prompt_hash_and_timestamp() { + let worker = gpu_worker_stub(1); + *worker.active_generation.write().unwrap() = Some(crate::gpu_pool::ActiveGeneration { + model: "flux-dev:q4".to_string(), + prompt_sha256: "abc123".to_string(), + started_at_unix_ms: 1_700_000_000_000, + started_at: std::time::Instant::now(), + }); + + let mut state = AppState::with_engine(MockEngine::ready()); + state.gpu_pool = Arc::new(crate::gpu_pool::GpuPool { + workers: vec![worker], + }); + let app = app_with_state(state); + + let resp = app + .oneshot(Request::get("/api/status").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), 1024 * 1024) + .await + .unwrap(); + let status: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(status["current_generation"]["model"], "flux-dev:q4"); + assert_eq!(status["current_generation"]["prompt_sha256"], "abc123"); + assert_eq!( + status["current_generation"]["started_at_unix_ms"], + serde_json::json!(1_700_000_000_000_u64) + ); + assert_eq!(status["gpus"][0]["ordinal"], serde_json::json!(1)); + } + #[tokio::test] async fn status_includes_hostname_and_memory_status() { let app = app_empty(); @@ -1196,6 +1264,7 @@ mod tests { start_time: std::time::Instant::now(), model_load_lock: Arc::new(tokio::sync::Mutex::new(())), pull_lock: Arc::new(tokio::sync::Mutex::new(())), + chain_lock: Arc::new(tokio::sync::Mutex::new(())), queue, shared_pool: Arc::new(std::sync::Mutex::new( mold_inference::shared_pool::SharedPool::new(), @@ -1249,6 +1318,7 @@ mod tests { start_time: std::time::Instant::now(), model_load_lock: Arc::new(tokio::sync::Mutex::new(())), pull_lock: Arc::new(tokio::sync::Mutex::new(())), + chain_lock: Arc::new(tokio::sync::Mutex::new(())), queue, shared_pool: Arc::new(std::sync::Mutex::new( mold_inference::shared_pool::SharedPool::new(), @@ -1505,6 +1575,7 @@ mod tests { start_time: std::time::Instant::now(), model_load_lock: Arc::new(tokio::sync::Mutex::new(())), pull_lock: Arc::new(tokio::sync::Mutex::new(())), + chain_lock: Arc::new(tokio::sync::Mutex::new(())), queue, shared_pool: Arc::new(std::sync::Mutex::new( mold_inference::shared_pool::SharedPool::new(), @@ -2263,6 +2334,77 @@ mod tests { resp.status() ); } + + #[tokio::test] + async fn put_model_placement_rejects_gpu_outside_worker_pool() { + let _lock = env_lock().lock().unwrap_or_else(|e| e.into_inner()); + let (tx, _rx) = tokio::sync::mpsc::channel(16); + let queue = crate::state::QueueHandle::new(tx); + let gpu_pool = Arc::new(crate::gpu_pool::GpuPool { + workers: vec![gpu_worker_stub(1)], + }); + let state = AppState::empty(mold_core::Config::default(), queue, gpu_pool, 200); + let app = crate::routes::create_router(state); + + let body = serde_json::json!({ + "text_encoders": { "kind": "auto" }, + "advanced": { + "transformer": { "kind": "gpu", "ordinal": 0 }, + "vae": { "kind": "auto" } + } + }); + + let resp = app + .oneshot( + Request::builder() + .method("PUT") + .uri("/api/config/model/flux-dev%3Aq4/placement") + .header("content-type", "application/json") + .body(Body::from(body.to_string())) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY); + let body = json_body(resp).await; + assert!(body["error"].as_str().unwrap().contains("gpu:0")); + } + + #[tokio::test] + async fn generate_rejects_gpu_outside_worker_pool() { + let app = app_with_worker_pool(MockEngine::ready(), &[1]); + let body = serde_json::json!({ + "prompt": "a cat", + "model": "mock-model", + "width": 512, + "height": 512, + "steps": 4, + "batch_size": 1, + "output_format": "png", + "placement": { + "text_encoders": { "kind": "auto" }, + "advanced": { + "transformer": { "kind": "gpu", "ordinal": 0 }, + "vae": { "kind": "auto" } + } + } + }); + + let resp = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/generate") + .header("content-type", "application/json") + .body(Body::from(body.to_string())) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY); + let body = json_body(resp).await; + assert!(body["error"].as_str().unwrap().contains("gpu:0")); + } // ─── Downloads UI (Agent A) ───────────────────────────────────────────── #[tokio::test] @@ -2508,6 +2650,7 @@ mod tests { used_by_mold: 0, used_by_other: 0, }, + cpu: None, }); let app = create_router(state); diff --git a/crates/mold-server/src/state.rs b/crates/mold-server/src/state.rs index 56213cca..2c9415be 100644 --- a/crates/mold-server/src/state.rs +++ b/crates/mold-server/src/state.rs @@ -138,6 +138,14 @@ pub struct AppState { pub model_load_lock: Arc>, /// Guards concurrent pulls — only one download at a time. pub pull_lock: Arc>, + /// Serializes chained video renders. The chain handler removes the + /// engine from `model_cache` and runs blocking work outside that + /// lock for the full multi-minute chain; without a dedicated lock two + /// concurrent chain requests race on `cache.take()` and the loser + /// surfaces "engine vanished from cache after ensure_model_ready". + /// Held for the entire chain (load + all stages + restore); other + /// single-clip requests continue to queue normally on `queue`. + pub chain_lock: Arc>, /// Generation request queue. pub queue: QueueHandle, /// Shared tokenizer pool for cross-engine caching. @@ -188,6 +196,7 @@ impl AppState { start_time: Instant::now(), model_load_lock: Arc::new(Mutex::new(())), pull_lock: Arc::new(Mutex::new(())), + chain_lock: Arc::new(Mutex::new(())), queue, shared_pool: Arc::new(std::sync::Mutex::new(SharedPool::new())), shutdown_tx: Arc::new(tokio::sync::Mutex::new(None)), @@ -215,6 +224,7 @@ impl AppState { start_time: Instant::now(), model_load_lock: Arc::new(Mutex::new(())), pull_lock: Arc::new(Mutex::new(())), + chain_lock: Arc::new(Mutex::new(())), queue, shared_pool: Arc::new(std::sync::Mutex::new(SharedPool::new())), shutdown_tx: Arc::new(tokio::sync::Mutex::new(None)), @@ -264,6 +274,7 @@ impl AppState { start_time: Instant::now(), model_load_lock: Arc::new(Mutex::new(())), pull_lock: Arc::new(Mutex::new(())), + chain_lock: Arc::new(Mutex::new(())), queue, shared_pool: Arc::new(std::sync::Mutex::new(SharedPool::new())), shutdown_tx: Arc::new(tokio::sync::Mutex::new(None)), @@ -300,6 +311,7 @@ impl AppState { start_time: Instant::now(), model_load_lock: Arc::new(Mutex::new(())), pull_lock: Arc::new(Mutex::new(())), + chain_lock: Arc::new(Mutex::new(())), queue, shared_pool: Arc::new(std::sync::Mutex::new(SharedPool::new())), shutdown_tx: Arc::new(tokio::sync::Mutex::new(None)), diff --git a/crates/mold-tui/src/app.rs b/crates/mold-tui/src/app.rs index 0ec7661f..16659f9e 100644 --- a/crates/mold-tui/src/app.rs +++ b/crates/mold-tui/src/app.rs @@ -1452,6 +1452,7 @@ impl App { model_name_local.clone(), weights_path, mold_inference::LoadStrategy::Eager, + 0, )?; engine.set_on_progress(Box::new(move |event| { diff --git a/crates/mold-tui/src/ui/info.rs b/crates/mold-tui/src/ui/info.rs index 93f0260c..1202b126 100644 --- a/crates/mold-tui/src/ui/info.rs +++ b/crates/mold-tui/src/ui/info.rs @@ -55,7 +55,26 @@ pub fn render(frame: &mut Frame, app: &App, area: Rect) { // GPU info from remote server if let Some(ref status) = ri.server_status { - if let Some(ref gpu) = status.gpu_info { + if let Some(ref gpus) = status.gpus { + for gpu in gpus { + let vram_free = gpu.vram_total_bytes.saturating_sub(gpu.vram_used_bytes); + lines.push(Line::from(Span::styled( + format!( + "GPU {} {}: {:.1} GB free", + gpu.ordinal, + gpu.name, + vram_free as f64 / 1_073_741_824.0 + ), + theme.dim(), + ))); + } + if let (Some(depth), Some(capacity)) = (status.queue_depth, status.queue_capacity) { + lines.push(Line::from(Span::styled( + format!("Queue: {depth}/{capacity}"), + theme.dim(), + ))); + } + } else if let Some(ref gpu) = status.gpu_info { let vram_free = gpu.vram_total_mb.saturating_sub(gpu.vram_used_mb); lines.push(Line::from(Span::styled( format!("{}: {:.1} GB free", gpu.name, vram_free as f64 / 1024.0), diff --git a/flake.nix b/flake.nix index 2f5186b8..8a791107 100644 --- a/flake.nix +++ b/flake.nix @@ -109,6 +109,7 @@ pkgs.llvmPackages.libclang.lib ] ++ lib.optionals isLinux [ + pkgs.lld pkgs.cudaPackages.cuda_nvcc ]; buildInputs = [ @@ -136,8 +137,6 @@ opensslLibDir = "${pkgs.lib.getLib pkgs.openssl}/lib"; opensslIncludeDir = "${pkgs.openssl.dev}/include"; - cargoArtifacts = craneLib.buildDepsOnly commonArgs; - gpuFeature = if isLinux then "cuda" @@ -146,12 +145,34 @@ else ""; - # Features string for devshell commands: GPU + preview + discord + expand + tui + video formats + devProfile = "dev-fast"; + + # Fast local iteration defaults: GPU backend + preview + prompt expansion. devFeatures = if gpuFeature != "" then - "${gpuFeature},preview,discord,expand,tui,webp,mp4" + "${gpuFeature},preview,expand" + else + "preview,expand"; + + # Full shipping feature set used for release builds and feature coverage. + releaseFeatures = + if gpuFeature != "" then + "${gpuFeature},preview,discord,expand,tui,webp,mp4,metrics" else - "preview,discord,expand,tui,webp,mp4"; + "preview,discord,expand,tui,webp,mp4,metrics"; + + cargoArtifacts = craneLib.buildDepsOnly ( + commonArgs + // { + cargoExtraArgs = "-p mold-ai --features ${releaseFeatures}"; + } + ); + + webEmbedSetup = '' + export SCCACHE_DIR="''${MOLD_SCCACHE_DIR:-$PWD/.cache/sccache}" + export MOLD_WEB_DIST="$PWD/web/dist" + ./scripts/ensure-web-dist.sh + ''; # Merged CUDA toolkit so bindgen_cuda can find both bin/nvcc and include/cuda.h cudaToolkit = pkgs.symlinkJoin { @@ -215,9 +236,7 @@ // { inherit cargoArtifacts meta; MOLD_WEB_DIST = "${mold-web}"; - cargoExtraArgs = - "-p mold-ai --features preview,discord,expand,tui,webp,mp4,metrics" - + lib.optionalString (gpuFeature != "") ",${gpuFeature}"; + cargoExtraArgs = "-p mold-ai --features ${releaseFeatures}"; postInstall = '' installShellCompletion --cmd mold \ --bash <($out/bin/mold completions bash) \ @@ -278,6 +297,7 @@ pkgs.pkg-config pkgs.openssl pkgs.nasm + pkgs.sccache pkgs.git pkgs.gh pkgs.jq @@ -296,6 +316,7 @@ pkgs.llvmPackages.libcxxClang ] ++ lib.optionals isLinux [ + pkgs.lld pkgs.cudaPackages.cuda_nvcc pkgs.cudaPackages.cuda_cudart pkgs.cudaPackages.libcublas.lib @@ -313,6 +334,14 @@ name = "MOLD_LTX_DEBUG"; value = "1"; } + { + name = "CARGO_INCREMENTAL"; + value = "1"; + } + { + name = "RUSTC_WRAPPER"; + value = "sccache"; + } { name = "PKG_CONFIG_PATH"; value = opensslPkgConfigPath; @@ -386,26 +415,44 @@ { category = "build"; name = "build"; - help = "cargo build (debug, all crates)"; + help = "fast local mold build with the web bundle embedded"; + command = '' + set -euo pipefail + ${webEmbedSetup} + cargo build --profile ${devProfile} -p mold-ai --features ${devFeatures} "$@" + ''; + } + { + category = "build"; + name = "build-workspace"; + help = "cargo build the full workspace in debug mode"; command = "cargo build \"$@\""; } { category = "build"; name = "build-release"; - help = "cargo build --release -p mold-ai --features ${devFeatures}"; - command = "cargo build --release -p mold-ai --features ${devFeatures} \"$@\""; + help = "shipping mold build with the full feature set and embedded web UI"; + command = '' + set -euo pipefail + ${webEmbedSetup} + cargo build --release -p mold-ai --features ${releaseFeatures} "$@" + ''; } { category = "build"; name = "build-server"; - help = "cargo build -p mold-ai --features ${devFeatures} (single binary with GPU + preview)"; - command = "cargo build -p mold-ai --features ${devFeatures} \"$@\""; + help = "fast local server build with GPU + preview + expand and embedded web UI"; + command = '' + set -euo pipefail + ${webEmbedSetup} + cargo build --profile ${devProfile} -p mold-ai --features ${devFeatures} "$@" + ''; } { category = "build"; name = "build-discord"; - help = "cargo build -p mold-ai --features discord"; - command = "cargo build -p mold-ai --features ${devFeatures} \"$@\""; + help = "fast local Discord-bot build"; + command = "cargo build --profile ${devProfile} -p mold-ai --features discord \"$@\""; } { category = "build"; @@ -434,8 +481,8 @@ { category = "check"; name = "clippy"; - help = "cargo clippy --workspace -- -D warnings (matches CI)"; - command = "cargo clippy --workspace \"$@\" -- -D warnings"; + help = "cargo clippy --workspace --all-targets -- -D warnings (matches CI)"; + command = "cargo clippy --workspace --all-targets \"$@\" -- -D warnings"; } { category = "check"; @@ -451,7 +498,7 @@ set -euo pipefail cargo fmt --all -- --check cargo check --workspace - cargo clippy --workspace -- -D warnings + cargo clippy --workspace --all-targets -- -D warnings cargo test --workspace cargo check -p mold-ai --features preview,discord,expand,tui,webp,mp4 ''; @@ -493,26 +540,38 @@ { category = "run"; name = "mold"; - help = "run mold CLI (e.g. mold list, mold ps, mold pull)"; - command = "cargo run -p mold-ai --features ${devFeatures} -- \"$@\""; + help = "run mold CLI with the fast local feature set"; + command = '' + set -euo pipefail + ${webEmbedSetup} + cargo run --profile ${devProfile} -p mold-ai --features ${devFeatures} -- "$@" + ''; } { category = "run"; name = "serve"; help = "start the mold server"; - command = "cargo run -p mold-ai --features ${devFeatures} -- serve \"$@\""; + command = '' + set -euo pipefail + ${webEmbedSetup} + cargo run --profile ${devProfile} -p mold-ai --features ${devFeatures} -- serve "$@" + ''; } { category = "run"; name = "generate"; help = "generate an image from a prompt"; - command = "cargo run -p mold-ai --features ${devFeatures} -- run \"$@\""; + command = '' + set -euo pipefail + ${webEmbedSetup} + cargo run --profile ${devProfile} -p mold-ai --features ${devFeatures} -- run "$@" + ''; } { category = "run"; name = "discord-bot"; help = "start the mold Discord bot"; - command = "cargo run -p mold-ai --features ${devFeatures} -- discord \"$@\""; + command = "cargo run --profile ${devProfile} -p mold-ai --features discord -- discord \"$@\""; } { category = "runpod"; @@ -548,13 +607,21 @@ category = "run"; name = "build-ltx2"; help = "build mold with the full feature set for LTX-2 work"; - command = "cargo build -p mold-ai --features ${devFeatures} \"$@\""; + command = '' + set -euo pipefail + ${webEmbedSetup} + cargo build --profile ${devProfile} -p mold-ai --features ${releaseFeatures} "$@" + ''; } { category = "run"; name = "smoke-ltx2"; help = "run a local LTX-2 / LTX-2.3 smoke inference"; - command = "cargo run -p mold-ai --features ${devFeatures} -- run --local \"$@\""; + command = '' + set -euo pipefail + ${webEmbedSetup} + cargo run --profile ${devProfile} -p mold-ai --features ${releaseFeatures} -- run --local "$@" + ''; } { category = "run"; diff --git a/scripts/ensure-web-dist.sh b/scripts/ensure-web-dist.sh new file mode 100755 index 00000000..443bf70e --- /dev/null +++ b/scripts/ensure-web-dist.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash + +set -euo pipefail + +repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +web_dir="${MOLD_WEB_ROOT:-$repo_root/web}" +dist_dir="$web_dir/dist" +stamp_file="$dist_dir/.mold-build-stamp" +install_stamp="$web_dir/node_modules/.mold-install-stamp" + +needs_build=0 + +if [ ! -f "$dist_dir/index.html" ] || [ ! -f "$stamp_file" ]; then + needs_build=1 +else + while IFS= read -r path; do + if [ "$path" -nt "$stamp_file" ]; then + needs_build=1 + break + fi + done < <( + find \ + "$web_dir/src" \ + "$web_dir/public" \ + -type f \ + -print + printf '%s\n' \ + "$web_dir/package.json" \ + "$web_dir/bun.lock" \ + "$web_dir/bun.nix" \ + "$web_dir/index.html" \ + "$web_dir/vite.config.ts" \ + "$web_dir/tsconfig.json" \ + "$web_dir/tsconfig.app.json" \ + "$web_dir/tsconfig.node.json" \ + "$web_dir/vitest.config.ts" + ) +fi + +if [ "$needs_build" -eq 0 ]; then + exit 0 +fi + +if [ ! -d "$web_dir/node_modules" ] \ + || [ ! -f "$install_stamp" ] \ + || [ "$web_dir/package.json" -nt "$install_stamp" ] \ + || [ "$web_dir/bun.lock" -nt "$install_stamp" ]; then + ( + cd "$web_dir" + bun install --frozen-lockfile + touch "$install_stamp" + ) +fi + +( + cd "$web_dir" + bun run build + touch "$stamp_file" +) diff --git a/scripts/tui-uat.sh b/scripts/tui-uat.sh index b8ca5c62..23566a5e 100755 --- a/scripts/tui-uat.sh +++ b/scripts/tui-uat.sh @@ -27,7 +27,11 @@ set -euo pipefail STATE_FILE="/tmp/mold-tui-uat.state" -MOLD_BIN="${MOLD_BIN:-./target/debug/mold}" +DEFAULT_MOLD_BIN="./target/dev-fast/mold" +if [ ! -x "$DEFAULT_MOLD_BIN" ]; then + DEFAULT_MOLD_BIN="./target/debug/mold" +fi +MOLD_BIN="${MOLD_BIN:-$DEFAULT_MOLD_BIN}" # ── Helpers ───────────────────────────────────────────────────────── diff --git a/tasks/render-chain-v1-handoff.md b/tasks/render-chain-v1-handoff.md new file mode 100644 index 00000000..540e04cc --- /dev/null +++ b/tasks/render-chain-v1-handoff.md @@ -0,0 +1,311 @@ +# render-chain-v1 — context handoff + +> Paste the prompt at the bottom of this file into a fresh Claude Code session +> to resume work on render-chain v1. Everything above it is reference material +> that the prompt points at. + +## Status + +Branch: `main` (local). **6 commits stacked ahead of `origin/main`, not pushed** +per plan convention (no mid-plan push): + +| # | Commit | Scope | Phase | +|---|-----------|----------|-------| +| 1 | `d13a554` | `fix(ltx2): use pure source latents as i2v denoise-mask target` | Fix A (prereq) | +| 2 | `b4ed487` | `feat(chain): add core wire types and request normalisation` | 0.1 | +| 3 | `0328e76` | `feat(core): MoldClient chain methods` | 0.2 | +| 4 | `e89826f` | `feat(ltx2): ChainTail type and latent-tail extraction helper` | 1a | +| 5 | `e917210` | `feat(ltx2): staged latent conditioning bypasses VAE encode` | 1b | +| 6 | `14801c7` | `feat(ltx2): chain orchestrator with motion-tail carryover loop` | 1c | + +Test status on commit 6: `mold-core` 617 pass, `mold-inference` 586 pass, +`cargo fmt --check` clean, no candle weights loaded by any test. + +Pre-existing clippy warnings on main (NOT introduced by this branch): +- `crates/mold-core/src/download.rs:1451` — `manual_repeat_n` +- `crates/mold-core/src/placement_test.rs:167` — `field_reassign_with_default` + +These only fire on newer clippy versions than CI pins and are unrelated to +the chain work. Don't "fix" them as part of render-chain. + +## Signed-off design decisions (do NOT re-litigate) + +User confirmed these 2026-04-20 and they're recorded at the top of +`tasks/render-chain-v1-plan.md`: + +1. **Trim over-production from the tail** of the final clip, not the head. +2. **Per-stage seed derivation: `stage_seed = base_seed ^ ((stage_idx as u64) << 32)`.** + `ChainStage::seed_offset` overrides this; reserved for the v2 movie-maker. +3. **Fail closed on mid-chain failure.** 502 + discard all prior stages. No + partial stitch. +4. **Accept ~1 GB RAM ceiling** for accumulated `RgbImage` buffer. Streaming + encode revisited at 1000+ frames. +5. **Single-GPU per chain.** Multi-GPU stage fan-out is v2. + +The orchestrator already encodes 1, 2, 3 and Phase 2 server route handles 3. + +## What's done + +- **`mold_core::chain`** — wire types (`ChainRequest`, `ChainResponse`, + `ChainStage`, `ChainProgressEvent`, `SseChainCompleteEvent`) and + `ChainRequest::normalise()`. Re-exports from `mold_core`. +- **`MoldClient::generate_chain{,_stream}`** with 422 → Validation, 404-with- + body → ModelNotFound, empty-404 → hard error (non-streaming) / `Ok(None)` + (streaming). Wiremock integration tests pin all four paths. +- **`ltx2::chain::ChainTail` + `extract_tail_latents`** — pure tensor math, + VAE formula `((pixel - 1) / 8) + 1`. Errors (not panics) on rank + mismatch / oversize tail. +- **`StagedLatent` + `StagedConditioning.latents`** — threaded through + `maybe_load_stage_video_conditioning` in `runtime.rs`. When the latents + vec is non-empty, the function builds `VideoTokenReplacement`s straight + from pre-encoded tokens and **skips VAE load entirely** (conditional + `Option` — confirmed only loaded when images or reference video are + present). +- **`Ltx2ChainOrchestrator`** — fully tested against + a fake renderer. Handles seed derivation, motion-tail trim on + continuations (stage 0 keeps all frames, continuations drop leading K), + progress forwarding with `stage_idx` wrapping, fail-closed error handling. + Orchestrator does NOT trim to a target total or encode MP4 — those are + caller responsibilities. + +## What's remaining + +### Phase 1d — `impl ChainStageRenderer for Ltx2Engine` (engine integration) + +The one-sentence contract: given `stage_req`, optional `carry: &ChainTail`, +and an optional stage-progress callback, return +`StageOutcome { frames, tail, generation_time_ms }`. + +Three sub-tasks: + +1. **Tail capture slot.** Add a mechanism for `render_real_distilled_av` + (`crates/mold-inference/src/ltx2/runtime.rs:1722`) to clone the + pre-VAE-decode `latents` tensor into a caller-provided slot. The exact + capture point is immediately before `vae.decode(&latents...)` at + `runtime.rs:2010` — shape is `[1, 128, T_latent, H/32, W/32]` F32. + Preferred mechanism: a field on `Ltx2RuntimeSession` (or a method + argument threaded down) holding `Option>>>`. + Production non-chain callers leave it `None` and pay no overhead. + +2. **`Ltx2Engine::generate_with_carryover(&mut self, req, carry)`**: + - Validate the request is a supported family (v1 scope: distilled LTX-2 + only — see `select_pipeline` at `crates/mold-inference/src/ltx2/pipeline.rs:108`). + - Build a `Ltx2GeneratePlan` via the existing `materialize_request` flow. + When `carry.is_some()`, wipe `source_image` and append a + `StagedLatent { latents: carry.latents.clone(), frame: 0, strength: 1.0 }` + to `plan.conditioning.latents`. The runtime already handles the rest + (`maybe_load_stage_video_conditioning` skips VAE, builds a frame-0 + replacement from patchified tokens). + - Enable the tail-capture slot. + - Run the existing render → decode → encode pipeline. + - Pull the captured latents out of the slot. + - Call `ltx2::chain::extract_tail_latents(&captured, motion_tail_frames)` + to get the tail slice. + - Decode the stitched MP4 once to extract `last_rgb_frame` (or capture + it alongside the `frames` Vec from `decoded_video_to_frames`). + - Return `(GenerateResponse, ChainTail)`. + +3. **`impl ChainStageRenderer for Ltx2Engine`** that delegates to + `generate_with_carryover`. The orchestrator's fake-renderer tests + define the exact contract; no new test harness needed for the impl — + real-engine coverage is Phase 2's integration test. + +**Gotchas:** +- `CLAUDE.md` claims `[lib] test = false` on `mold-inference` and + `mold-server` — **this is stale.** Both have normal test configs. Verified + in Phase 1a/b/c by running 586 tests. +- `run_real_distilled_stage` already takes + `video_clean_latents: Option<&Tensor>` and `video_denoise_mask: Option<&Tensor>` — + don't add new parameters unnecessarily. The tail carryover rides on + `conditioning.replacements` via `StagedLatent`, not on `video_clean_latents`. +- VAE temporal ratio is **8× with causal first frame** (`model/shapes.rs:20`). + `extract_tail_latents` already encodes this; just call it. +- `motion_tail_frames` defaults to 4 per plan; orchestrator validates + `motion_tail < stage.frames` up front, but the engine should still + tolerate `motion_tail = 0` (simple concat, no latent carryover — `carry` + will be `None` for every stage in that configuration). + +### Phase 2 — `POST /api/generate/chain[/stream]` server route + +Plan §2. Handler flow: + +1. Parse + `ChainRequest::normalise()`. +2. Reject non-LTX-2 models with a clear error. +3. Grab the engine from `ModelCache` (`crates/mold-server/src/lib.rs` — + holds `AppState.model_cache: Arc>`). +4. Construct `Ltx2ChainOrchestrator` against it, call `run()`. +5. Trim accumulated frames to target total (the ChainRequest no longer + carries `total_frames` after normalise — if you want tail-trim support, + add a `target_total_frames: Option` field that normalise + populates). Per the sign-off: trim from the tail. +6. Encode stitched MP4. Reuse `ltx2::media::encode_frames_to_mp4` or the + existing `encode_native_video` path — scout during Phase 2. +7. Save via `save_video_to_dir` with an `OutputMetadata` synthesised from + `stages[0].prompt`; optionally add `chain_stage_count: Option` to + `OutputMetadata`. +8. Return `ChainResponse` JSON. + +**Do NOT go through the existing single-job queue.** A 10+ minute chain +would block the queue. Instead hold the `ModelCache` mutex directly for +the chain duration, same pattern as the multi-GPU pool. Reason in plan §2.1. + +SSE variant: same flow, stream `ChainProgressEvent` as `event: progress` +JSON frames and a final `SseChainCompleteEvent` as `event: complete`. + +Tests: route-level with a fake engine (same trait seam as Phase 1c). No +real weights. + +### Phase 3 — CLI auto-routing + flags + +When `--frames > clip_cap` (97 for LTX-2 19B/22B distilled), build a +`ChainRequest` from the CLI args and route to +`MoldClient::generate_chain_stream`. New flags: `--clip-frames N`, +`--motion-tail N` (default 4). + +Stacked progress bar: one parent bar per chain (estimated total frames), +one per-stage bar wiping between stages. + +`--local` parity: factor the orchestrator invocation so both server +handler and CLI local path use the same code. + +### Phase 4 — docs + +- `website/guide/video.md`: new "Chained video output" section explaining + `--frames N`, motion tail, and the server endpoint. +- `CHANGELOG.md`: Unreleased/Added entry. +- `.claude/skills/mold/SKILL.md`: new CLI flags + endpoint. + +## Verification commands + +Run these in order after any Phase 1d change to verify nothing regressed: + +```bash +cargo fmt -p mold-ai-inference -- --check +cargo check -p mold-ai-inference +cargo test -p mold-ai-inference --lib ltx2::chain:: # orchestrator + tail helpers +cargo test -p mold-ai-inference --lib # full 586-test sweep (~35 s) +cargo test -p mold-ai-core # sanity +``` + +Phase 1d's own tests should live alongside existing `pipeline.rs::tests` +patterns (using `with_runtime_session` injection at +`crates/mold-inference/src/ltx2/pipeline.rs:1062` — the existing test +exercises the runtime without real weights). + +## File map — where everything lives now + +``` +NEW (this branch): + crates/mold-core/src/chain.rs # wire types + normalise + crates/mold-core/tests/chain_client.rs # wiremock integration + crates/mold-inference/src/ltx2/chain.rs # ChainTail + orchestrator + +MODIFIED (this branch): + crates/mold-core/src/lib.rs # re-exports + crates/mold-core/src/types.rs # pub(crate) base64_opt + crates/mold-core/src/client.rs # generate_chain{,_stream} + crates/mold-inference/src/ltx2/mod.rs # pub use chain::* + crates/mold-inference/src/ltx2/conditioning.rs # StagedLatent + crates/mold-inference/src/ltx2/runtime.rs # latents loop + Fix A + +TARGETS (Phase 1d): + crates/mold-inference/src/ltx2/pipeline.rs # Ltx2Engine::generate_with_carryover + crates/mold-inference/src/ltx2/runtime.rs # tail-capture slot on session + +TARGETS (Phase 2+): + crates/mold-server/src/routes_chain.rs # NEW + crates/mold-server/src/lib.rs # route registration + crates/mold-cli/src/main.rs # auto-route + crates/mold-cli/src/commands/generate.rs # chain path + local parity + website/guide/video.md # docs + CHANGELOG.md + .claude/skills/mold/SKILL.md +``` + +## Convention reminders + +- Feature branch: `feat/render-chain-v1` (currently committing directly to + local `main` since pre-push). PR target: `main`. +- Commit scopes: `feat(chain)`, `fix(chain)`, `test(chain)`, `docs(chain)` + (core), or `feat(ltx2)`, `feat(server)`, `feat(cli)` depending on crate. +- **No mid-plan push.** All work accumulates locally until Phase 4 ends. +- Every phase step ends with a commit; verification (`fmt`, `test`) + between every step. +- Tests must be weight-free. Use the trait-seam pattern (Phase 1c) or the + `with_runtime_session` injection pattern (`pipeline.rs:1062`). + +--- + +## The prompt + +Paste from here into a fresh Claude Code session: + +--- + +I'm continuing work on **render-chain v1** — server-side chained LTX-2 video +generation for the mold repo. + +## Read first, in this order + +1. `CLAUDE.md` (both global at `~/.claude-personal/CLAUDE.md` and + `/Users/jeffreydilley/github/mold/CLAUDE.md`). +2. `tasks/render-chain-v1-plan.md` — full design, signed-off decisions. +3. `tasks/render-chain-v1-handoff.md` — status, remaining work, gotchas. + **This is your primary briefing.** Read it end-to-end before writing code. + +## Status on entry + +- 6 commits stacked locally on `main`, not pushed (per plan convention). + Last commit: `14801c7 feat(ltx2): chain orchestrator with motion-tail carryover loop`. +- Phase 0 (core wire types + client) and Phase 1a/b/c (ltx2 chain types, + StagedLatent plumbing, orchestrator + fake-renderer tests) are done. +- `mold-inference` has 586 tests passing, `mold-core` 617. Nothing loads + candle weights. Fmt clean. +- `CLAUDE.md`'s claim that `mold-inference` has `[lib] test = false` is + **stale** — the previous session verified tests run normally. + +## What you're doing + +**Phase 1d** — the engine integration that makes the orchestrator actually +render. Spec in `render-chain-v1-handoff.md` under "Phase 1d". In one +sentence: implement `impl ChainStageRenderer for Ltx2Engine` by adding a +tail-capture slot to `Ltx2RuntimeSession` and a +`Ltx2Engine::generate_with_carryover` method that populates +`plan.conditioning.latents` from the `ChainTail` input and returns the +captured tail alongside the response. + +Key surgery points already scouted: +- Tail capture immediately before `vae.decode` at + `crates/mold-inference/src/ltx2/runtime.rs:2010` +- Plan's staged-latents plumbing already works — + `maybe_load_stage_video_conditioning` accepts pre-encoded latents when + you populate `plan.conditioning.latents` (Phase 1b). + +After Phase 1d, Phases 2 (server route), 3 (CLI), and 4 (docs) per the plan. + +## How to work + +- Use `superpowers:subagent-driven-development` — the plan is sized for it. +- Use `superpowers:verification-before-completion` before claiming any + phase done. The handoff doc has the exact verification commands. +- Every step ends with a commit. Commit scope `feat(ltx2)` for Phase 1d. +- Do NOT push anything — plan convention is no mid-plan push. +- Do NOT re-litigate the signed-off design decisions in the handoff doc. +- Tests must be weight-free (use the `with_runtime_session` injection + pattern from `pipeline.rs:1062` or the trait seam shipped in Phase 1c). + +## Start here + +1. Run `git status && git log --oneline -7` to confirm the 6 commits are + on the tree. +2. Read `tasks/render-chain-v1-handoff.md` end-to-end. +3. Delegate an Explore subagent to map `Ltx2RuntimeSession` and the full + `Ltx2Engine::generate` → `generate_inner` → `render_native_video` call + chain end-to-end before writing code. Cite file:line throughout. Keep + the report under 2000 words. +4. Then plan the tail-capture mechanism (decide: field on + `Ltx2RuntimeSession` vs. threaded parameter, ergonomics tradeoffs). +5. Implement. Commit. Then Phase 2. + +If you hit a surprise that invalidates an assumption in the plan or +handoff doc, stop and re-plan rather than papering over it. diff --git a/tasks/render-chain-v1-plan.md b/tasks/render-chain-v1-plan.md new file mode 100644 index 00000000..d0324863 --- /dev/null +++ b/tasks/render-chain-v1-plan.md @@ -0,0 +1,410 @@ +# Render Chain v1 — Implementation Plan + +> Server-side chained video generation for LTX-2: generate videos of arbitrary length by stringing together multiple per-clip renders and stitching the results. v1 exposes a single-prompt/arbitrary-length UX; the request shape is **stages-based from day one** so the eventual movie-maker (multi-prompt, multi-keyframe) extends without a breaking change. + +## Confirmed design decisions (signed off 2026-04-20) + +1. **Trim over-production from the tail** of the final clip, not the head. The head carries the user's starting image anchor and is perceptually load-bearing; tail frames are the freshest continuation but cheapest to lose. +2. **Per-stage seed derivation: `stage_seed = base_seed ^ ((stage_idx as u64) << 32)`.** Deterministic, reproducible, avoids identical-noise artefacts when prompts match across stages. `ChainStage::seed_offset` stays reserved as the v2 movie-maker override hook. +3. **Fail closed on mid-chain failure.** If any stage errors, return 502 and discard all prior stages. No partial stitch is ever written to the gallery. Partial-resume is a v2 movie-maker feature. +4. **1 GB RAM ceiling for the accumulation buffer.** Hold decoded `RgbImage`s in memory through the stitch — acceptable for the 400-frame 1216×704 target. Revisit with streaming encode when someone pushes 1000+ frames. +5. **Single-GPU per chain.** The orchestrator runs every stage on the GPU the engine was loaded onto. Multi-GPU stage fan-out is a v2 perf win; docs mention it, code doesn't build it. + +**Goal:** `mold run ltx-2-19b-distilled:fp8 "a cat walking" --image cat.png --frames 400` produces a single 400-frame MP4, stitched from ~4 coherent sub-clips, each seeded by a motion tail of latents from the prior clip. + +**Scope (v1):** + +- LTX-2 only (other video engines intentionally out of scope). +- Single prompt replicated across all stages. Optional starting image on stage 0. +- Motion-tail carryover **using cached latents in-process** (no VAE re-encode between clips). +- Single stitched output to the gallery. No per-clip gallery rows, no `chain_id` grouping. +- Sequential execution (clip N+1 waits for N). Multi-GPU fan-out is v2. +- Server-side orchestration under a new `/api/generate/chain[/stream]` route. CLI auto-routes when `--frames > max_per_clip`. + +**Explicitly NOT in v1:** + +- Movie maker UI (that's v2, built on the same server API). +- Per-stage prompts/keyframes (the request shape supports them; the CLI doesn't expose them yet). +- Crossfade / colour-matching at clip boundaries. +- Pause/resume/retry of a partial chain. + +**Base branch:** `main` · **Feature branch:** `feat/render-chain-v1` · **PR target:** `main` + +--- + +## The compatibility contract + +The key architectural decision: **the wire format is already multi-stage.** v1 auto-synthesises the stages list from a single prompt + total length, but the server only ever sees the stages form. That means v2 (movie maker) is additive — the SPA just lets the user author the stages list by hand, no server breaking changes. + +```json +POST /api/generate/chain +{ + "model": "ltx-2-19b-distilled:fp8", + "stages": [ + { "prompt": "a cat walking", "frames": 97, "source_image": "" }, + { "prompt": "a cat walking", "frames": 97 }, + { "prompt": "a cat walking", "frames": 97 }, + { "prompt": "a cat walking", "frames": 97 } + ], + "motion_tail_frames": 4, + "width": 1216, "height": 704, "fps": 24, + "seed": 42, "steps": 8, "guidance": 3.0, "strength": 1.0, + "output_format": "mp4" +} +``` + +Or the auto-expand form (what v1 CLI sends): + +```json +POST /api/generate/chain +{ + "model": "ltx-2-19b-distilled:fp8", + "prompt": "a cat walking", + "total_frames": 400, + "clip_frames": 97, + "source_image": "", + "motion_tail_frames": 4, + "width": 1216, "height": 704, "fps": 24, + "seed": 42, "steps": 8, "guidance": 3.0, "strength": 1.0, + "output_format": "mp4" +} +``` + +Server-side, a canonicalising function collapses the auto-expand form into stages. From the engine's POV there's only ever a `Vec`. + +--- + +## File map + +### New + +``` +crates/mold-core/src/chain.rs -- ChainStage, ChainRequest, ChainResponse types +crates/mold-inference/src/ltx2/chain.rs -- LTX-2 chain orchestrator + latent-tail carry +crates/mold-server/src/routes_chain.rs -- POST /api/generate/chain[/stream] +``` + +### Modified + +``` +crates/mold-core/src/lib.rs -- re-export chain types +crates/mold-core/src/client.rs -- MoldClient::generate_chain[_stream]() +crates/mold-inference/src/ltx2/mod.rs -- pub use chain::{Ltx2ChainOrchestrator, ChainTail} +crates/mold-inference/src/ltx2/pipeline.rs -- expose internal render path that returns (VideoData, ChainTail) +crates/mold-inference/src/ltx2/runtime.rs -- thread ChainTail through run_real_distilled_stage +crates/mold-server/src/lib.rs -- route registration +crates/mold-server/src/queue.rs -- chain handler uses ModelCache but does NOT enqueue via the existing video job queue (reason in §3) +crates/mold-cli/src/main.rs -- auto-route --frames > clip_max to /api/generate/chain +crates/mold-cli/src/commands/generate.rs -- chain client + progress rendering +CHANGELOG.md +website/guide/video.md -- document --frames N and the chain endpoint +``` + +--- + +## Conventions + +- All new Rust code gets unit tests where the logic is pure (stage expansion, tail shape math, concat-drop math). The orchestrator's end-to-end path is covered by an integration test that swaps in a fake engine. +- `mold-inference` crate has `test = false` on the `lib` target — new tests in `ltx2/chain.rs` must either run under `#[cfg(test)] mod tests` with logic that doesn't touch candle weights, or use the fake-engine pattern. Keep tests weight-free. +- CLI manual UAT runs against BEAST (`MOLD_HOST=http://beast:7680`) with `ltx-2-19b-distilled:fp8`. +- Commit scopes: `feat(chain): …`, `fix(chain): …`, `test(chain): …`, `docs(chain): …`. +- Every task ends with a commit. No mid-plan push. + +--- + +## Phases + +### Phase 0 — core types (no-op at runtime) + +**0.1. Add `mold-core::chain` module with wire types.** + +```rust +// crates/mold-core/src/chain.rs +pub struct ChainStage { + pub prompt: String, + pub frames: u32, + pub source_image: Option>, // PNG bytes + pub negative_prompt: Option, // future-proof; v1 ignores if Some + pub seed_offset: Option, // v2 hook; v1 derives from base seed +} + +pub struct ChainRequest { + pub model: String, + pub stages: Vec, // canonical form + #[serde(default)] + pub motion_tail_frames: u32, // 0 = single-frame handoff; >0 = multi-frame tail + pub width: u32, pub height: u32, pub fps: u32, + pub seed: Option, pub steps: u32, pub guidance: f64, + pub strength: f64, // applied to stage[0].source_image only + pub output_format: OutputFormat, + pub placement: Option, + // auto-expand form (server normalises): + pub prompt: Option, + pub total_frames: Option, + pub clip_frames: Option, + pub source_image: Option>, +} + +pub struct ChainResponse { pub video: VideoData, pub stage_count: u32, pub gpu: Option } +``` + +- Add a `normalise(self) -> Result` that collapses the auto-expand fields into stages when `stages.is_empty()`. +- Validation: at least one stage, each stage has `frames` satisfying 8k+1 and > 0, total stages × clip_frames ≤ 16 (early guardrail — users aren't generating feature films with this yet). +- Tests: `normalise_splits_single_prompt_into_stages`, `normalise_preserves_first_stage_image`, `normalise_rejects_empty`, `normalise_rejects_non_8k1_frames`. + +Commit: `feat(chain): add core wire types and request normalisation`. + +**0.2. Re-export from `mold_core`, add `MoldClient::generate_chain`/`generate_chain_stream`.** + +Mirror the existing `generate` / `generate_stream` shape. No server changes yet — client just has the surface area. + +Commit: `feat(core): MoldClient chain methods`. + +--- + +### Phase 1 — LTX-2 chain orchestrator (single GPU, in-process) + +**1.1. Define `ChainTail` as the latent-carryover payload.** + +```rust +// crates/mold-inference/src/ltx2/chain.rs +pub struct ChainTail { + pub frames: u32, // number of pixel frames this tail represents + pub latents: Tensor, // [1, C, tail_latent_frames, H/32, W/32] on the engine device + pub last_rgb_frame: RgbImage, // for fallback + debugging +} +``` + +The VAE temporal ratio is 8 with causal first frame, so `tail_latent_frames = ((tail_pixel_frames - 1) / 8 + 1).max(1)`. For `motion_tail_frames=4` this is 1 latent frame. For `motion_tail_frames=9` it's 2 latent frames. Tests cover the arithmetic. + +**1.2. Extend `Ltx2Engine` with a chain-aware generate path.** + +Add a method that `generate` proper delegates to: + +```rust +impl Ltx2Engine { + pub fn generate_with_carryover( + &mut self, + req: &GenerateRequest, + carry: Option<&ChainTail>, + ) -> Result<(GenerateResponse, ChainTail)>; +} +``` + +When `carry = None`, behaviour is identical to `self.generate(req)` (use the source_image path as today). When `carry = Some(tail)`, the engine: + +1. Skips VAE encode on `stage_conditioning` for the keyframe at frame 0. +2. Instead, threads `tail.latents` straight into `maybe_load_stage_video_conditioning` via a new optional parameter. The patchified tail tokens go into `StageVideoConditioning::replacements` with `strength = 1.0` and `start_token = 0..tail_token_count`. +3. Extracts the last `K = motion_tail_frames` pixel frames' worth of latents from the completed denoise (before VAE decode) and returns them as the new `ChainTail`. + +The new latent extraction hook needs to run **after the last denoise step, before `vae.decode`** in the distilled and two-stage paths. Surface it as a single helper `extract_tail_latents(&final_latents, motion_tail_frames) -> Tensor` that narrows along the time axis. + +- Tests for the helper: `extract_tail_computes_correct_latent_slice`, `extract_tail_preserves_device_and_dtype`, `extract_tail_handles_single_frame_edge_case`. + +**1.3. Stage conditioning: accept pre-encoded latents instead of a staged image.** + +Currently `maybe_load_stage_video_conditioning` (`runtime.rs:1215`) reads an image path, decodes, VAE-encodes. Add a sibling path that accepts `Option<&Tensor>` as pre-patchified tokens (or raw latents to be patchified in place). Route through it when the orchestrator passes carryover. + +Concretely: a new variant on `StagedImage` or a parallel `StagedLatent` struct carried through `StagedConditioning`. Prefer the latter — keeps the existing image path pristine. + +```rust +pub struct StagedLatent { + pub latents: Tensor, // [1, C, T, H/32, W/32] + pub frame: u32, // start frame (0 for chain carryover) + pub strength: f32, // 1.0 for chain +} + +pub struct StagedConditioning { + pub images: Vec, + pub latents: Vec, // NEW, empty for today's callers + pub audio_path: Option, + pub video_path: Option, +} +``` + +`maybe_load_stage_video_conditioning` iterates `images` then `latents`, patchifying the latter directly without calling `vae.encode`. All existing call sites pass an empty `latents` Vec. + +- Test: `staged_latent_produces_same_replacement_token_shape_as_image_for_single_latent_frame`. + +**1.4. Build `Ltx2ChainOrchestrator`.** + +```rust +// crates/mold-inference/src/ltx2/chain.rs +pub struct Ltx2ChainOrchestrator<'a> { + engine: &'a mut Ltx2Engine, +} + +impl<'a> Ltx2ChainOrchestrator<'a> { + pub fn run( + &mut self, + req: &ChainRequest, + progress: Option, + ) -> Result; +} +``` + +Internal loop: + +``` +let mut tail: Option = None; +let mut accumulated_frames: Vec = Vec::new(); +let tail_drop = req.motion_tail_frames as usize; + +for (idx, stage) in req.stages.iter().enumerate() { + let per_clip = build_clip_request(stage, &req, tail.is_some())?; + let (resp, new_tail) = self.engine.generate_with_carryover(&per_clip, tail.as_ref())?; + let frames = decode_video_frames_from_response(&resp)?; + if idx == 0 { + accumulated_frames.extend(frames); + } else { + // drop the leading `tail_drop` pixel frames; they duplicate the prior clip's tail + accumulated_frames.extend(frames.into_iter().skip(tail_drop)); + } + tail = Some(new_tail); + emit_progress(progress.as_ref(), ChainStageDone { idx, total: req.stages.len() }); +} + +let stitched = encode_mp4(&accumulated_frames, req.fps)?; +Ok(ChainResponse { video: stitched, ... }) +``` + +- Stage-1 request has `source_image = stage.source_image`, `keyframes = None`. +- Stage-N request (N ≥ 2) has `source_image = None`, `keyframes = None`; the carryover is passed via the `tail` parameter to `generate_with_carryover`, not through the request DTO. +- Progress events: forward engine events with an added `stage_idx`, plus emit `ChainStageStart` / `ChainStageDone` / `ChainStitching` / `ChainComplete`. + +- Tests (fake engine): `chain_runs_all_stages_and_drops_tail_prefix_from_continuations`, `chain_with_zero_tail_concats_full_clips_without_drop`, `chain_progress_forwards_engine_events_with_stage_idx`, `chain_empty_stages_errors`. + +Commit: `feat(ltx2): chain orchestrator with latent-tail carryover`. + +--- + +### Phase 2 — server route + +**2.1. `POST /api/generate/chain` (non-streaming).** + +Handler flow: + +1. Parse & normalise the `ChainRequest`. +2. Validate model is an LTX-2 family (`anyhow::bail!` with a clear error otherwise). +3. Grab the model's engine from `ModelCache` (load if needed, same as the existing video path). +4. Construct `Ltx2ChainOrchestrator` against it and call `run()`. +5. Save the stitched MP4 via the same save path as single-clip videos (`save_video_to_dir`), populating `OutputMetadata` with a synthetic prompt (`stages[0].prompt` for v1) and a note in a new optional metadata field `chain_stage_count: Option`. +6. Return `ChainResponse` as JSON. + +Do **not** go through the existing single-job queue — a chain is a long-running compound job and would block the queue for 10+ minutes. Instead, the handler holds the `ModelCache` mutex the same way the multi-GPU worker does, for the full chain duration. This is OK because the multi-GPU pool already has per-GPU thread isolation. + +**2.2. `POST /api/generate/chain/stream` (SSE).** + +Same flow but progress events stream as `data:` frames. Event types: + +- `chain_start { stage_count, estimated_total_frames }` +- `stage_start { stage_idx }` +- `denoise_step { stage_idx, step, total }` (forwarded from engine with `stage_idx` wrapped in) +- `stage_done { stage_idx, frames_emitted }` +- `stitching { total_frames }` +- `complete { video_frames, video_fps, video_base64, filename, seed, ... }` (same shape as `/api/generate/stream` complete event) +- `error { message }` + +The existing SSE completion-event helper (`build_sse_complete_event` in `queue.rs`) is not reusable as-is because it takes a single `GenerateResponse`; write a sibling `build_chain_sse_complete_event(&ChainResponse)` that produces the same JSON structure plus `chain_stage_count`. + +- Tests: route-level tests with a fake engine that exercise both non-streaming and SSE shapes; verify SSE emits events in the expected order. + +Commit: `feat(server): chain render endpoint and SSE stream`. + +--- + +### Phase 3 — CLI + +**3.1. Auto-route `mold run` to `/api/generate/chain` when `--frames > max_per_clip`.** + +Add a constant in `mold-cli` for LTX-2 clip caps (97 for 19B distilled, 97 for 22B — same as today's single-clip validation). When `frames > cap`: + +- Build a `ChainRequest` with `prompt=…`, `total_frames=…`, `clip_frames=cap`, `source_image=…`, `motion_tail_frames=4` (default). +- Call `MoldClient::generate_chain_stream`. +- Render a progress bar per stage stacked with a parent "chain" bar. + +When `frames ≤ cap`, path is unchanged (`/api/generate/stream`, single clip, today's behaviour). + +- New flag: `--clip-frames N` to let advanced users override the per-clip length (default = model cap). +- New flag: `--motion-tail N` to override the tail (default 4, 0 to disable). +- Help text for `--frames` updates to mention chained output when > cap. + +- Tests: `run_frames_above_cap_selects_chain_endpoint` (argparse-level; doesn't invoke the network). + +**3.2. `--local` chain mode.** + +For parity with `mold run --local`, the CLI should run the orchestrator in-process when `--local` is passed. Factor the orchestrator invocation into a helper so both the server handler and the CLI local path share it. + +Commit: `feat(cli): chain rendering for --frames above clip cap`. + +--- + +### Phase 4 — docs & changelog + +**4.1. Website.** Add a new section in `website/guide/video.md` explaining chained video output, how motion tail works, and the CLI flags. Link it from the LTX-2 model page. + +**4.2. CHANGELOG.** Unreleased / Added entry describing the `/api/generate/chain` route, the CLI auto-routing behaviour, and the motion-tail carryover. + +**4.3. Skill file.** Update `.claude/skills/mold/SKILL.md` with the new CLI flags and endpoint. + +Commit: `docs(chain): guide, changelog, and skill updates`. + +--- + +## Integration test: a realistic end-to-end + +One integration test lives in `crates/mold-server/tests/chain_integration.rs` (or inline in `tests/` if an integration dir exists). It: + +1. Stands up an in-process server with a **fake LTX-2 engine** (not real weights) whose `generate_with_carryover` returns a deterministic gradient pattern + a synthetic `ChainTail` whose latents are zeros but whose RGB tail frame is the last frame of the emitted clip. +2. POSTs an auto-expand chain request with `total_frames=200`, `clip_frames=97`, `motion_tail_frames=4`. +3. Asserts: + - Three stages fired. + - The stitched MP4 has `ceil((200 - 97) / 93) * 93 + 97 = 97 + 93*2 = 283 ≥ 200` frames before trim; after trim it's 200 frames. + - SSE stream emitted events in the expected order. + - The gallery DB got one row with `chain_stage_count = 3`. + +The fake-engine pattern keeps this test out of the GPU path and makes it safe to run in CI. + +--- + +## Open design decisions I'm flagging for your sign-off + +1. **Trim policy.** If `total_frames = 400` and chain math produces 469 frames, should we trim from the tail (final clip's final frames get cut — but those are the freshest continuation) or from the head (stage-0 frames get cut — but those are the user-anchored ones)? I recommend **trim from tail** because the head is where the user's starting image landed and matters more perceptually. + +2. **Seed handling across stages.** Should each stage get the same seed (reproducible but with artifacts from identical noise when prompts match), or derive per-stage seeds (`base_seed ^ (stage_idx << 32)`)? I recommend **derive per-stage**. `seed_offset` on `ChainStage` lets the movie maker override. + +3. **Failure mode mid-chain.** If stage 3 of 4 fails, do we return a 502 and discard everything, or return the partial stitch of stages 1–3? I recommend **fail closed for v1** — no partial output. Partial resume is a v2 movie-maker feature where individual stage regen is first-class. + +4. **Memory.** 400 frames × 1216×704×3 ≈ 1 GB of RgbImages held in RAM before MP4 encode. Acceptable for v1. If users push to 1000+ frames we revisit with streaming encode. + +5. **Placement.** Chain always runs on a single GPU for v1 (the one the engine was loaded onto). Multi-GPU fan-out (stage N and N+1 on different cards) is a v2 perf win; mention in docs but don't build. + +--- + +## What `mold run` looks like after this ships + +```console +$ mold run ltx-2-19b-distilled:fp8 "a cat walking through autumn leaves" \ + --image cat.png --frames 400 + +⏳ Chain render: 4 stages × 97 frames (motion tail: 4) → 388 stitched frames +▸ Stage 1/4 · denoise step 8/8 · 47s +▸ Stage 2/4 · denoise step 8/8 · 44s (tail carried from stage 1) +▸ Stage 3/4 · denoise step 8/8 · 44s +▸ Stage 4/4 · denoise step 8/8 · 44s +▸ Stitching 388 frames @ 24fps … +✔ Saved mold-ltx-2-19b-distilled-{ts}.mp4 (400 frames, 16.7s, 16MB) +``` + +--- + +## Out-of-scope for v1 but in-scope for v2 (movie maker) + +- SPA route `/movie` with a timeline authoring UI. +- Per-stage prompts and keyframes exposed in the request body (the server already supports this — only the UI needs to change). +- Per-clip gallery rows with `chain_id` grouping so users can iterate on individual stages. +- Selective stage regeneration (replace stage 2 without redoing 1/3/4). +- Crossfade blending at clip boundaries. +- Multi-GPU stage fan-out. + +The whole point of v1 is to ship a stable foundation these land on top of without breaking changes. diff --git a/tasks/sd3-clip-77-truncation-handoff.md b/tasks/sd3-clip-77-truncation-handoff.md new file mode 100644 index 00000000..093d84c6 --- /dev/null +++ b/tasks/sd3-clip-77-truncation-handoff.md @@ -0,0 +1,249 @@ +# SD3 CLIP-L/G 77-token truncation bug — handoff + +> Paste the prompt at the bottom of this file into a fresh Claude Code session +> to pick up where this one left off. + +## TL;DR + +`sd3.5-large:q8` (and every other `sd3*` family) fails with +`"shape mismatch in broadcast_add, lhs: [1, N, 768], rhs: [1, 77, 768]"` +whenever the prompt tokenises to `N > 77` tokens. Observed three times in +killswitch's server log over the last few hours (seq lengths 130, 131, 132). +The shared SD1.5/SDXL encoder path already truncates to 77 correctly; the +SD3-specific wrapper regressed. + +## Repro + +Any sd3 family, any prompt long enough to exceed 77 CLIP tokens: + +```bash +# Remote: +curl -sS -X POST http://beast:7680/api/generate \ + -H 'content-type: application/json' \ + -d '{"model":"sd3.5-large:q8","prompt":"'"$(python3 -c 'print("a highly detailed " * 30)')"'","width":1024,"height":1024,"steps":20,"guidance":4.0}' | head -c 300 +# → HTTP 500 with "shape mismatch in broadcast_add, lhs: [1, 132, 768], rhs: [1, 77, 768]" +``` + +Same bug on `--local` once an sd3 engine builds. + +## Root cause (with file:line) + +**`crates/mold-inference/src/encoders/sd3_clip.rs:86-97`** — the +`ClipWithTokenizer::encode_text_to_embedding` method tokenises the prompt +and pads UP to `max_position_embeddings` (77) but never truncates DOWN when +the tokeniser returns more than 77 ids: + +```rust +let mut tokens = self.tokenizer.encode(prompt, true)...get_ids().to_vec(); +let eos_position = tokens.len() - 1; // ← overshoots when len > 77 +while tokens.len() < self.max_position_embeddings { + tokens.push(pad_id); // ← pads up only +} +let tokens = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; +let (_, text_embeddings_penultimate) = + clip.forward_until_encoder_layer(&tokens, usize::MAX, -2)?; +// … +last_hidden.i((0, eos_position, ..))? // ← also out-of-bounds when len > 77 +``` + +CLIP's position-embedding table holds exactly 77 entries. When the token +tensor is `[1, 132]`, the internal `embedding + position_embedding` add +hits `lhs: [1, 132, 768], rhs: [1, 77, 768]`, which is the error the user +surfaced. `eos_position = tokens.len() - 1` is also out-of-bounds for the +pooled-output slice when it happens to not panic before the add fires. + +**Compare to the correct shared path** at +`crates/mold-inference/src/encoders/clip.rs:105-107`: + +```rust +let mut tokens = ...get_ids().to_vec(); +// CLIP hard limit: 77 tokens (including BOS/EOS) +tokens.truncate(77); +``` + +Same CLIP tokeniser, same 77-limit constant — the SD3 wrapper just +forgot to call `truncate`. + +**Blast radius.** `ClipWithTokenizer` is used for BOTH SD3's CLIP-L +(`encoders/sd3_clip.rs:232` — `encode_text_to_embedding` on `clip_l`) AND +CLIP-G (`:236` — same method on `clip_g`). Fixing the helper once fixes +both. Every `sd3*` model is affected: + +- `sd3.5-large:q8` (observed — three failures) +- `sd3.5-large:fp16` +- `sd3.5-medium:*` +- `sd3-large:*` (base sd3, if still in the manifest) + +SD1.5, SDXL, FLUX, Flux.2, Z-Image, LTX-Video, LTX-2, Qwen-Image, +Wuerstchen — all unaffected (different encoders, different truncation +paths already verified). + +## The fix (sketch — do not paste blindly) + +```rust +let raw_tokens = self.tokenizer.encode(prompt, true) + .map_err(|e| anyhow::anyhow!("CLIP tokenization failed: {e}"))? + .get_ids() + .to_vec(); + +// Truncate to max_position_embeddings, preserving EOS as the last slot. +// CLIP's pooled output reads from the EOS position; losing EOS breaks +// the pooled branch silently. +let eos_id = *raw_tokens.last().unwrap_or(&pad_id); +let mut tokens = raw_tokens; +if tokens.len() > self.max_position_embeddings { + tokens.truncate(self.max_position_embeddings); + *tokens.last_mut().expect("non-empty after truncate") = eos_id; +} +let eos_position = tokens.len() - 1; +while tokens.len() < self.max_position_embeddings { + tokens.push(pad_id); +} +``` + +Three subtle points to not miss: + +1. **EOS preservation** — CLIP's pooled output pulls from the EOS-position + hidden state. If you just `tokens.truncate(77)` you lose EOS when the + raw length exceeds 77; the pooled branch then reads a content token's + hidden state, which changes output. The shared encoder at + `encoders/clip.rs:107` gets away with a bare `truncate(77)` because + that path doesn't compute a pooled-at-EOS output at all — it returns + only the last-layer `forward(...)` result. +2. **`eos_position` recompute** — it's fine to leave as + `tokens.len() - 1` after truncate-then-pad, but only because the pad + step happens after the truncate. If you flip the order, `eos_position` + lands in the pad region. +3. **Don't silently warn** on overlong prompts unless logging is cheap — + users may hit this with an expanded prompt that's 80 tokens. Prefer a + single `tracing::debug!` with the truncation count so the CLI doesn't + spam on every generation. + +## Verification + +Weight-free unit test (no candle weights — small synthetic model is fine): + +1. Build a `ClipWithTokenizer` against the real SDXL CLIP-L tokenizer JSON + (it's a file-based tokeniser, no weights needed to tokenise). +2. Run `encode_text_to_embedding` on a 50-token prompt and a 200-token + prompt in the same test. Assert both return `[1, 77, 768]` penultimate + and `[768]` pooled (the pooled shape is actually `[768]` after the `i` + slice at line 104 — confirm). +3. Actually this test path is weight-bearing (`clip.forward_until_encoder_layer` + needs real CLIP weights). Split into a pure tokeniser-level test that + just verifies `tokens.len() == 77` after the new truncation logic, with + `tokens[76] == eos_id` when the raw tokeniser output exceeded 77. + +Integration test: hit `http://beast:7680/api/generate` against +`sd3.5-large:q8` with a 300-char prompt after the fix ships. Expected: +200 + image bytes, no shape-mismatch error in `~/.mold/logs/server.log`. + +## Render-chain v1 context (the other active thread) + +This session just finished **render-chain v1**. The state: + +- Branch `feat/render-chain-v1` is pushed to origin, 12 commits ahead of + `origin/main`. PR not opened yet — the URL stub is + `https://github.com/utensils/mold/pull/new/feat/render-chain-v1`. +- Last commit on the branch: `766322e fix(cli): pass owned String to + create_engine in local chain path` — caught only when the CUDA build + ran on killswitch (Phase 3 feature-matrix check omitted `cuda`/`metal`). +- killswitch is running the new binary at `766322e` as PID 1199380 + (`./target/release/mold serve --bind 0.0.0.0 --gpus 0,1`), logs at + `~/.mold/logs/server.log`. `MOLD_HOST=http://beast:7680` reaches it. +- Local `main` is parity with `origin/main` at `1410d08`. The chain work + lives on the feature branch only. +- All four plan phases landed with tests green: `mold-ai-core` 611 pass, + `mold-ai-inference` lib 588 pass, `mold-ai-server` lib 186 pass (+5 + chain route tests), `mold-ai` 369 unit + 38 integration + 11 runpod + (+12 chain CLI tests). `cargo fmt/clippy --workspace -- -D warnings` + clean. Website `bun run fmt:check/verify/build` clean. + +This SD3 CLIP-L/G truncation bug is **NOT** related to render-chain — it +fires on the single-clip SD3 path (`mold_server::gpu_worker::process_job`). +Fixing it is an independent fix that can land on `main` directly, or +stack on top of `feat/render-chain-v1` and go out as part of the chain +PR. Your call. + +## Branch / commit layout + +``` +6211182 (origin/feat/render-chain-v1~1) docs(chain): … +766322e (origin/feat/render-chain-v1, HEAD of feat branch) fix(cli): pass owned String … +1410d08 (origin/main, local main) fix(flux): surface city96-GGUF … +``` + +--- + +## The prompt + +Paste from here into a fresh Claude Code session: + +--- + +I need to fix a CLIP-L/G prompt-truncation bug in the mold repo +(`/Users/jeffreydilley/github/mold`) that currently makes every `sd3*` +model (SD 3.5 Large q8 confirmed) return HTTP 500 with +`"shape mismatch in broadcast_add, lhs: [1, N, 768], rhs: [1, 77, 768]"` +for any prompt that tokenises to more than 77 CLIP tokens. + +## Read first + +1. `tasks/sd3-clip-77-truncation-handoff.md` — full diagnosis, file:line + cites, repro curl, fix sketch, verification plan. Read end-to-end. +2. `crates/mold-inference/src/encoders/sd3_clip.rs:60-131` — the buggy + function (`ClipWithTokenizer::encode_text_to_embedding`). +3. `crates/mold-inference/src/encoders/clip.rs:100-115` — the shared + path that correctly truncates. Contrast to see the regression. + +## Status on entry + +- Branch `main` locally and on origin at `1410d08`. +- `feat/render-chain-v1` is on origin at `766322e` (12 commits ahead of + main). It's unrelated to this bug but is a parallel in-flight PR; + ignore unless you're explicitly asked to stack the fix on it. +- killswitch (BEAST, dual-3090) is running `766322e` as + `mold serve --bind 0.0.0.0 --gpus 0,1`. Reproduce against it via + `MOLD_HOST=http://beast:7680`. Log tail: + `ssh killswitch@192.168.1.67 "tail -f ~/.mold/logs/server.log"`. +- `CLAUDE.md` claims `mold-inference`/`mold-server` have + `[lib] test = false` — **stale, tests run normally**. Verified in + render-chain v1's Phase 1d–4 landings. +- Pre-existing clippy warnings unrelated to this bug (do NOT fix): + `manual_repeat_n` in `mold-core/src/download.rs:1451`, + `field_reassign_with_default` in `mold-core/src/placement_test.rs:167`. + +## What you're doing + +Fix the bug per the handoff's "The fix" section. Short commit, targeted +test, verify on killswitch. One atomic commit, `fix(sd3): truncate CLIP +token sequences to 77 with EOS preserved` (or similar). + +Do NOT push unless the user asks. Do NOT stack on `feat/render-chain-v1` +unless asked — land on `main` as a standalone fix so it can merge +independently of the render-chain PR. + +Verify with: + +```bash +cargo fmt --all -- --check +cargo clippy --workspace -- -D warnings +cargo test -p mold-ai-inference --lib +cargo test -p mold-ai-core +# + a new unit test you add for the truncation helper +``` + +Then optionally rebuild on killswitch and retry the repro curl against +`http://beast:7680` to confirm the user-surfaced symptom is gone. + +## Process + +- Use `superpowers:systematic-debugging` — the root cause is already + in the handoff, but the skill will keep you honest about verifying. +- Use `superpowers:test-driven-development` — write a failing test for + the 200-token case first, then fix, then watch it go green. +- Use `superpowers:verification-before-completion` before claiming done. + +If you discover the bug is broader than the handoff claims (e.g. it +affects another encoder I missed), stop and re-scope rather than +silently widening the fix. diff --git a/web/src/App.vue b/web/src/App.vue index e3556d4c..9ae4ef75 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -1,6 +1,7 @@