Skip to content

Deterministic Nix-flake packaging for JAX, JAX-Fluids, and NEK#257

Closed
ludgerpaehler wants to merge 20 commits into
dynamicslab:mainfrom
ludgerpaehler:nix-flake-packaging
Closed

Deterministic Nix-flake packaging for JAX, JAX-Fluids, and NEK#257
ludgerpaehler wants to merge 20 commits into
dynamicslab:mainfrom
ludgerpaehler:nix-flake-packaging

Conversation

@ludgerpaehler

Copy link
Copy Markdown
Collaborator

Summary

Replaces the Firedrake-only .packaging/ Dockerfile pile with a minimal Ubuntu 22.04 + Nix base image and per-backend × per-GPU-arch Nix flake outputs. flake.lock is the reproducibility anchor — same lockfile + --system x86_64-linux ⇒ bit-identical solver environment.

  • One base image (hydrogym/base:dev) carries only Ubuntu + multi-user Nix + a HydroGym checkout. Every backend-specific dep (CUDA, MPI, Python ecosystem, NEK5000 binary) lives in flake outputs.
  • Five devshells: jax-cuda-{hopper-blackwell,turing-ampere}, jaxfluids-cuda-{hopper-blackwell,turing-ampere}. The NEK builder skeleton ships unwired (case files come in a follow-up).
  • Adds a CI workflow that runs nix flake check on every push.

In scope

Component Status
.packaging/Dockerfile.base — Ubuntu 22.04 + Nix 2.34.7 (multi-user) + tini-shimmed daemon Built + smoke-tested
nix/overlays/nvhpc-26.1.nix — NVHPC SDK 26.1 (CUDA 12.9 + 13.1) FOD, autoPatchelfHook Builds clean (15 GB extract)
nix/lib/{cudaTargets,mkBackendShell}.nix — arch table + per-backend shell factory All 5 shells eval clean
nix/backends/jax/default.nix — pip-installed jax[cuda12]==0.10.1 over a Nix-managed venv GPU validated end-to-end
nix/backends/jaxfluids/default.nix — JAXFLUIDS umbrella from pinned upstream rev Imports + Nozzle2D/3D resolve
nix/backends/nek/{default,nek5000}.nix + third_party/nek5000 (v19.0) Skeleton — no case wired
.github/workflows/nix-build.ymlnix flake check on push Green

Out of scope (deferred follow-ups)

  • Firedrake native Nix build — Phase 4, separate PR.
  • MAIA — Easybuild module path; no Nix coverage planned.
  • NEK case files & flake wiring — case files (SIZE, <case>.usr, <case>.par) for TCFmini_3D_Re180 live outside this repo. The builder & factory ship ready; drop a case in and follow nix/README.md recipe to expose it.
  • Removal of .packaging/Dockerfile.{firedrake_env,hydrogym_env,hydrogym,devpod} — phase 5; safe to delete once consumers have migrated.

Validation evidence (this branch on nix-flake-packaging)

  • nix flake check --no-build → all 5 devShells pass.
  • nix build .#devShells.x86_64-linux.jax-cuda-hopper-blackwell.inputDerivation --no-link → completes (NVHPC extract + autoPatchelf + venv bootstrap).
  • nix develop .#jax-cuda-hopper-blackwell -c python -c "import jax; print(jax.devices())"[CudaDevice(id=0)] (validated on an RTX 5090; PTX/cuDNN-bundled-with-wheel handles device init).
  • Kolmogorov example runs 5 steps end-to-end (CPU run validated; GPU run blocked only by JAX's partial sm_120 FP64 kernel coverage on consumer Blackwell — unrelated to packaging, doesn't affect the actual H100/B100/B200 deployment targets the hopper-blackwell arch table pins).
  • nix develop .#jaxfluids-cuda-hopper-blackwell -c python -c "from hydrogym.jaxfluids import Nozzle2D, Nozzle3D" → both classes resolve; JAXFLUIDS umbrella (jaxfluids + jaxfluids_rl + jaxfluids_thirdparty) installs from the pinned git rev.

Plan-gap discoveries

The spec didn't anticipate these; each is documented in the relevant commit message:

  1. .gitignore lib/ swallowed nix/lib/ — added !nix/lib/ exception.
  2. Multi-user Nix needs a daemon-startup shim under tini (matches nixos/nix image pattern).
  3. NVHPC SDK 26.1's only published variant is cuda_multi.tar.gz (no per-CUDA-version), and the archive has a …/install_components/ wrapper dir.
  4. NVHPC math libs need libgcc_s.so.1 (added stdenv.cc.cc.lib to buildInputs); Nsight Compute drags optional Qt/InfiniBand libs (handled with autoPatchelfIgnoreMissingDeps = true).
  5. python.withPackages produces a read-only env — pip needs a venv layered over it; venv nix-env.pth file (refreshed every shell entry) lets pip-upgrades take precedence over Nix-provided deps.
  6. CUDA_VISIBLE_DEVICES=all is invalid CUDA syntax (driver reports "no devices"); unset means "all visible", which is what we want.
  7. Host driver libs (libcuda.so.1, libnvidia-ml.so.1) need a per-shell symlink farm, not bulk LD_LIBRARY_PATH mount of /usr/lib/x86_64-linux-gnu — the latter shadows nix-store libs and SIGFPEs the dynamic loader.
  8. JAX pin pushed to 0.10.1: pre-Blackwell wheels don't ship sm_100/sm_120 kernels.
  9. nixpkgs-24.05 builds matplotlib/h5py/scipy/pandas/flax against numpy 1.x; jax[cuda12] needs numpy 2.x. Resolved by moving every numpy-C-extension dep to pip; pure-Python deps stay in Nix.
  10. jaxfluids_rl isn't a standalone repo — it's a sub-package of tumaer/JAXFLUIDS, added after the v0.2.1 tag. Pinned to main HEAD; pip-installed via git+rev rather than the originally-planned submodule + Nix overlay (the umbrella's install_requires would clash with the cuda12 wheels).

Test plan

  • CI: confirm .github/workflows/nix-build.yml passes on this PR (greenfield workflow, first run on this branch).
  • Reviewer with H100 / B100 / B200 (or downstream user): nix develop .#jax-cuda-hopper-blackwell -c python examples/jax/getting_started/1_kolmogorov/test_kolmogorov_env.py minimize_tke --num-steps 5 should print step output with non-NaN rewards. Same on Turing/Ampere via .#jax-cuda-turing-ampere.
  • Reviewer with H100: same example via the JAX-Fluids shell.
  • Quick docs read-through: README.md Nix section + nix/README.md + .packaging/README.md should each stand on their own.

🤖 Generated with Claude Code

ludgerpaehler and others added 20 commits June 2, 2026 15:22
Captures the brainstormed design for replacing the current ad-hoc
.packaging/ Dockerfiles with one minimal Ubuntu 22.04 + Nix base image
plus per-backend × per-GPU-arch Nix flake outputs. PR scope is phases
1-3 (JAX, JAX-Fluids, NEK); Firedrake is a follow-up PR; MAIA is
deferred until a portable binary distribution exists.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bite-sized tasks for the base Dockerfile, root flake, JAX/JAX-Fluids
GPU outputs (Hopper/Blackwell + Turing/Ampere), and NEK MiniChannel
case. Each task has exact file contents, commands, expected output,
and TDD-style validation steps. Spec coverage and placeholder/type
self-review pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
control + dmsuite aren't in nixpkgs at all. flax-0.8.3 in nixpkgs-24.05
pulls in einops → jupyter → qtconsole → ipython-genutils, the last of
which is incompatible with Python 3.12. Moving control, dmsuite, flax,
chex, optax to the pip-install step in the shellHook — keeps the JAX
ecosystem versions coherent and unblocks `nix flake check` for the JAX
devshells. The JAXFluids shells still error pending Task 12's overlay.
Surfaced by the first JAX Hopper/Blackwell smoke test against a real
device (RTX 5090, sm_120). Six independent issues, all in one commit
because they were all required to make `nix develop .#jax-cuda-hopper-
blackwell` produce a CUDA-visible Python session.

NVHPC overlay
- libgcc_s.so.1 / libstdc++.so.6 added to buildInputs via
  stdenv.cc.cc.lib; required by the cublas / cusparse / cusolver /
  cufft / cutensor math libs that autoPatchelf otherwise couldn't
  satisfy.
- autoPatchelfIgnoreMissingDeps = true; Nsight Compute's optional Qt
  / InfiniBand / NVML plugins reference libs (libqxcb, libib*,
  libnvidia-ml) that aren't on the JAX runtime path. Letting
  autoPatchelf warn keeps the build green without dropping the
  profiler entirely.

mkBackendShell
- Bootstraps a writable venv layered over the Nix env. python.withPackages
  is read-only, so pip refused --user installs and editable installs
  failed; the venv accepts both. VENV_DIR is set BEFORE cudaSetup
  because cudaSetup writes the driver-libs symlink farm inside it.
- LD_LIBRARY_PATH seeded with stdenv.cc.cc.lib + zlib so PyPI wheels
  (numpy 2.x, jaxlib) find libstdc++/libz.
- libcuda.so.1 / libnvidia-ml.so.1 are exposed via a per-shell symlink
  farm at $VENV_DIR/.nv-driver-libs rather than putting the whole
  /usr/lib/x86_64-linux-gnu on LD_LIBRARY_PATH — the latter shadows
  nix-store libs and SIGFPE's the dynamic loader. Probes the standard
  NixOS / Debian-multiarch / RHEL paths.
- CUDA_VISIBLE_DEVICES is no longer defaulted to "all"; on a bare host
  that's not a valid value and the driver reports "no devices".
  Unset means "all visible" — same effect, no breakage.
- nix-env.pth file in the venv's site-packages instead of PYTHONPATH
  manipulation. PYTHONPATH is processed before site.py and would put
  the Nix wrapped-env site-packages ahead of the venv's own — meaning
  pip upgrades (numpy 1.x → 2.x for jax 0.10) couldn't override the
  Nix-pinned versions. .pth entries land AFTER venv site-packages so
  pip wins.

JAX backend pins
- jax[cuda12]==0.10.1 (was 0.4.34). 0.4.34 ships ptxas that predates
  CUDA 12.8 / Blackwell and CUBIN-rejects on H100 sm_90's runtime in
  some configs; 0.10.1 covers sm_90 (H100) + sm_100 (B100/B200) cleanly
  and partially covers sm_120 (consumer Blackwell). NVHPC ships
  everything except cuDNN, so the non-local jax[cuda12] wheel owns the
  JAX-side CUDA stack while NVHPC stays the source of nvcc/nvfortran/
  NCCL for native-code workflows.
- Same pin in the JAX-Fluids backend; mirrors the wider HydroGym pattern.

Validated end-to-end:
- `nix build .#devShells.x86_64-linux.jax-cuda-hopper-blackwell.inputDerivation`
  completes (15 GB NVHPC extract + autoPatchelf + venv bootstrap +
  jax[cuda12] wheels).
- `nix develop .#jax-cuda-hopper-blackwell -c python -c "import jax;
  print(jax.devices())"` -> `[CudaDevice(id=0)]`.
- Kolmogorov example completes 5 reset/step iterations with correct
  TKE/reward output (verified on CPU; GPU pass blocked only by JAX's
  partial sm_120 FP64 kernel coverage — fine on the actual H100 /
  B100 / B200 deployment targets that hopper-blackwell pins to).
jaxfluids_rl turns out to be a sub-package of the tumaer/JAXFLUIDS
umbrella (src/jaxfluids_rl/), not a standalone repo. setup.py calls
find_packages(where="src") so the whole tree ships as one install,
giving us `jaxfluids`, `jaxfluids_rl`, and `jaxfluids_thirdparty` —
all three of which HydroGym imports from. v0.2.1 (the latest tag)
predates jaxfluids_rl, so we pin to main HEAD instead:
9bb1e6c85371d445cbdaaaf9e5699495cff4b371.

The plan's intent (submodule + buildPythonPackage overlay) doesn't
fit cleanly: JAXFLUIDS's install_requires drags jax/jaxlib/flax/optax,
which conflict with the cuda12 wheels we already pip-install. Using
`pip install --no-deps git+...` instead keeps version coherence with
the rest of the JAX ecosystem we pip-install (jax 0.10.1, chex, flax,
optax, control, dmsuite). No submodule needed; no overlay file
needed.

JAX-Fluids backend also restructured its pythonDeps split:
- Pure-Python deps stay in Nix (gymnasium, huggingface-hub, gitpython,
  omegaconf, toml) — no ABI risk.
- numpy-C-extension deps move to pip (scipy, pandas, h5py, matplotlib,
  pyvista) because nixpkgs-24.05's builds target numpy 1.x and crash
  against the numpy 2.x that jax[cuda12]==0.10.1 demands.

mkBackendShell venv .pth file is now refreshed on every shell entry,
not just on venv creation. pythonDeps changes produce a new pythonEnv
store path; a stale .pth would point at a no-longer-existent old env.

Validated end-to-end on the host:
  $ nix flake check --no-build  # all 5 devShells pass
  $ nix develop .#jaxfluids-cuda-hopper-blackwell -c python -c '
        from hydrogym.jaxfluids import Nozzle2D, Nozzle3D
        print(Nozzle2D.__name__, Nozzle3D.__name__)'
  Nozzle2D Nozzle3D
Pin to the v19.0 tag (sha 5afb6179daf69c5f1d0e5a6eab49d6f07e0c622f),
matching the existing HydroGym NEK examples which target the v19
ABI (NEK5000_v19 sentinel in hydrogym/data_manager.py:134).

Submodule only; the per-case builder (nix/backends/nek/{nek5000,
default}.nix) lands in the next commit. Case-specific shells are
deferred to a follow-up PR — the case files (SIZE, *.usr) for any
specific run live outside this repo.
Per-case nek5000 binary derivation + mkNekShell factory under
nix/backends/nek/. Mirrors the layout used by the JAX backends.
The flake.nix root does NOT instantiate any case in this PR — case
files (SIZE, *.usr, *.par) for the canonical TCFmini_3D_Re180
environment live outside this repo. Drop them under
nix/backends/nek/cases/<case>/ and follow the recipe in nix/README.md
to expose `.#nek-cpu-<case>`.

Naming convention: case directories use HuggingFace environment
names (`TCFmini_3D_Re180`, not the older `MiniChannel` alias) so
they pair 1:1 with `dynamicslab/HydroGym-environments` directories.

Python dep split mirrors the JAX-Fluids backend:
- Pure-Python in Nix: gymnasium, huggingface-hub, omegaconf, toml,
  pettingzoo.
- mpi4py from Nix — MUST share pkgs.mpich with the nek5000 binary.
- numpy-C-extension in pip: numpy, scipy, pandas (nixpkgs-24.05
  builds against numpy 1.x and break against pip-resolved numpy
  2.x; same pattern we already use for JAX-Fluids).
- Pure-pip: pymech (not in nixpkgs), stable-baselines3, supersuit,
  tensorboard, control, dmsuite.

`nix flake check` still passes — the new files are inert until a
case is wired in. CI workflow does not gain a nek build step yet
(no case to build against).
@ludgerpaehler

Copy link
Copy Markdown
Collaborator Author

Closing — opened against the wrong repo; correct PR will be opened on ludgerpaehler/hydrogym.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant