diff --git a/ptan/common/utils.py b/ptan/common/utils.py index 4235518..3a065a5 100644 --- a/ptan/common/utils.py +++ b/ptan/common/utils.py @@ -349,9 +349,9 @@ def __enter__(self): def __exit__(self, *args): self.writer.close() - def reward(self, reward, frame, epsilon=None): + def reward(self, reward, frame, epsilon=None, avg_size=100): self.total_rewards.append(reward) - mean_reward = np.mean(self.total_rewards[-100:]) + mean_reward = np.mean(self.total_rewards[-avg_size:]) ts_diff = time.time() - self.ts if ts_diff > self.min_ts_diff: speed = (frame - self.ts_frame) / ts_diff