Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,21 @@ def run_loop(
params = agent.initial_params(next(rng))
learner_state = agent.initial_learner_state(params)

train_actor_state = agent.initial_actor_state()

print(f"Training agent for {train_episodes} episodes")
for episode in range(train_episodes):

# Prepare agent, environment and accumulator for a new episode.
timestep = environment.reset()
accumulator.push(timestep, None)
actor_state = agent.initial_actor_state()


while not timestep.last():

# Acting.
actor_output, actor_state = agent.actor_step(
params, timestep, actor_state, next(rng), evaluation=False)
actor_output, train_actor_state = agent.actor_step(
params, timestep, train_actor_state, next(rng), evaluation=False)

# Agent-environment interaction.
action = int(actor_output.actions)
Expand All @@ -57,13 +59,13 @@ def run_loop(
# Evaluation.
if not episode % evaluate_every:
returns = 0.
eval_actor_state = agent.initial_actor_state()
for _ in range(eval_episodes):
timestep = environment.reset()
actor_state = agent.initial_actor_state()

while not timestep.last():
actor_output, actor_state = agent.actor_step(
params, timestep, actor_state, next(rng), evaluation=True)
actor_output, eval_actor_state = agent.actor_step(
params, timestep, eval_actor_state, next(rng), evaluation=True)
timestep = environment.step(int(actor_output.actions))
returns += timestep.reward

Expand Down
2 changes: 1 addition & 1 deletion examples/simple_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from absl import app
from absl import flags
from bsuite.environments import catch
from examples import experiment
import haiku as hk
from haiku import nets
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
from rlax.examples import experiment

Params = collections.namedtuple("Params", "online target")
ActorState = collections.namedtuple("ActorState", "count")
Expand Down