Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/background.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Background on differentiable methods
====================================

The tutorials in this section are based on toy models such as
single and double Gaussians, and some supplementary source code
that is not necessarily part of Diffsky.

.. toctree::
:maxdepth: 1

gaussian_examples/soft_hist_page.rst
gaussian_examples/double_gauss_page.rst
12 changes: 7 additions & 5 deletions docs/source/demos.rst
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
Diffsky Code Demos
==================

For a pedagogical introduction to the differentiable / probabilistic
methods used in Diffsky (soft histograms, PDF-weighted histograms, etc.),
see the :doc:`background` section.

Reading and analyzing Diffsky mocks
------------------------------------
.. toctree::
:maxdepth: 1
:caption: Notebooks:

demo_diffsky_recompute_from_mock.ipynb
:maxdepth: 1

demo_diffsky_recompute_from_mock.ipynb

Generating Diffsky galaxy samples
---------------------------------
Docs coming soon!
Docs coming soon!
331 changes: 331 additions & 0 deletions docs/source/gaussian_examples/double_gauss_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "494ee05f-bd9b-45eb-9816-1bacc370ce7c",
"metadata": {},
"source": [
"# Fitting a double Gaussian with soft histograms\n",
"\n",
"This notebook shows how to implement a double Gaussian model in JAX, and demonstrates how to optimize the parameters of the model by fitting to soft histograms with gradient descent."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "981d72a6-0042-4f7a-9397-2f7e64429469",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"from jax import random as jran\n",
"from jax import numpy as jnp\n",
"\n",
"ran_key = jran.key(0)"
]
},
{
"cell_type": "markdown",
"id": "bc19e244-4499-4b6c-a4a3-19723b21986c",
"metadata": {},
"source": [
"## Stochastic Monte Carlo predictions\n",
"\n",
"The `mc_double_gaussian` function generates a sample of 1d data by standard Monte Carlo methods:\n",
"1. Draw $N$ points from the first Gaussian, $\\{\\mu_0, \\sigma_0\\}$\n",
"2. Draw $N$ points from the second Gaussian, $\\{\\mu_1, \\sigma_1\\}$\n",
"3. Draw $N$ uniform random numbers, $u$\n",
"4. If $f$ is the model parameter controlling the relative height of the two Gaussians, then for points with $u<f,$ select the random draw from the first Gaussian, otherwise select from the second Gaussian.\n",
"\n",
"This demo is based on the `double_gaussian.py` module, which implements the above algorithm with the `mc_double_gaussian` function."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5f395ee-28c0-4c0c-9c37-3cb263553268",
"metadata": {},
"outputs": [],
"source": [
"import double_gaussian as dg\n",
"\n",
"NBINS = 50\n",
"XBOUNDS = (-5.0, 5.0)\n",
"XBINS = np.linspace(*XBOUNDS, NBINS)[:-1]\n",
"\n",
"\n",
"PARAMS_INIT = dg.DEFAULT_PARAMS._replace()\n",
"ran_key, init_key = jran.split(ran_key, 2)\n",
"XDATA_INIT = dg.mc_double_gaussian(PARAMS_INIT, init_key)\n",
"\n",
"PARAMS_TARGET = dg.DEFAULT_PARAMS._replace(\n",
" mu0=-2.0, sig0=1.0, mu1=2.0, sig1=1.0, frac0=0.25)\n",
"ran_key, target_key = jran.split(ran_key, 2)\n",
"XDATA_TARGET = dg.mc_double_gaussian(PARAMS_TARGET, target_key)\n",
"\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"__=ax.hist(XDATA_INIT, bins=XBINS, alpha=0.7, label=r'initial population')\n",
"__=ax.hist(XDATA_TARGET, bins=XBINS, alpha=0.7, label=r'target population')\n",
"leg = ax.legend()"
]
},
{
"cell_type": "markdown",
"id": "e7a1840a-046b-4806-ab55-541aa47ca860",
"metadata": {},
"source": [
"### Predicting a histogram from a population\n",
"\n",
"The `predict_soft_xhist_mc` function is just a wrapper around `mc_double_gaussian` that first generates sample data, and then computes a soft histogram of the sample."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1339958-9d9f-40d0-943c-af59a60bd3c1",
"metadata": {},
"outputs": [],
"source": [
"ran_key, init_key = jran.split(ran_key, 2)\n",
"XHIST_INIT = dg.predict_soft_xhist_mc(PARAMS_INIT, XBINS, init_key)\n",
"\n",
"ran_key, target_key = jran.split(ran_key, 2)\n",
"XHIST_TARGET = dg.predict_soft_xhist_mc(PARAMS_TARGET, XBINS, target_key)"
]
},
{
"cell_type": "markdown",
"id": "9ee6974d-eecc-4c6a-9aa4-5d2692324926",
"metadata": {},
"source": [
"### Running gradient descent\n",
"\n",
"The `mc_mae_loss_and_grad` function is the loss function we will try to minimize with gradient descent. This loss function uses `predict_soft_xhist_mc` to predict a histogram, and then computes the mean absolute error between the predicted and target histogram. We use `jax.value_and_grad` to additionally return the gradients of the loss in addition to the value.\n",
"\n",
"The next cell takes 100 steps of gradient descent, where at each step, we take a tiny step down the gradient to update the parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e14afe99-649f-4416-b891-4ffb9b669c46",
"metadata": {},
"outputs": [],
"source": [
"ran_key, loss_key = jran.split(ran_key, 2)\n",
"loss_data = XHIST_TARGET, XBINS, loss_key\n",
"\n",
"learn_rate = 0.001\n",
"\n",
"nsteps = 100\n",
"\n",
"p_best_mc = dg.DEFAULT_PARAMS._replace()\n",
"collector_mc = []\n",
"for istep in range(nsteps):\n",
" loss, grads = dg.mc_mae_loss_and_grad(p_best_mc, loss_data)\n",
" p_best_mc = dg.param_update(p_best_mc, grads, learn_rate)\n",
" collector_mc.append(loss)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"xlabel = ax.set_xlabel('step')\n",
"ylabel = ax.set_ylabel('log10 loss')\n",
"__=ax.plot(np.log10(collector_mc))\n",
"\n",
"xhist_best_mc = dg.predict_soft_xhist_mc(p_best_mc, XBINS, target_key)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"__=ax.plot(XBINS[1:], XHIST_TARGET, label='target')\n",
"__=ax.plot(XBINS[1:], XHIST_INIT, '--', label='initial guess')\n",
"__=ax.plot(XBINS[1:], xhist_best_mc, '--', label='best fit MC method')\n",
"\n",
"leg = ax.legend()"
]
},
{
"cell_type": "markdown",
"id": "03f6d81e-1ac0-44d6-b9e3-3478498c00e6",
"metadata": {},
"source": [
"## That fit is not so great - what happened?\n",
"\n",
"Let's see how the best-fit points compare to the target and initial points"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00aac980-79fc-4121-926f-ec121a860478",
"metadata": {},
"outputs": [],
"source": [
"for key, val_best, val_init, val_target in zip(p_best_mc._fields, p_best_mc, dg.DEFAULT_PARAMS, PARAMS_TARGET):\n",
" print(f\"Init {key} = {val_init:.2f}\")\n",
" print(f\"Best {key} = {val_best:.2f}\")\n",
" print(f\"True {key} = {val_target:.2f}\\n\")"
]
},
{
"cell_type": "markdown",
"id": "a2c7b466-bca1-47f2-9b0b-75db420b8d47",
"metadata": {},
"source": [
"### Hmmm, the `frac0` parameter didn't move\n",
"\n",
"Let's inspect the gradient"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "225d5dc0-a13c-477c-a8ed-2f03c4ff9b75",
"metadata": {},
"outputs": [],
"source": [
"loss_best_mc, grads = dg.mc_mae_loss_and_grad(p_best_mc, loss_data)\n",
"print(grads)"
]
},
{
"cell_type": "markdown",
"id": "191b59a8-e4d1-4132-b6c5-ba6c4b0eee02",
"metadata": {},
"source": [
"### The `frac0` parameter has zero gradient!\n",
"\n",
"That's why the parameter did not move from its initial position during our gradient descent. Since the `frac0` parameter has zero gradient, during our gradient descent, the other parameters adjust as best as they can to improve the agreement with the target histogram, and so the loss improves, but the fit still converges to the wrong model since the value of `frac0` never departs from its initial value.\n",
"\n",
"\n",
"#### Why doesn't noisy Monte Carlo with autodiff?\n",
"\n",
"For the case of a unimodal Gaussian, it's no problem at all to fit the model perfectly with predictions based on a noisy MC realization, so what gives?\n",
"\n",
"The `frac0` parameter in this model is different from the others: it controls the relative abundance of the two Gaussians, and we use `frac0` together with additional draws from a uniform random, so those are some clues why this parameter requires different treatment from the rest.\n",
"\n",
"Let's first consider why we get non-zero gradients for the other four parameters, $\\{\\mu_0, \\sigma_0, \\mu_1, \\sigma_1\\}$. Imagine how an infinitesimal change to $\\mu_0$ induces a change to the result of the histogram counts in bin $i$. As $\\mu_0$ changes, the positions of points $x_{\\rm j}$ drawn from the first Gaussian move, thereby smoothly changing the weights $w_{\\rm j}$ of those points. And so we get non-zero gradients of $\\mu_0$ for our loss function, and similarly for the other three parameters.\n",
"\n",
"Now consider what happens when we perturb the `frac0` parameter by some infinitesimal amount, $\\delta f_0$. If our random draw $u_{\\rm j}-f_0\\equiv\\Delta_{\\rm j},$ then since $\\delta f_0<\\Delta_{\\rm j}$ for an finite difference $\\Delta_{\\rm j}$, points to the left and right of $f_0$ remain to the left and right after the perturbation, and so gradients with respect to $f_0$ are zero.\n",
"\n",
"#### What is the solution?\n",
"The root problem comes from the stochastic Monte Carlo method we used to choose a particular Gaussian for each point. Instead, we need to compute a $f_0$-weighted sum of the soft histogram result of each Gaussian. The `predict_soft_xhist_weighted` function in `double_gaussian.py` implements this calculation.\n",
"\n",
"Let's first observe that the two methods of computing a histogram are equivalent."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1794f3b-dff3-4fba-aa54-e7286853a864",
"metadata": {},
"outputs": [],
"source": [
"XHIST_TARGET = dg.predict_soft_xhist_mc(PARAMS_TARGET, XBINS, target_key)\n",
"XHIST_TARGET_WEIGHTED = dg.predict_soft_xhist_weighted(PARAMS_TARGET, XBINS, target_key)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"__=ax.plot(XBINS[1:], XHIST_TARGET, label='noisy MC')\n",
"__=ax.plot(XBINS[1:], XHIST_TARGET_WEIGHTED, '--', label='weighted MC')\n",
"\n",
"leg = ax.legend()"
]
},
{
"cell_type": "markdown",
"id": "7815c117-382e-47f0-bff0-3ab1de9f7201",
"metadata": {},
"source": [
"### Run gradient descent with weighted soft histograms"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50544508-8638-4c9c-8ec8-9a2892e4b9a1",
"metadata": {},
"outputs": [],
"source": [
"params_init = dg.DEFAULT_PARAMS._replace()\n",
"loss_init, grads_init = dg.weighted_mae_loss_and_grad(params_init, loss_data)\n",
"grads_init"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6d7932e-538e-4ed4-910a-c3d3c5d2300e",
"metadata": {},
"outputs": [],
"source": [
"learn_rate = 0.00005\n",
"\n",
"nsteps = 300\n",
"collector = []\n",
"p_best = dg.DEFAULT_PARAMS._replace()\n",
"for istep in range(nsteps):\n",
" loss, grads = dg.weighted_mae_loss_and_grad(p_best, loss_data)\n",
" p_best = dg.param_update(p_best, grads, learn_rate)\n",
" collector.append(loss)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"xlabel = ax.set_xlabel('step')\n",
"ylabel = ax.set_ylabel('log10 loss')\n",
"__=ax.plot(np.log10(collector))\n",
"\n",
"xhist_best = dg.predict_soft_xhist_weighted(p_best, XBINS, target_key)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"\n",
"__=ax.plot(XBINS[1:], XHIST_TARGET, label='target')\n",
"__=ax.plot(XBINS[1:], XHIST_INIT, '--', label='initial guess')\n",
"__=ax.plot(XBINS[1:], xhist_best, '--', label='best fit')\n",
"\n",
"leg = ax.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b471e66-b039-4706-81ae-82bf476ff74e",
"metadata": {},
"outputs": [],
"source": [
"for key, val_best, val_init, val_target in zip(p_best._fields, p_best, dg.DEFAULT_PARAMS, PARAMS_TARGET):\n",
" print(f\"Init {key} = {val_init:.2f}\")\n",
" print(f\"Best {key} = {val_best:.2f}\")\n",
" print(f\"True {key} = {val_target:.2f}\\n\")"
]
},
{
"cell_type": "markdown",
"id": "55dcc1d5-7e86-439d-8206-0f989bf20c5d",
"metadata": {},
"source": [
"## It worked!\n",
"\n",
"Upshot: whenever using autodiff to fit models of a multi-modal PDF, we need to compute our soft histograms separately for each mode, and then calculate a probability-weighted sum of the results. Otherwise we get zero-valued gradients for parameters that control the relative abundance of the different modes of the PDF."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
13 changes: 13 additions & 0 deletions docs/source/gaussian_examples/double_gauss_page.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Fitting multi-modal probability distributions
======================================================

This section demonstrates differentiable methods for fitting a
multi-modal probability distribution,
using a double Gaussian to demonstrate the PDF-weighting technique.
We recommend reading the tutorial notebook alongside the supplementary source code linked below.

.. toctree::
:maxdepth: 1

double_gauss_demo.ipynb
double_gaussian_code_holder.rst
Loading
Loading