-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_training.py
More file actions
80 lines (64 loc) · 2.59 KB
/
model_training.py
File metadata and controls
80 lines (64 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import pickle
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Function to pad sequences manually (using numpy)
def pad_sequences_custom(data, maxlen):
padded_data = []
for seq in data:
# Pad with zeros if the sequence is shorter than maxlen
if len(seq) < maxlen:
padded_seq = seq + [0] * (maxlen - len(seq))
else:
padded_seq = seq[:maxlen] # Truncate if the sequence is longer than maxlen
padded_data.append(padded_seq)
return np.array(padded_data)
# Load processed data
try:
data_dict = pickle.load(open('data.pickle', 'rb'))
data = data_dict['data']
labels = data_dict['labels']
# Inspect the first few data points to see the structure
print(f"First few data points:")
for i in range(min(5, len(data))): # Inspecting up to 5 data points
print(f"Data point {i}: {len(data[i])} values")
except Exception as e:
print(f"Error loading data: {e}")
raise
# Check if all sequences in data have the same length
data_lengths = [len(d) for d in data]
print(f"Lengths of data points: {data_lengths[:10]}") # Display the first 10 lengths
# Find the length of the longest sequence
max_len = max(data_lengths)
print(f"Maximum sequence length: {max_len}")
# Pad the sequences to ensure all sequences have the same length
padded_data = pad_sequences_custom(data, max_len)
# Convert data and labels to numpy arrays
data = np.asarray(padded_data)
labels = np.asarray(labels)
# Check the shapes of data and labels
print(f"Data shape: {data.shape}")
print(f"Labels shape: {labels.shape}")
# Ensure labels are correctly formatted (one label per sample)
if len(data) != len(labels):
print("Mismatch between number of samples and number of labels!")
raise ValueError("Data and labels do not align!")
# Split data into training and testing sets
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, stratify=labels, shuffle=True)
print(f"Training samples: {len(x_train)}, Test samples: {len(x_test)}")
# Train Random Forest model
model = RandomForestClassifier()
model.fit(x_train, y_train)
# Evaluate model accuracy
y_predict = model.predict(x_test)
accuracy = accuracy_score(y_predict, y_test)
print(f'{accuracy * 100:.2f}% of samples were classified correctly!')
# Save trained model
try:
with open('model.p', 'wb') as f:
pickle.dump({'model': model}, f)
print("Model saved successfully.")
except Exception as e:
print(f"Error saving model: {e}")
raise