diff --git a/examples/experiment.py b/examples/experiment.py index b0bfc14..dd526e0 100644 --- a/examples/experiment.py +++ b/examples/experiment.py @@ -28,19 +28,21 @@ def run_loop( params = agent.initial_params(next(rng)) learner_state = agent.initial_learner_state(params) + train_actor_state = agent.initial_actor_state() + print(f"Training agent for {train_episodes} episodes") for episode in range(train_episodes): # Prepare agent, environment and accumulator for a new episode. timestep = environment.reset() accumulator.push(timestep, None) - actor_state = agent.initial_actor_state() + while not timestep.last(): # Acting. - actor_output, actor_state = agent.actor_step( - params, timestep, actor_state, next(rng), evaluation=False) + actor_output, train_actor_state = agent.actor_step( + params, timestep, train_actor_state, next(rng), evaluation=False) # Agent-environment interaction. action = int(actor_output.actions) @@ -57,13 +59,13 @@ def run_loop( # Evaluation. if not episode % evaluate_every: returns = 0. + eval_actor_state = agent.initial_actor_state() for _ in range(eval_episodes): timestep = environment.reset() - actor_state = agent.initial_actor_state() while not timestep.last(): - actor_output, actor_state = agent.actor_step( - params, timestep, actor_state, next(rng), evaluation=True) + actor_output, eval_actor_state = agent.actor_step( + params, timestep, eval_actor_state, next(rng), evaluation=True) timestep = environment.step(int(actor_output.actions)) returns += timestep.reward diff --git a/examples/simple_dqn.py b/examples/simple_dqn.py index 66478f5..81a034d 100644 --- a/examples/simple_dqn.py +++ b/examples/simple_dqn.py @@ -19,6 +19,7 @@ from absl import app from absl import flags from bsuite.environments import catch +from examples import experiment import haiku as hk from haiku import nets import jax @@ -26,7 +27,6 @@ import numpy as np import optax import rlax -from rlax.examples import experiment Params = collections.namedtuple("Params", "online target") ActorState = collections.namedtuple("ActorState", "count")