From 5b4989f913e3376a056d9121233cffa3bb7b0bab Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Wed, 20 May 2026 10:34:09 -0500 Subject: [PATCH 1/8] Temporarily remove demo notebook --- .../demo_diffsky_recompute_from_mock.ipynb | 0 docs/source/demos.rst | 7 +------ 2 files changed, 1 insertion(+), 6 deletions(-) rename docs/{source => notebooks}/demo_diffsky_recompute_from_mock.ipynb (100%) diff --git a/docs/source/demo_diffsky_recompute_from_mock.ipynb b/docs/notebooks/demo_diffsky_recompute_from_mock.ipynb similarity index 100% rename from docs/source/demo_diffsky_recompute_from_mock.ipynb rename to docs/notebooks/demo_diffsky_recompute_from_mock.ipynb diff --git a/docs/source/demos.rst b/docs/source/demos.rst index e657f6c0..d73a42cf 100644 --- a/docs/source/demos.rst +++ b/docs/source/demos.rst @@ -3,12 +3,7 @@ Diffsky Code Demos Reading and analyzing Diffsky mocks ------------------------------------ -.. toctree:: - :maxdepth: 1 - :caption: Notebooks: - - demo_diffsky_recompute_from_mock.ipynb - +Docs coming soon! Generating Diffsky galaxy samples --------------------------------- From f151e1a2e8bd15ea5cded440402e0ef5d68bfce6 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Wed, 20 May 2026 11:15:48 -0500 Subject: [PATCH 2/8] Working docs' --- docs/source/background.rst | 12 + docs/source/demos.rst | 6 +- docs/source/double_gauss_demo.ipynb | 327 ++++++++++++++++++++++ docs/source/double_gauss_page.rst | 13 + docs/source/double_gaussian.py | 90 +++++++ docs/source/double_gaussian_holder.rst | 6 + docs/source/index.rst | 17 +- docs/source/single_gaussian.py | 90 +++++++ docs/source/single_gaussian_holder.rst | 6 + docs/source/soft_hist_page.rst | 12 + docs/source/softhist_demo.ipynb | 357 +++++++++++++++++++++++++ 11 files changed, 927 insertions(+), 9 deletions(-) create mode 100644 docs/source/background.rst create mode 100644 docs/source/double_gauss_demo.ipynb create mode 100644 docs/source/double_gauss_page.rst create mode 100644 docs/source/double_gaussian.py create mode 100644 docs/source/double_gaussian_holder.rst create mode 100644 docs/source/single_gaussian.py create mode 100644 docs/source/single_gaussian_holder.rst create mode 100644 docs/source/soft_hist_page.rst create mode 100644 docs/source/softhist_demo.ipynb diff --git a/docs/source/background.rst b/docs/source/background.rst new file mode 100644 index 00000000..1c3e700b --- /dev/null +++ b/docs/source/background.rst @@ -0,0 +1,12 @@ +Background on differentiable methods +==================================== + +The tutorials in this section are based on toy models such as +single and double Gaussians, and some supplementary source code +that is not necessarily part of Diffsky. + +.. toctree:: + :maxdepth: 1 + + soft_hist_page.rst + double_gauss_page.rst diff --git a/docs/source/demos.rst b/docs/source/demos.rst index d73a42cf..66062b98 100644 --- a/docs/source/demos.rst +++ b/docs/source/demos.rst @@ -1,10 +1,14 @@ Diffsky Code Demos ================== +For a pedagogical introduction to the differentiable / probabilistic +methods used in Diffsky (soft histograms, PDF-weighted histograms, etc.), +see the :doc:`background` section. + Reading and analyzing Diffsky mocks ------------------------------------ Docs coming soon! Generating Diffsky galaxy samples --------------------------------- -Docs coming soon! \ No newline at end of file +Docs coming soon! diff --git a/docs/source/double_gauss_demo.ipynb b/docs/source/double_gauss_demo.ipynb new file mode 100644 index 00000000..e40a3dfc --- /dev/null +++ b/docs/source/double_gauss_demo.ipynb @@ -0,0 +1,327 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "494ee05f-bd9b-45eb-9816-1bacc370ce7c", + "metadata": {}, + "source": [ + "# Fitting a double Gaussian with soft histograms\n", + "\n", + "This notebook shows how to implement a double Gaussian model in JAX, and demonstrates how to optimize the parameters of the model by fitting to soft histograms with gradient descent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e0dcdd2-0f05-4de0-90bf-a23d49169a4e", + "metadata": {}, + "outputs": [], + "source": [ + "from jax import random as jran\n", + "ran_key = jran.key(0)" + ] + }, + { + "cell_type": "markdown", + "id": "bc19e244-4499-4b6c-a4a3-19723b21986c", + "metadata": {}, + "source": [ + "## Stochastic Monte Carlo predictions\n", + "\n", + "The `mc_double_gaussian` function generates a sample of 1d data by standard Monte Carlo methods:\n", + "1. Draw $N$ points from the first Gaussian, $\\{\\mu_0, \\sigma_0\\}$\n", + "2. Draw $N$ points from the second Gaussian, $\\{\\mu_1, \\sigma_1\\}$\n", + "3. Draw $N$ uniform random numbers, $u$\n", + "4. If $f$ is the model parameter controlling the relative height of the two Gaussians, then for points with $u`__ kernels used to predict the SEDs -of a population of galaxies +Diffsky is a python library providing +`JAX `__ kernels used to predict the SEDs +of a population of galaxies with `diffmah `__, -`diffstar `__, -and `dsps `__. +`diffstar `__, +and `dsps `__. Diffsky is open-source code that is publicly available on -`GitHub `__. -These docs show you how to use diffsky to predict the SEDs and photometry -of a population of galaxies and their co-evolving dark matter halos. +`GitHub `__. +These docs show you how to use diffsky to predict the SEDs and photometry +of a population of galaxies and their co-evolving dark matter halos. .. toctree:: :maxdepth: 1 @@ -22,6 +22,7 @@ of a population of galaxies and their co-evolving dark matter halos. installation.rst demos.rst + background.rst .. toctree:: :maxdepth: 1 diff --git a/docs/source/single_gaussian.py b/docs/source/single_gaussian.py new file mode 100644 index 00000000..0cecdeb3 --- /dev/null +++ b/docs/source/single_gaussian.py @@ -0,0 +1,90 @@ +"""Demo code to fit a 1d Gaussian model with soft histograms and jax.grad""" + +from jax import numpy as jnp +from jax import jit as jjit +from jax import random as jran +from collections import namedtuple +from diffsky.signdhist_lomem import nnsig_ndhist +from jax import value_and_grad + +GParams = namedtuple("GParams", ("mu", "sig")) +DEFAULT_PARAMS = GParams(mu=-1.0, sig=1.0) + +NPTS = 20_000 + + +@jjit +def mc_single_gaussian(params, ran_key): + """Draw a Monte Carlo realization of a Gaussian""" + xdata = jran.normal(ran_key, shape=(NPTS,)) * params.sig + params.mu + return xdata + + +@jjit +def mc_predict_hard_edged_xhist(params, xbins, ran_key): + """Predict histogram counts by applying jnp.histogram to + a Monte Carlo realization of a Gaussian""" + xdata = mc_single_gaussian(params, ran_key) + xhist, __ = jnp.histogram(xdata, bins=xbins) + return xhist + + +@jjit +def mc_predict_soft_xhist(params, xbins, ran_key): + """Predict histogram counts by applying a soft histogram to + a Monte Carlo realization of a Gaussian""" + xdata = mc_single_gaussian(params, ran_key) + n = xdata.shape[0] + xdata = xdata.reshape((n, 1)) + xhist = soft_xhist(xdata, xbins) + return xhist + + +@jjit +def soft_xhist(xdata, xbins): + """Soft histogram function + This is a wrapper around diffsky.nnsig_ndhist for 1d data""" + nbins = xbins.shape[0] + xbins_lo = xbins[:-1].reshape((nbins - 1, 1)) + xbins_hi = xbins[1:].reshape((nbins - 1, 1)) + dx = jnp.diff(xbins).mean() + ndsig = jnp.zeros_like(xbins_lo) + dx / 2 + xdata = xdata.reshape((-1, 1)) + xhist = nnsig_ndhist(xdata, ndsig, xbins_lo, xbins_hi) + return xhist + + +@jjit +def _mae_kern(x, y): + """Mean absolute error""" + abs_diff = jnp.abs(y - x) + return jnp.mean(abs_diff) + + +@jjit +def hard_edged_xhist_loss(params, loss_data): + """Loss function based on a histogram with hard-edged bins""" + xhist_target, xbins, ran_key = loss_data + xhist_pred = mc_predict_hard_edged_xhist(params, xbins, ran_key) + loss = _mae_kern(xhist_pred, xhist_target) + return loss + + +@jjit +def soft_xhist_loss(params, loss_data): + """Loss function based on a soft histogram""" + xhist_target, xbins, ran_key = loss_data + xhist_pred = mc_predict_soft_xhist(params, xbins, ran_key) + loss = _mae_kern(xhist_pred, xhist_target) + return loss + + +@jjit +def param_update(params, grads, learning_rate): + """Update namedtuple params by taking a small step down the gradient""" + new_params = params._make(jnp.array(params) - jnp.array(grads) * learning_rate) + return new_params + + +hard_edged_xhist_loss_and_grad = jjit(value_and_grad(hard_edged_xhist_loss, argnums=0)) +soft_xhist_loss_and_grad = jjit(value_and_grad(soft_xhist_loss, argnums=0)) diff --git a/docs/source/single_gaussian_holder.rst b/docs/source/single_gaussian_holder.rst new file mode 100644 index 00000000..f9e8b505 --- /dev/null +++ b/docs/source/single_gaussian_holder.rst @@ -0,0 +1,6 @@ +Supplementary source code for soft histograms +============================================= + +.. literalinclude:: single_gaussian.py + :language: python + :linenos: diff --git a/docs/source/soft_hist_page.rst b/docs/source/soft_hist_page.rst new file mode 100644 index 00000000..eca8073f --- /dev/null +++ b/docs/source/soft_hist_page.rst @@ -0,0 +1,12 @@ +Fitting probability distributions with soft histograms +====================================================== + +This section demonstrates differentiable techniques for fitting a probability distribution, +using a simple unimodal Gaussian to introduce the need for soft histograms. +We recommend reading the tutorial notebook alongside the supplementary source code linked below. + +.. toctree:: + :maxdepth: 1 + + single_gaussian_holder.rst + softhist_demo.ipynb diff --git a/docs/source/softhist_demo.ipynb b/docs/source/softhist_demo.ipynb new file mode 100644 index 00000000..de58f7ed --- /dev/null +++ b/docs/source/softhist_demo.ipynb @@ -0,0 +1,357 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "02646743-5e32-47c9-b942-99d8e303185a", + "metadata": {}, + "source": [ + "# Introduction to soft histograms\n", + "\n", + "This notebook demonstrates the basic principles of using a soft histogram when fitting a model for a probability distribution. We'll use a one-dimensional Gaussian distribution as our model, and the first thing we'll do is try to fit the model parameters using a loss function based on a standard histogram. As discussed below, computing standard histograms with hard-edged bins is a non-differentiable calculation, and so our first attempt to fit the model will fail. We'll then see how using a soft histogram solves the problem.\n", + "\n", + "This demo is based on the `single_gaussian.py` module, which implements a Gaussian model. Let's start out by using the module to generate some target data with the `mc_single_gaussian` function, and visually inspect the distributions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e69d23e5-4a4c-455c-9d5a-ef793b81d0c2", + "metadata": {}, + "outputs": [], + "source": [ + "from jax import random as jran\n", + "ran_key = jran.key(0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b84cf21-5d5a-425c-bbe1-83938f79f5d3", + "metadata": {}, + "outputs": [], + "source": [ + "import single_gaussian as sg\n", + "\n", + "NBINS = 50\n", + "XBOUNDS = (-10.0, 10.0)\n", + "XBINS = np.linspace(*XBOUNDS, NBINS)[:-1]\n", + "\n", + "\n", + "PARAMS_INIT = sg.DEFAULT_PARAMS._replace()\n", + "ran_key, init_key = jran.split(ran_key, 2)\n", + "XDATA_INIT = sg.mc_single_gaussian(PARAMS_INIT, init_key)\n", + "\n", + "PARAMS_TARGET = sg.DEFAULT_PARAMS._replace(mu=-2.0, sig=2.0)\n", + "ran_key, target_key = jran.split(ran_key, 2)\n", + "XDATA_TARGET = sg.mc_single_gaussian(PARAMS_TARGET, target_key)\n", + "\n", + "fig, ax = plt.subplots(1, 1)\n", + "__=ax.hist(XDATA_TARGET, bins=XBINS, \n", + " alpha=0.7, label=r'target population')\n", + "__=ax.hist(XDATA_INIT, bins=XBINS, \n", + " alpha=0.7, label=r'initial population')\n", + "leg = ax.legend()" + ] + }, + { + "cell_type": "markdown", + "id": "882f0f64-fd68-4dcb-96b8-dd18c1e4301b", + "metadata": {}, + "source": [ + "### Predicting a histogram from a population\n", + "\n", + "The `mc_predict_hard_edged_xhist` function is just a wrapper around `mc_single_gaussian` that first generates target data, and then uses `jnp.histogram` to bin the data into a predicted histogram." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7428dde5-e537-4e31-a73b-6a58fc54c0a9", + "metadata": {}, + "outputs": [], + "source": [ + "ran_key, init_key = jran.split(ran_key, 2)\n", + "XHIST_INIT = sg.mc_predict_hard_edged_xhist(\n", + " PARAMS_INIT, XBINS, init_key)\n", + "\n", + "ran_key, target_key = jran.split(ran_key, 2)\n", + "XHIST_TARGET = sg.mc_predict_hard_edged_xhist(\n", + " PARAMS_TARGET, XBINS, target_key)" + ] + }, + { + "cell_type": "markdown", + "id": "181e6fde-bb21-4b9d-955c-820d8dc7c533", + "metadata": {}, + "source": [ + "### Running gradient descent\n", + "\n", + "The `hard_edged_xhist_loss_and_grad` function is the loss function we will try to minimize with gradient descent. This loss function uses `mc_predict_hard_edged_xhist` to predict a histogram, and then computes the mean absolute error between the predicted and target histogram. We use `jax.value_and_grad` to additionally return the gradients of the loss in addition to the value.\n", + "\n", + "The next cell takes 100 steps of gradient descent, where at each step, we take a tiny step down the gradient to update the parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95275efa-dd29-4f05-bff5-33a16374ba63", + "metadata": {}, + "outputs": [], + "source": [ + "ran_key, loss_key = jran.split(ran_key, 2)\n", + "loss_data = XHIST_TARGET, XBINS, loss_key\n", + "\n", + "learn_rate = 0.0005\n", + "\n", + "nsteps = 100\n", + "loss_collector = []\n", + "p_best = PARAMS_INIT._replace()\n", + "for istep in range(nsteps):\n", + " loss, grads = sg.hard_edged_xhist_loss_and_grad(\n", + " p_best, loss_data)\n", + " p_best = sg.param_update(\n", + " p_best, grads, learn_rate)\n", + " loss_collector.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "952233c7-fd78-4d42-a6e1-fa6eb65818ef", + "metadata": {}, + "source": [ + "### Inspect the results\n", + "\n", + "We'll now inspect the results by plotting the loss curve, and comparing the best-fit histogram to the target." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ce1adc0-c970-4b89-9d40-e81290138011", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1)\n", + "xlabel = ax.set_xlabel('step')\n", + "ylabel = ax.set_ylabel('log10 loss')\n", + "__=ax.plot(np.log10(loss_collector))\n", + "\n", + "ran_key, pred_key = jran.split(ran_key, 2)\n", + "xhist_best = sg.mc_predict_hard_edged_xhist(\n", + " p_best, XBINS, pred_key)\n", + "\n", + "fig, ax = plt.subplots(1, 1)\n", + "__=ax.plot(XBINS[1:], XHIST_TARGET, \n", + " label='target')\n", + "__=ax.plot(XBINS[1:], XHIST_INIT, '--', \n", + " label='initial guess')\n", + "__=ax.plot(XBINS[1:], xhist_best, ':', \n", + " label='best fit MC method')\n", + "leg = ax.legend()" + ] + }, + { + "cell_type": "markdown", + "id": "416d9d9b-9d41-4547-8539-979609f03548", + "metadata": {}, + "source": [ + "### That didn't work at all - what happened?\n", + "\n", + "It looks like our best-fit histogram didn't move from the initial histogram. Let's see how the best-fit points compare to the target and initial points:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "077abebd-b557-4215-b722-8895795afbf9", + "metadata": {}, + "outputs": [], + "source": [ + "gen = zip(p_best._fields, p_best, PARAMS_INIT, PARAMS_TARGET)\n", + "for key, val_best, val_init, val_target in gen:\n", + " print(f\"Init {key} = {val_init:.2f}\")\n", + " print(f\"Best {key} = {val_best:.2f}\")\n", + " print(f\"True {key} = {val_target:.2f}\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "1e47b5bc-9285-4c8f-a791-97c050ab8f79", + "metadata": {}, + "source": [ + "### Hmmm, the parameters didn't move at all...\n", + "\n", + "Let's inspect the gradient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ec7fa72-6280-4de6-ae60-967cef8617c4", + "metadata": {}, + "outputs": [], + "source": [ + "loss_best_mc, grads = sg.hard_edged_xhist_loss_and_grad(\n", + " p_best, loss_data)\n", + "print(grads)" + ] + }, + { + "cell_type": "markdown", + "id": "d6fca23e-1d12-4af0-aeb0-a6698a7c4f44", + "metadata": {}, + "source": [ + "### All the parameters have zero gradients!\n", + "\n", + "That's why the parameter did not move from its initial position during our gradient descent. The problem comes from trying to differentiate through a histogram with hard-edged bins.\n", + "\n", + "#### Why don't hard-edged histograms work with autodiff?\n", + "\n", + "Let's think about how autodiff works with hard-edged histograms to understand what happened. The way a standard histogram calculation works is that for each bin $i$, we loop over each point $x_{\\rm j}$ in our dataset, and if the point falls within the boundaries of bin $i$, we increment our histogram by 1, otherwise we increment by 0. Now consider how the dataset changes with an _infinitesimal_ change to the position of each point, ${\\rm d}x$. If $x_{\\rm j}$ is within the bin boundaries, then the perturbed position $x_{\\rm j}+{\\rm d}x$ also within the bin boundaries, because the point and the boundary are some _finite_ distance away, and we have only perturbed our point by an infinitesimal amount. The same goes for points outside the bin. The only points with non-zero gradients will be those points that just so happen to fall exactly on the boundary of some bin, a set of measure zero. Thus it makes sense that we get zero-valued gradients for predictions made with hard-edged histograms." + ] + }, + { + "cell_type": "markdown", + "id": "6cb14d0c-3862-4175-aa06-58abad394cc9", + "metadata": {}, + "source": [ + "## Introducing soft histograms\n", + "\n", + "The solution to this problem is to use soft histograms. In standard histograms, each point in the dataset $x_{\\rm j}$ contributes either 0 or 1 to the result for each bin. In soft histograms, each point contributes a continuously-valued _weight_, $w_{\\rm j},$ to the result for each bin. Each $w_{\\rm j}$ is computed by integrating a Gaussian kernel across the edges of the bin. For histogram bins with width ${\\rm d}x,$ we typically choose a kernel width $\\sigma\\lesssim{\\rm d}x.$\n", + "\n", + "There are several soft histogram calculators in diffsky. All of the calculators are written to support N-dimensional data, and so the `soft_xhist` function in `single_gaussian.py` just provides some wrapper behavior that reshapes the data and the bins to ${\\rm (n, 1)}$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6ba0bf0-cc8a-4a75-a0e3-07f4ec4e6905", + "metadata": {}, + "outputs": [], + "source": [ + "XHIST_TARGET, __ = jnp.histogram(XDATA_TARGET, bins=XBINS)\n", + "XHIST_TARGET_SOFT = sg.soft_xhist(XDATA_TARGET, XBINS)\n", + "\n", + "fig, ax = plt.subplots(1, 1)\n", + "__=ax.plot(XBINS[1:], XHIST_TARGET, \n", + " label='standard histogram')\n", + "__=ax.plot(XBINS[1:], XHIST_TARGET_SOFT,'--', \n", + " label='soft histogram')\n", + "leg = ax.legend()" + ] + }, + { + "cell_type": "markdown", + "id": "976682e5-b257-47df-8370-608e26ba0588", + "metadata": {}, + "source": [ + "### Running gradient descent with a soft histogram\n", + "\n", + "The next cell takes 100 steps of gradient descent, this time with a loss function based on `soft_xhist_loss_and_grad`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b79da38-ec57-4c40-aea9-2e765cc6804c", + "metadata": {}, + "outputs": [], + "source": [ + "ran_key, loss_key = jran.split(ran_key, 2)\n", + "loss_data = XHIST_TARGET, XBINS, loss_key\n", + "\n", + "learn_rate = 0.0005\n", + "\n", + "nsteps = 100\n", + "soft_loss_collector = []\n", + "p_best_soft = PARAMS_INIT._replace()\n", + "for istep in range(nsteps):\n", + " loss, grads = sg.soft_xhist_loss_and_grad(\n", + " p_best_soft, loss_data)\n", + " p_best_soft = sg.param_update(\n", + " p_best_soft, grads, learn_rate)\n", + " soft_loss_collector.append(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "e8369b95-3a4b-40b9-b2d4-36c0225c09d3", + "metadata": {}, + "source": [ + "### Inspect the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20fc856d-31d8-4d73-98f9-abedd5ecbb4d", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1)\n", + "xlabel = ax.set_xlabel('step')\n", + "ylabel = ax.set_ylabel('log10 loss')\n", + "__=ax.plot(np.log10(soft_loss_collector))\n", + "\n", + "ran_key, pred_key = jran.split(ran_key, 2)\n", + "xhist_best_soft = sg.mc_predict_soft_xhist(\n", + " p_best_soft, XBINS, pred_key)\n", + "\n", + "fig, ax = plt.subplots(1, 1)\n", + "__=ax.plot(XBINS[1:], XHIST_TARGET, \n", + " label='target')\n", + "__=ax.plot(XBINS[1:], XHIST_INIT, '--', \n", + " label='initial guess')\n", + "__=ax.plot(XBINS[1:], xhist_best_soft, ':', \n", + " label='best fit MC method')\n", + "\n", + "leg = ax.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5b40c02-2305-4fc5-832f-db09ce68c51c", + "metadata": {}, + "outputs": [], + "source": [ + "gen = zip(p_best_soft._fields, p_best_soft, PARAMS_INIT, PARAMS_TARGET)\n", + "for key, val_best, val_init, val_target in gen:\n", + " print(f\"Init {key} = {val_init:.2f}\")\n", + " print(f\"Best {key} = {val_best:.2f}\")\n", + " print(f\"True {key} = {val_target:.2f}\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "4b6d1777-8fcd-49dd-945d-1cce2e93bc03", + "metadata": {}, + "source": [ + "### It worked! \n", + "\n", + "With a soft histogram, when we perturb each point by some infinitesimal amount, the weight of each point also changes infinitesimally, and so we get non-zero gradients." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 6a9168fe636d7372c1077d136ac7c53566418724 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Wed, 20 May 2026 11:20:21 -0500 Subject: [PATCH 3/8] Working organization --- docs/source/double_gauss_page.rst | 4 ++-- ...le_gaussian_holder.rst => double_gaussian_code_holder.rst} | 0 ...le_gaussian_holder.rst => single_gaussian_code_holder.rst} | 0 docs/source/soft_hist_page.rst | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename docs/source/{double_gaussian_holder.rst => double_gaussian_code_holder.rst} (100%) rename docs/source/{single_gaussian_holder.rst => single_gaussian_code_holder.rst} (100%) diff --git a/docs/source/double_gauss_page.rst b/docs/source/double_gauss_page.rst index 81599dde..5c00987c 100644 --- a/docs/source/double_gauss_page.rst +++ b/docs/source/double_gauss_page.rst @@ -9,5 +9,5 @@ We recommend reading the tutorial notebook alongside the supplementary source co .. toctree:: :maxdepth: 1 - double_gaussian_holder.rst - double_gauss_demo.ipynb \ No newline at end of file + double_gauss_demo.ipynb + double_gaussian_code_holder.rst diff --git a/docs/source/double_gaussian_holder.rst b/docs/source/double_gaussian_code_holder.rst similarity index 100% rename from docs/source/double_gaussian_holder.rst rename to docs/source/double_gaussian_code_holder.rst diff --git a/docs/source/single_gaussian_holder.rst b/docs/source/single_gaussian_code_holder.rst similarity index 100% rename from docs/source/single_gaussian_holder.rst rename to docs/source/single_gaussian_code_holder.rst diff --git a/docs/source/soft_hist_page.rst b/docs/source/soft_hist_page.rst index eca8073f..56884272 100644 --- a/docs/source/soft_hist_page.rst +++ b/docs/source/soft_hist_page.rst @@ -8,5 +8,5 @@ We recommend reading the tutorial notebook alongside the supplementary source co .. toctree:: :maxdepth: 1 - single_gaussian_holder.rst softhist_demo.ipynb + single_gaussian_code_holder.rst From 465d988185c4c97cd6cef60cb2f7fd193967d5aa Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Wed, 20 May 2026 11:23:23 -0500 Subject: [PATCH 4/8] working reorg --- docs/source/background.rst | 4 ++-- docs/source/{ => gaussian_examples}/double_gauss_demo.ipynb | 0 docs/source/{ => gaussian_examples}/double_gauss_page.rst | 0 docs/source/{ => gaussian_examples}/double_gaussian.py | 0 .../{ => gaussian_examples}/double_gaussian_code_holder.rst | 0 docs/source/{ => gaussian_examples}/single_gaussian.py | 0 .../{ => gaussian_examples}/single_gaussian_code_holder.rst | 0 docs/source/{ => gaussian_examples}/soft_hist_page.rst | 0 docs/source/{ => gaussian_examples}/softhist_demo.ipynb | 0 9 files changed, 2 insertions(+), 2 deletions(-) rename docs/source/{ => gaussian_examples}/double_gauss_demo.ipynb (100%) rename docs/source/{ => gaussian_examples}/double_gauss_page.rst (100%) rename docs/source/{ => gaussian_examples}/double_gaussian.py (100%) rename docs/source/{ => gaussian_examples}/double_gaussian_code_holder.rst (100%) rename docs/source/{ => gaussian_examples}/single_gaussian.py (100%) rename docs/source/{ => gaussian_examples}/single_gaussian_code_holder.rst (100%) rename docs/source/{ => gaussian_examples}/soft_hist_page.rst (100%) rename docs/source/{ => gaussian_examples}/softhist_demo.ipynb (100%) diff --git a/docs/source/background.rst b/docs/source/background.rst index 1c3e700b..2a2a5803 100644 --- a/docs/source/background.rst +++ b/docs/source/background.rst @@ -8,5 +8,5 @@ that is not necessarily part of Diffsky. .. toctree:: :maxdepth: 1 - soft_hist_page.rst - double_gauss_page.rst + gaussian_examples/soft_hist_page.rst + gaussian_examples/double_gauss_page.rst diff --git a/docs/source/double_gauss_demo.ipynb b/docs/source/gaussian_examples/double_gauss_demo.ipynb similarity index 100% rename from docs/source/double_gauss_demo.ipynb rename to docs/source/gaussian_examples/double_gauss_demo.ipynb diff --git a/docs/source/double_gauss_page.rst b/docs/source/gaussian_examples/double_gauss_page.rst similarity index 100% rename from docs/source/double_gauss_page.rst rename to docs/source/gaussian_examples/double_gauss_page.rst diff --git a/docs/source/double_gaussian.py b/docs/source/gaussian_examples/double_gaussian.py similarity index 100% rename from docs/source/double_gaussian.py rename to docs/source/gaussian_examples/double_gaussian.py diff --git a/docs/source/double_gaussian_code_holder.rst b/docs/source/gaussian_examples/double_gaussian_code_holder.rst similarity index 100% rename from docs/source/double_gaussian_code_holder.rst rename to docs/source/gaussian_examples/double_gaussian_code_holder.rst diff --git a/docs/source/single_gaussian.py b/docs/source/gaussian_examples/single_gaussian.py similarity index 100% rename from docs/source/single_gaussian.py rename to docs/source/gaussian_examples/single_gaussian.py diff --git a/docs/source/single_gaussian_code_holder.rst b/docs/source/gaussian_examples/single_gaussian_code_holder.rst similarity index 100% rename from docs/source/single_gaussian_code_holder.rst rename to docs/source/gaussian_examples/single_gaussian_code_holder.rst diff --git a/docs/source/soft_hist_page.rst b/docs/source/gaussian_examples/soft_hist_page.rst similarity index 100% rename from docs/source/soft_hist_page.rst rename to docs/source/gaussian_examples/soft_hist_page.rst diff --git a/docs/source/softhist_demo.ipynb b/docs/source/gaussian_examples/softhist_demo.ipynb similarity index 100% rename from docs/source/softhist_demo.ipynb rename to docs/source/gaussian_examples/softhist_demo.ipynb From a3305e2c19fe6fa380537e035af76e0798158182 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Wed, 20 May 2026 11:26:19 -0500 Subject: [PATCH 5/8] Add docstrings --- docs/source/gaussian_examples/double_gaussian.py | 13 ++++++++++--- docs/source/gaussian_examples/single_gaussian.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/docs/source/gaussian_examples/double_gaussian.py b/docs/source/gaussian_examples/double_gaussian.py index 2826e3db..a1f99fed 100644 --- a/docs/source/gaussian_examples/double_gaussian.py +++ b/docs/source/gaussian_examples/double_gaussian.py @@ -1,20 +1,21 @@ """Demo code to fit a 2d Gaussian model with soft histograms and jax.grad""" from jax import numpy as jnp -from collections import namedtuple from jax import random as jran from jax import jit as jjit -from diffsky.signdhist_lomem import nnsig_ndhist from jax import value_and_grad +from collections import namedtuple +from diffsky.signdhist_lomem import nnsig_ndhist DGParams = namedtuple("DGParams", ("mu0", "sig0", "mu1", "sig1", "frac0")) - DEFAULT_PARAMS = DGParams(mu0=-1.0, sig0=0.5, mu1=1.0, sig1=1.0, frac0=0.75) + NPTS = 20_000 @jjit def mc_double_gaussian(params, ran_key): + """Draw a stochastic Monte Carlo realization of a double Gaussian""" u_key, n0_key, n1_key = jran.split(ran_key, 3) uran = jran.uniform(u_key, minval=0, maxval=1, shape=(NPTS,)) n0 = jran.normal(n0_key, shape=(NPTS,)) * params.sig0 + params.mu0 @@ -26,6 +27,8 @@ def mc_double_gaussian(params, ran_key): @jjit def predict_soft_xhist_mc(params, xbins, ran_key): + """Predict histogram counts by applying soft histogram to + a stochastic Monte Carlo realization of a double Gaussian""" xdata = mc_double_gaussian(params, ran_key) xhist = soft_xhist(xdata, xbins) return xhist @@ -33,6 +36,8 @@ def predict_soft_xhist_mc(params, xbins, ran_key): @jjit def predict_soft_xhist_weighted(params, xbins, ran_key): + """Predict histogram counts by applying soft histogram to + a PDF-weighted Monte Carlo realization of a double Gaussian""" n0_key, n1_key = jran.split(ran_key, 2) n0 = jran.normal(n0_key, shape=(NPTS,)) * params.sig0 + params.mu0 n1 = jran.normal(n1_key, shape=(NPTS,)) * params.sig1 + params.mu1 @@ -65,6 +70,7 @@ def _mae_kern(x, y): @jjit def weighted_mae_loss(params, loss_data): + """Loss function based on a PDF-weighted soft histogram""" xhist_target, xbins, ran_key = loss_data xhist_pred = predict_soft_xhist_weighted(params, xbins, ran_key) loss = _mae_kern(xhist_pred, xhist_target) @@ -73,6 +79,7 @@ def weighted_mae_loss(params, loss_data): @jjit def mc_mae_loss(params, loss_data): + """Loss function based on a stochastic Monte Carlo with a soft histogram""" xhist_target, xbins, ran_key = loss_data xhist_pred = predict_soft_xhist_mc(params, xbins, ran_key) loss = _mae_kern(xhist_pred, xhist_target) diff --git a/docs/source/gaussian_examples/single_gaussian.py b/docs/source/gaussian_examples/single_gaussian.py index 0cecdeb3..1034d097 100644 --- a/docs/source/gaussian_examples/single_gaussian.py +++ b/docs/source/gaussian_examples/single_gaussian.py @@ -3,9 +3,9 @@ from jax import numpy as jnp from jax import jit as jjit from jax import random as jran +from jax import value_and_grad from collections import namedtuple from diffsky.signdhist_lomem import nnsig_ndhist -from jax import value_and_grad GParams = namedtuple("GParams", ("mu", "sig")) DEFAULT_PARAMS = GParams(mu=-1.0, sig=1.0) From 2421d31a0e1971a2f518384ddaea9ed25b7abed4 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Wed, 20 May 2026 11:33:44 -0500 Subject: [PATCH 6/8] Reinstate demo_diffsky_recompute_from_mock.ipynb --- .../demo_diffsky_recompute_from_mock.ipynb | 0 docs/source/demos.rst | 5 ++++- 2 files changed, 4 insertions(+), 1 deletion(-) rename docs/{notebooks => source}/demo_diffsky_recompute_from_mock.ipynb (100%) diff --git a/docs/notebooks/demo_diffsky_recompute_from_mock.ipynb b/docs/source/demo_diffsky_recompute_from_mock.ipynb similarity index 100% rename from docs/notebooks/demo_diffsky_recompute_from_mock.ipynb rename to docs/source/demo_diffsky_recompute_from_mock.ipynb diff --git a/docs/source/demos.rst b/docs/source/demos.rst index 66062b98..176973fe 100644 --- a/docs/source/demos.rst +++ b/docs/source/demos.rst @@ -7,7 +7,10 @@ see the :doc:`background` section. Reading and analyzing Diffsky mocks ------------------------------------ -Docs coming soon! +.. toctree:: + :maxdepth: 1 + + demo_diffsky_recompute_from_mock.ipynb Generating Diffsky galaxy samples --------------------------------- From 60b989eb42c4127361a9dafbf49c49493a4a9313 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Wed, 20 May 2026 11:42:29 -0500 Subject: [PATCH 7/8] Add forgotten imports to gaussian_examples in docs --- docs/source/gaussian_examples/double_gauss_demo.ipynb | 5 ++++- docs/source/gaussian_examples/softhist_demo.ipynb | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/source/gaussian_examples/double_gauss_demo.ipynb b/docs/source/gaussian_examples/double_gauss_demo.ipynb index e40a3dfc..9057d53b 100644 --- a/docs/source/gaussian_examples/double_gauss_demo.ipynb +++ b/docs/source/gaussian_examples/double_gauss_demo.ipynb @@ -13,11 +13,14 @@ { "cell_type": "code", "execution_count": null, - "id": "6e0dcdd2-0f05-4de0-90bf-a23d49169a4e", + "id": "981d72a6-0042-4f7a-9397-2f7e64429469", "metadata": {}, "outputs": [], "source": [ + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", "from jax import random as jran\n", + "\n", "ran_key = jran.key(0)" ] }, diff --git a/docs/source/gaussian_examples/softhist_demo.ipynb b/docs/source/gaussian_examples/softhist_demo.ipynb index de58f7ed..3e6b413a 100644 --- a/docs/source/gaussian_examples/softhist_demo.ipynb +++ b/docs/source/gaussian_examples/softhist_demo.ipynb @@ -15,11 +15,14 @@ { "cell_type": "code", "execution_count": null, - "id": "e69d23e5-4a4c-455c-9d5a-ef793b81d0c2", + "id": "db9b4c72-fd3b-4e54-bbab-8b3e7c925fac", "metadata": {}, "outputs": [], "source": [ + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", "from jax import random as jran\n", + "\n", "ran_key = jran.key(0)" ] }, From f8cc7c4d4af739337fdf51920fec9e67115224a0 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Wed, 20 May 2026 11:48:27 -0500 Subject: [PATCH 8/8] Add forgotten jnp imports to tutorials --- docs/source/gaussian_examples/double_gauss_demo.ipynb | 1 + docs/source/gaussian_examples/softhist_demo.ipynb | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/source/gaussian_examples/double_gauss_demo.ipynb b/docs/source/gaussian_examples/double_gauss_demo.ipynb index 9057d53b..b3380d16 100644 --- a/docs/source/gaussian_examples/double_gauss_demo.ipynb +++ b/docs/source/gaussian_examples/double_gauss_demo.ipynb @@ -20,6 +20,7 @@ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from jax import random as jran\n", + "from jax import numpy as jnp\n", "\n", "ran_key = jran.key(0)" ] diff --git a/docs/source/gaussian_examples/softhist_demo.ipynb b/docs/source/gaussian_examples/softhist_demo.ipynb index 3e6b413a..1987c720 100644 --- a/docs/source/gaussian_examples/softhist_demo.ipynb +++ b/docs/source/gaussian_examples/softhist_demo.ipynb @@ -22,6 +22,7 @@ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from jax import random as jran\n", + "from jax import numpy as jnp\n", "\n", "ran_key = jran.key(0)" ]