Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"max_env_steps": 200, "nn_type": "conv", "alpha": 0.001, "gamma": 0.98, "train_epsilon": 0.9, "test_epsilon": 0.05, "epsilon_decay": 0.999, "replay_buffer_size": 200, "batch_size": 32, "target_update_interval": 50, "numTraining": 2000, "verbose": false, "device": "cuda:4", "state_size": [10, 10], "action_size": 4}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"max_env_steps": 200, "nn_type": "conv", "alpha": 0.001, "gamma": 0.98, "train_epsilon": 0.9, "test_epsilon": 0.05, "epsilon_decay": 0.999, "replay_buffer_size": 200, "batch_size": 32, "target_update_interval": 50, "numTraining": 2000, "verbose": false, "device": "cuda:4", "state_size": [10, 10], "action_size": 4}
19 changes: 19 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def main(cfg):
environment_args = {
'environment_name': cfg.environment.type,
'grid_size': cfg.grid_size,
'max_steps': cfg.max_steps
}
drawer_args = {
'grid_size': cfg.grid_size,
Expand Down Expand Up @@ -42,6 +43,24 @@ def main(cfg):
'max_iterations': cfg.controller.max_iterations,
'model_path': cfg.controller.model_path,
})
elif controller == 'dqn':
controller_args.update({
'max_env_steps': cfg.controller.max_env_steps,
'model_path': cfg.controller.model_path,
'nn_type': cfg.controller.nn_type,
'alpha': cfg.controller.alpha,
'gamma': cfg.controller.gamma,
'train_epsilon': cfg.controller.train_epsilon,
'test_epsilon': cfg.controller.test_epsilon,
'epsilon_decay': cfg.controller.epsilon_decay,
'replay_buffer_size': cfg.controller.replay_buffer_size,
'batch_size': cfg.controller.batch_size,
'target_update_interval': cfg.controller.target_update_interval,
'numTraining': cfg.controller.numTraining,
'verbose': cfg.controller.verbose,
'max_env_steps': cfg.controller.max_env_steps,
'device': cfg.controller.device
})
else:
raise ValueError(f"Unknown controller: {controller}")

Expand Down
1 change: 1 addition & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ cell_size: 40
framerate: 2
eval: true
num_episodes: 1000
max_steps: 200
21 changes: 21 additions & 0 deletions config/controller/dqn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
type: dqn

# common training params
alpha: 0.001
batch_size: 32
numTraining: 2000
nn_type: conv # or use conv

# RL training params
target_update_interval: 50
replay_buffer_size: 200
epsilon_decay: 0.999
gamma: 0.98
train_epsilon: 0.9
test_epsilon: 0.05

# other params
verbose: False
max_env_steps: 200
model_path: null
device: cuda:4
21 changes: 21 additions & 0 deletions config/controller/dqn_ghosts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
type: dqn

# common training params
alpha: 0.001
batch_size: 32
numTraining: 4000
nn_type: conv # or use conv

# RL training params
target_update_interval: 50
replay_buffer_size: 200
epsilon_decay: 0.9995
gamma: 0.98
train_epsilon: 0.5
test_epsilon: 0.05

# other params
verbose: False
max_env_steps: 200
model_path: null
device: cuda:4
9 changes: 9 additions & 0 deletions config/controller/qlearn_ghosts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
type: qlearn
alpha: 0.25
train_epsilon: 0.9
test_epsilon: 0.05
gamma: 0.98
gamma_eps: 0.99995
numTraining: 100000
verbose: False
model_path: null
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ numpy = "^2.2.2"
pygame = "^2.6.1"
click = "^8.1.7"
hydra-core = "^1.3.2"
tqdm = "^4.67.1"
torch = "2.6.0"

[build-system]
requires = ["poetry-core"]
Expand Down
3 changes: 2 additions & 1 deletion src/controller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .controller import Controller
from .basic import BasicController
from .qlearn import QTable, QLearnAgent
from .value_iteration import ValueIterationAgent
from .value_iteration import ValueIterationAgent
from .dqn import DQNAgent
Loading