diff --git a/orbitize/results.py b/orbitize/results.py index cecacaff..731df8ef 100644 --- a/orbitize/results.py +++ b/orbitize/results.py @@ -115,7 +115,8 @@ def save_results(self, filename): hf.attrs['version_number'] = self.version_number # Now add post and lnlike from the results object as datasets - hf.create_dataset('post', data=self.post) + if self.post is not None: + hf.create_dataset('post', data=self.post) # hf.create_dataset('data', data=self.data) if self.lnlike is not None: hf.create_dataset('lnlike', data=self.lnlike) @@ -151,8 +152,12 @@ def load_results(self, filename, append=False): version_number = str(hf.attrs['version_number']) except KeyError: version_number = "<= 1.13" - post = np.array(hf.get('post')) - lnlike = np.array(hf.get('lnlike')) + post = hf.get('post') + if post is not None: + post = np.array(post) + lnlike = hf.get('lnlike') + if lnlike is not None: + lnlike = np.array(lnlike) try: diff --git a/tests/test_mcmc.py b/tests/test_mcmc.py index cea41417..777c80ec 100644 --- a/tests/test_mcmc.py +++ b/tests/test_mcmc.py @@ -44,21 +44,23 @@ def test_mcmc_runs(num_temps=0, num_threads=1): } ) - # run it a little (tests 0 burn-in steps) - myDriver.sampler.run_sampler(100) - assert myDriver.sampler.results.post.shape[0] == 100 - - # run it a little more - myDriver.sampler.run_sampler(1000, burn_steps=1) - assert myDriver.sampler.results.post.shape[0] == 1100 - - # run it a little more (tests adding to results object, and periodic saving) + # run it some (tests adding to results object, and periodic saving) output_filename = os.path.join(orbitize.DATADIR, 'test_mcmc.hdf5') myDriver.sampler.run_sampler( - 400, burn_steps=1, output_filename=output_filename, periodic_save_freq=2 + 400, burn_steps=10, output_filename=output_filename, periodic_save_freq=2 ) - # test results object exists and has 2100*100 steps + # TODO: Add test for restarting from saved results when burn-in is interrupted (i.e. no regular steps have been done) + + # run it a little more (tests 0 burn-in steps) + myDriver.sampler.run_sampler(100) + assert myDriver.sampler.results.post.shape[0] == 500 + + # run it a little more and save + myDriver.sampler.run_sampler(1000, burn_steps=1, output_filename=output_filename, periodic_save_freq=2) + assert myDriver.sampler.results.post.shape[0] == 1500 + + # test results object exists and has 1500*100 steps assert os.path.exists(output_filename) saved_results = results.Results() saved_results.load_results(output_filename)