Deterministic Nix-flake packaging for JAX, JAX-Fluids, and NEK#257
Closed
ludgerpaehler wants to merge 20 commits into
Closed
Deterministic Nix-flake packaging for JAX, JAX-Fluids, and NEK#257ludgerpaehler wants to merge 20 commits into
ludgerpaehler wants to merge 20 commits into
Conversation
Captures the brainstormed design for replacing the current ad-hoc .packaging/ Dockerfiles with one minimal Ubuntu 22.04 + Nix base image plus per-backend × per-GPU-arch Nix flake outputs. PR scope is phases 1-3 (JAX, JAX-Fluids, NEK); Firedrake is a follow-up PR; MAIA is deferred until a portable binary distribution exists. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bite-sized tasks for the base Dockerfile, root flake, JAX/JAX-Fluids GPU outputs (Hopper/Blackwell + Turing/Ampere), and NEK MiniChannel case. Each task has exact file contents, commands, expected output, and TDD-style validation steps. Spec coverage and placeholder/type self-review pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
control + dmsuite aren't in nixpkgs at all. flax-0.8.3 in nixpkgs-24.05 pulls in einops → jupyter → qtconsole → ipython-genutils, the last of which is incompatible with Python 3.12. Moving control, dmsuite, flax, chex, optax to the pip-install step in the shellHook — keeps the JAX ecosystem versions coherent and unblocks `nix flake check` for the JAX devshells. The JAXFluids shells still error pending Task 12's overlay.
Surfaced by the first JAX Hopper/Blackwell smoke test against a real device (RTX 5090, sm_120). Six independent issues, all in one commit because they were all required to make `nix develop .#jax-cuda-hopper- blackwell` produce a CUDA-visible Python session. NVHPC overlay - libgcc_s.so.1 / libstdc++.so.6 added to buildInputs via stdenv.cc.cc.lib; required by the cublas / cusparse / cusolver / cufft / cutensor math libs that autoPatchelf otherwise couldn't satisfy. - autoPatchelfIgnoreMissingDeps = true; Nsight Compute's optional Qt / InfiniBand / NVML plugins reference libs (libqxcb, libib*, libnvidia-ml) that aren't on the JAX runtime path. Letting autoPatchelf warn keeps the build green without dropping the profiler entirely. mkBackendShell - Bootstraps a writable venv layered over the Nix env. python.withPackages is read-only, so pip refused --user installs and editable installs failed; the venv accepts both. VENV_DIR is set BEFORE cudaSetup because cudaSetup writes the driver-libs symlink farm inside it. - LD_LIBRARY_PATH seeded with stdenv.cc.cc.lib + zlib so PyPI wheels (numpy 2.x, jaxlib) find libstdc++/libz. - libcuda.so.1 / libnvidia-ml.so.1 are exposed via a per-shell symlink farm at $VENV_DIR/.nv-driver-libs rather than putting the whole /usr/lib/x86_64-linux-gnu on LD_LIBRARY_PATH — the latter shadows nix-store libs and SIGFPE's the dynamic loader. Probes the standard NixOS / Debian-multiarch / RHEL paths. - CUDA_VISIBLE_DEVICES is no longer defaulted to "all"; on a bare host that's not a valid value and the driver reports "no devices". Unset means "all visible" — same effect, no breakage. - nix-env.pth file in the venv's site-packages instead of PYTHONPATH manipulation. PYTHONPATH is processed before site.py and would put the Nix wrapped-env site-packages ahead of the venv's own — meaning pip upgrades (numpy 1.x → 2.x for jax 0.10) couldn't override the Nix-pinned versions. .pth entries land AFTER venv site-packages so pip wins. JAX backend pins - jax[cuda12]==0.10.1 (was 0.4.34). 0.4.34 ships ptxas that predates CUDA 12.8 / Blackwell and CUBIN-rejects on H100 sm_90's runtime in some configs; 0.10.1 covers sm_90 (H100) + sm_100 (B100/B200) cleanly and partially covers sm_120 (consumer Blackwell). NVHPC ships everything except cuDNN, so the non-local jax[cuda12] wheel owns the JAX-side CUDA stack while NVHPC stays the source of nvcc/nvfortran/ NCCL for native-code workflows. - Same pin in the JAX-Fluids backend; mirrors the wider HydroGym pattern. Validated end-to-end: - `nix build .#devShells.x86_64-linux.jax-cuda-hopper-blackwell.inputDerivation` completes (15 GB NVHPC extract + autoPatchelf + venv bootstrap + jax[cuda12] wheels). - `nix develop .#jax-cuda-hopper-blackwell -c python -c "import jax; print(jax.devices())"` -> `[CudaDevice(id=0)]`. - Kolmogorov example completes 5 reset/step iterations with correct TKE/reward output (verified on CPU; GPU pass blocked only by JAX's partial sm_120 FP64 kernel coverage — fine on the actual H100 / B100 / B200 deployment targets that hopper-blackwell pins to).
jaxfluids_rl turns out to be a sub-package of the tumaer/JAXFLUIDS
umbrella (src/jaxfluids_rl/), not a standalone repo. setup.py calls
find_packages(where="src") so the whole tree ships as one install,
giving us `jaxfluids`, `jaxfluids_rl`, and `jaxfluids_thirdparty` —
all three of which HydroGym imports from. v0.2.1 (the latest tag)
predates jaxfluids_rl, so we pin to main HEAD instead:
9bb1e6c85371d445cbdaaaf9e5699495cff4b371.
The plan's intent (submodule + buildPythonPackage overlay) doesn't
fit cleanly: JAXFLUIDS's install_requires drags jax/jaxlib/flax/optax,
which conflict with the cuda12 wheels we already pip-install. Using
`pip install --no-deps git+...` instead keeps version coherence with
the rest of the JAX ecosystem we pip-install (jax 0.10.1, chex, flax,
optax, control, dmsuite). No submodule needed; no overlay file
needed.
JAX-Fluids backend also restructured its pythonDeps split:
- Pure-Python deps stay in Nix (gymnasium, huggingface-hub, gitpython,
omegaconf, toml) — no ABI risk.
- numpy-C-extension deps move to pip (scipy, pandas, h5py, matplotlib,
pyvista) because nixpkgs-24.05's builds target numpy 1.x and crash
against the numpy 2.x that jax[cuda12]==0.10.1 demands.
mkBackendShell venv .pth file is now refreshed on every shell entry,
not just on venv creation. pythonDeps changes produce a new pythonEnv
store path; a stale .pth would point at a no-longer-existent old env.
Validated end-to-end on the host:
$ nix flake check --no-build # all 5 devShells pass
$ nix develop .#jaxfluids-cuda-hopper-blackwell -c python -c '
from hydrogym.jaxfluids import Nozzle2D, Nozzle3D
print(Nozzle2D.__name__, Nozzle3D.__name__)'
Nozzle2D Nozzle3D
Pin to the v19.0 tag (sha 5afb6179daf69c5f1d0e5a6eab49d6f07e0c622f),
matching the existing HydroGym NEK examples which target the v19
ABI (NEK5000_v19 sentinel in hydrogym/data_manager.py:134).
Submodule only; the per-case builder (nix/backends/nek/{nek5000,
default}.nix) lands in the next commit. Case-specific shells are
deferred to a follow-up PR — the case files (SIZE, *.usr) for any
specific run live outside this repo.
Per-case nek5000 binary derivation + mkNekShell factory under nix/backends/nek/. Mirrors the layout used by the JAX backends. The flake.nix root does NOT instantiate any case in this PR — case files (SIZE, *.usr, *.par) for the canonical TCFmini_3D_Re180 environment live outside this repo. Drop them under nix/backends/nek/cases/<case>/ and follow the recipe in nix/README.md to expose `.#nek-cpu-<case>`. Naming convention: case directories use HuggingFace environment names (`TCFmini_3D_Re180`, not the older `MiniChannel` alias) so they pair 1:1 with `dynamicslab/HydroGym-environments` directories. Python dep split mirrors the JAX-Fluids backend: - Pure-Python in Nix: gymnasium, huggingface-hub, omegaconf, toml, pettingzoo. - mpi4py from Nix — MUST share pkgs.mpich with the nek5000 binary. - numpy-C-extension in pip: numpy, scipy, pandas (nixpkgs-24.05 builds against numpy 1.x and break against pip-resolved numpy 2.x; same pattern we already use for JAX-Fluids). - Pure-pip: pymech (not in nixpkgs), stable-baselines3, supersuit, tensorboard, control, dmsuite. `nix flake check` still passes — the new files are inert until a case is wired in. CI workflow does not gain a nek build step yet (no case to build against).
Collaborator
Author
|
Closing — opened against the wrong repo; correct PR will be opened on ludgerpaehler/hydrogym. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Replaces the Firedrake-only
.packaging/Dockerfile pile with a minimal Ubuntu 22.04 + Nix base image and per-backend × per-GPU-arch Nix flake outputs.flake.lockis the reproducibility anchor — same lockfile +--system x86_64-linux⇒ bit-identical solver environment.hydrogym/base:dev) carries only Ubuntu + multi-user Nix + a HydroGym checkout. Every backend-specific dep (CUDA, MPI, Python ecosystem, NEK5000 binary) lives in flake outputs.jax-cuda-{hopper-blackwell,turing-ampere},jaxfluids-cuda-{hopper-blackwell,turing-ampere}. The NEK builder skeleton ships unwired (case files come in a follow-up).nix flake checkon every push.In scope
.packaging/Dockerfile.base— Ubuntu 22.04 + Nix 2.34.7 (multi-user) + tini-shimmed daemonnix/overlays/nvhpc-26.1.nix— NVHPC SDK 26.1 (CUDA 12.9 + 13.1) FOD, autoPatchelfHooknix/lib/{cudaTargets,mkBackendShell}.nix— arch table + per-backend shell factorynix/backends/jax/default.nix— pip-installedjax[cuda12]==0.10.1over a Nix-managed venvnix/backends/jaxfluids/default.nix— JAXFLUIDS umbrella from pinned upstream revnix/backends/nek/{default,nek5000}.nix+third_party/nek5000(v19.0).github/workflows/nix-build.yml—nix flake checkon pushOut of scope (deferred follow-ups)
<case>.usr,<case>.par) for TCFmini_3D_Re180 live outside this repo. The builder & factory ship ready; drop a case in and follownix/README.mdrecipe to expose it..packaging/Dockerfile.{firedrake_env,hydrogym_env,hydrogym,devpod}— phase 5; safe to delete once consumers have migrated.Validation evidence (this branch on
nix-flake-packaging)nix flake check --no-build→ all 5 devShells pass.nix build .#devShells.x86_64-linux.jax-cuda-hopper-blackwell.inputDerivation --no-link→ completes (NVHPC extract + autoPatchelf + venv bootstrap).nix develop .#jax-cuda-hopper-blackwell -c python -c "import jax; print(jax.devices())"→[CudaDevice(id=0)](validated on an RTX 5090; PTX/cuDNN-bundled-with-wheel handles device init).hopper-blackwellarch table pins).nix develop .#jaxfluids-cuda-hopper-blackwell -c python -c "from hydrogym.jaxfluids import Nozzle2D, Nozzle3D"→ both classes resolve; JAXFLUIDS umbrella (jaxfluids + jaxfluids_rl + jaxfluids_thirdparty) installs from the pinned git rev.Plan-gap discoveries
The spec didn't anticipate these; each is documented in the relevant commit message:
.gitignorelib/swallowednix/lib/— added!nix/lib/exception.nixos/niximage pattern).cuda_multi.tar.gz(no per-CUDA-version), and the archive has a…/install_components/wrapper dir.libgcc_s.so.1(addedstdenv.cc.cc.libto buildInputs); Nsight Compute drags optional Qt/InfiniBand libs (handled withautoPatchelfIgnoreMissingDeps = true).python.withPackagesproduces a read-only env — pip needs a venv layered over it; venvnix-env.pthfile (refreshed every shell entry) lets pip-upgrades take precedence over Nix-provided deps.CUDA_VISIBLE_DEVICES=allis invalid CUDA syntax (driver reports "no devices"); unset means "all visible", which is what we want.libcuda.so.1,libnvidia-ml.so.1) need a per-shell symlink farm, not bulkLD_LIBRARY_PATHmount of/usr/lib/x86_64-linux-gnu— the latter shadows nix-store libs and SIGFPEs the dynamic loader.jaxfluids_rlisn't a standalone repo — it's a sub-package oftumaer/JAXFLUIDS, added after the v0.2.1 tag. Pinned to main HEAD; pip-installed viagit+revrather than the originally-planned submodule + Nix overlay (the umbrella'sinstall_requireswould clash with the cuda12 wheels).Test plan
.github/workflows/nix-build.ymlpasses on this PR (greenfield workflow, first run on this branch).nix develop .#jax-cuda-hopper-blackwell -c python examples/jax/getting_started/1_kolmogorov/test_kolmogorov_env.py minimize_tke --num-steps 5should print step output with non-NaN rewards. Same on Turing/Ampere via.#jax-cuda-turing-ampere.README.mdNix section +nix/README.md+.packaging/README.mdshould each stand on their own.🤖 Generated with Claude Code