-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRecurrentNetwork.cpp
More file actions
77 lines (63 loc) · 1.87 KB
/
RecurrentNetwork.cpp
File metadata and controls
77 lines (63 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/*
* RecurrentNeuralNetwork.cpp
* ABQ
*
* Created by Elliot Meyerson on 3/24/15.
* Copyright 2015 Elliot Meyerson. All rights reserved.
*
*/
#include "RecurrentNetwork.h"
#include <iostream>
RecurrentNetwork::RecurrentNetwork() {
int num_input = 24;
int num_hidden = 2;
int num_output = 3;
input_end_ = num_input + 1;
output_start_ = input_end_ + num_hidden;
num_units_ = num_hidden + num_output;
num_nodes_ = output_start_ + num_output;
activation_.resize(num_nodes_);
weights_.resize(num_nodes_);
for (int i = 0; i < num_nodes_; i++) {
weights_[i].resize(num_units_);
}
}
void RecurrentNetwork::SetWeights(const vector<double> &genome) {
for (int i = 0; i < genome.size(); i++) {
//std::cout << genome[i];
//std::cout << "\n";
weights_[i/num_units_][i%num_units_] = genome[i];
}
}
void RecurrentNetwork::SetInput (const vector<double> &input_values) {
for (int i = 0; i < input_values.size(); i++) {
activation_[i+1] = input_values[i];
}
}
int RecurrentNetwork::GetAction () {
int action = 0;
double max_activation = -10.0;
for (int i = output_start_; i < num_nodes_; i++) {
if (activation_[i] > max_activation) {
max_activation = activation_[i];
action = i;
}
}
return action - output_start_;
}
void RecurrentNetwork::Step () {
vector<double> updated_activation(num_units_, 0.0);
for (int i = 0; i < num_nodes_; i++) {
for (int j = 0; j < num_units_; j++) {
updated_activation[j] += activation_[i] * weights_[i][j];
}
}
for (int i = 0; i < num_units_; i++) {
activation_[i+input_end_] = tanh(updated_activation[i]);
}
}
void RecurrentNetwork::Flush () {
for (int i = input_end_; i < num_nodes_; i++) {
activation_[i] = 0.0;
}
}