-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_dqn.cpp
More file actions
45 lines (35 loc) · 1.27 KB
/
test_dqn.cpp
File metadata and controls
45 lines (35 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <iostream>
#include "DQN.h"
#include "SimpleGridEnvironment.h"
void TestDQN() {
// Initialize environment
SimpleGridEnvironment env(10); // Grid size of 10
// Initialize DQN
std::vector<int> hiddenLayers = {2, 2};
DQN dqn(1, 2, hiddenLayers); // State size of 1, action size of 2
// Training loop
for (int episode = 0; episode < 1000; ++episode) {
auto state = env.Reset();
bool done = false;
double totalReward = 0.0;
while (!done) {
int action = dqn.SelectAction(state, dqn.GetEpsilon());
// Traditional tuple unpacking
std::tuple<std::vector<double>, double, bool> result = env.Step(action);
std::vector<double> nextState = std::get<0>(result);
double reward = std::get<1>(result);
bool doneFlag = std::get<2>(result);
done = doneFlag;
totalReward += reward;
dqn.Train(state, action, reward, nextState, dqn.GetGamma(), dqn.GetEpsilonDecay());
state = nextState;
}
// Update target network
dqn.UpdateTargetNetwork();
std::cout << "Episode " << episode + 1 << ": Total Reward = " << totalReward << std::endl;
}
}
int main() {
TestDQN();
return 0;
}