Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion prospect/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""


import time
import time, sys
from functools import partial as argfix

import numpy as np
Expand Down Expand Up @@ -431,6 +431,19 @@ def run_nested(observations, model, sps,
nested_sampler="dynesty",
nested_nlive=1000,
nested_target_n_effective=1000,
nested_n_like_max=sys.maxsize,
nested_dlogz=0.01,
nested_nlive_batch=200,
nested_maxiter=sys.maxsize,
nested_maxcall=sys.maxsize,
nested_maxiter_init=sys.maxsize,
nested_maxcall_init=sys.maxsize,
nested_maxiter_batch=sys.maxsize,
nested_maxcall_batch=sys.maxsize,
nested_maxbatch=sys.maxsize,
nested_wt_kwargs={'pfrac': 1.0},
nested_filepath=None,
nested_resume=True,
verbose=False,
**kwargs):
"""Thin wrapper on :py:class:`prospect.fitting.nested.run_nested_sampler`
Expand Down Expand Up @@ -461,6 +474,24 @@ def run_nested(observations, model, sps,
Number of live points for the nested sampler. Meaning somewhat
dependent on the chosen sampler

nautilus-specific parameters
----------------------------
nested_n_like_max : int
Maximum number of likelihood evaluations for nautilus.
nested_filepath : str or None
Path to the checkpointing file. Must have a `.h5` or `.hdf5` extension. If None, no checkpointing are performed. Default is None.
nested_resume : bool
If True, resume from previous run if filepath exists. If False, start from scratch and overwrite any previous file. Default is True.

dynesty-specific parameters
----------------------------
nested_maxcall : int
Maximum number of likelihood evaluations for dynesty.
nested_maxiter : int
Maximum number of iterations for dynesty.

----------------------------

Returns
--------
result: Dictionary
Expand All @@ -481,6 +512,16 @@ def run_nested(observations, model, sps,
verbose=verbose,
nested_nlive=nested_nlive,
nested_neff=nested_target_n_effective,
nested_n_like_max=nested_n_like_max,
nested_wt_kwargs=nested_wt_kwargs,
nested_dlogz=nested_dlogz,
nested_nlive_batch=nested_nlive_batch,
nested_maxiter=nested_maxiter, nested_maxcall=nested_maxcall,
nested_maxiter_init=nested_maxiter_init, nested_maxcall_init=nested_maxcall_init,
nested_maxiter_batch=nested_maxiter_batch, nested_maxcall_batch=nested_maxcall_batch,
nested_maxbatch=nested_maxbatch,
nested_filepath=nested_filepath,
nested_resume=nested_resume,
**kwargs)
info, result_obj = output

Expand Down
50 changes: 45 additions & 5 deletions prospect/fitting/nested.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import numpy as np
import time
import time, sys
import warnings

__all__ = ["run_nested_sampler"]
Expand All @@ -11,6 +11,19 @@ def run_nested_sampler(model,
nested_sampler="dynesty",
nested_nlive=1000,
nested_neff=1000,
nested_n_like_max=sys.maxsize,
nested_dlogz=0.01,
nested_nlive_batch=200,
nested_maxiter=sys.maxsize,
nested_maxcall=sys.maxsize,
nested_maxiter_init=sys.maxsize,
nested_maxcall_init=sys.maxsize,
nested_maxiter_batch=sys.maxsize,
nested_maxcall_batch=sys.maxsize,
nested_maxbatch=sys.maxsize,
nested_wt_kwargs={'pfrac': 1.0},
nested_filepath=None,
nested_resume=True,
verbose=False,
**kwargs):
"""We give a model -- parameter discription and prior transform -- and a
Expand All @@ -26,6 +39,24 @@ def run_nested_sampler(model,
Number of live points.
nested_neff : float
Minimum effective sample size.

nautilus-specific parameters
----------------------------
nested_n_like_max : int
Maximum number of likelihood evaluations for nautilus.
nested_filepath : str or None
Path to the file where results are saved. Must have a `.h5` or `.hdf5` extension. If None, no results are written. Default is None.
nested_resume : bool
If True, resume from previous run if filepath exists. If False, start from scratch and overwrite any previous file. Default is True.

dynesty-specific parameters
----------------------------
nested_maxcall : int
Maximum number of likelihood evaluations for dynesty.
nested_maxiter : int
Maximum number of iterations for dynesty.

----------------------------
verbose : bool
Whether to output sampler progress.

Expand All @@ -48,7 +79,7 @@ def run_nested_sampler(model,
sampler_init = Sampler
init_args = (model.prior_transform, likelihood_function)
init_kwargs = dict(pass_dict=False, n_live=nested_nlive,
n_dim=model.ndim)
n_dim=model.ndim, filepath=nested_filepath, resume=nested_resume)
elif nested_sampler == 'ultranest':
from ultranest import ReactiveNestedSampler
sampler_init = ReactiveNestedSampler
Expand All @@ -59,7 +90,10 @@ def run_nested_sampler(model,
from dynesty import DynamicNestedSampler
sampler_init = DynamicNestedSampler
init_args = (likelihood_function, model.prior_transform, model.ndim)
init_kwargs = dict(nlive=nested_nlive)
init_kwargs = dict(nlive=nested_nlive,
dlogz=nested_dlogz,
maxcall=nested_maxcall_init,
maxiter=nested_maxiter_init, )
elif nested_sampler == 'nestle':
import nestle
init_kwargs = dict()
Expand All @@ -80,7 +114,7 @@ def run_nested_sampler(model,
if nested_sampler == 'nautilus':
sampler_run = sampler.run
run_args = ()
run_kwargs = dict(n_eff=nested_neff, verbose=verbose)
run_kwargs = dict(n_eff=nested_neff, n_like_max=nested_n_like_max, verbose=verbose)
elif nested_sampler == 'ultranest':
sampler_run = sampler.run
run_args = ()
Expand All @@ -90,7 +124,13 @@ def run_nested_sampler(model,
elif nested_sampler == 'dynesty':
sampler_run = sampler.run_nested
run_args = ()
run_kwargs = dict(n_effective=nested_neff, print_progress=verbose)
run_kwargs = dict(n_effective=nested_neff,
wt_kwargs=nested_wt_kwargs,
nlive_batch=nested_nlive_batch,
maxiter=nested_maxiter, maxcall=nested_maxcall,
maxiter_batch=nested_maxiter_batch,
maxcall_batch=nested_maxcall_batch, maxbatch=nested_maxbatch,
print_progress=verbose)
elif nested_sampler == 'nestle':
sampler_run = nestle.sample
run_args = (likelihood_function, model.prior_transform, model.ndim)
Expand Down