-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
149 lines (120 loc) · 4.8 KB
/
predict.py
File metadata and controls
149 lines (120 loc) · 4.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
from network import NeuralNetwork
from PIL import Image, ImageOps
import os
def save_weights(nn, filepath="weights.npz"):
# Save all weights and biases
save_dict = {}
for i in range(len(nn.weights)):
save_dict[f'W{i}'] = nn.weights[i]
save_dict[f'b{i}'] = nn.biases[i]
np.savez(filepath, **save_dict)
print(f"Weights saved to {filepath}")
def load_weights(nn, filepath="weights.npz"):
data = np.load(filepath)
# Load all weights and biases
for i in range(len(nn.weights)):
nn.weights[i] = data[f'W{i}']
nn.biases[i] = data[f'b{i}']
print(f"Weights loaded from {filepath}")
return nn
def load_trained_network():
nn = NeuralNetwork(hidden_sizes=[512, 384, 256, 128])
if os.path.exists("weights.npz"):
use_saved = input("Found saved weights. Use them? (Y/N): ").strip().upper()
if use_saved == 'Y':
return load_weights(nn)
from train import load_mnist
print("Training network...")
X_train, y_train, _, _ = load_mnist()
# Create validation split
val_size = int(0.1 * len(X_train))
X_val = X_train[:val_size]
y_val = y_train[:val_size]
X_train_split = X_train[val_size:]
y_train_split = y_train[val_size:]
nn.train(X_train_split, y_train_split, X_val, y_val,
epochs=80, initial_lr=0.15, batch_size=128)
save_weights(nn)
return nn
def predict_image(nn, image_path):
from scipy import ndimage
# Load and convert to grayscale
img = Image.open(image_path).convert('L')
img_array = np.array(img, dtype=np.float32)
# Determine if we need to invert
corners = [
img_array[0:5, 0:5].mean(),
img_array[0:5, -5:].mean(),
img_array[-5:, 0:5].mean(),
img_array[-5:, -5:].mean()
]
corner_mean = np.mean(corners)
h, w = img_array.shape
center_mean = img_array[h//3:2*h//3, w//3:2*w//3].mean()
if corner_mean > center_mean:
img_array = 255 - img_array
# Threshold to get binary image
threshold = np.mean(img_array)
binary = img_array > threshold
# Find bounding box of digit
rows = np.any(binary, axis=1)
cols = np.any(binary, axis=0)
if not rows.any() or not cols.any():
# Empty image, return zeros
img_array = np.zeros((28, 28))
else:
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
# Extract digit with padding
digit = img_array[rmin:rmax+1, cmin:cmax+1]
# Resize to fit in 20x20 box (leaving 4 pixels margin like MNIST)
h, w = digit.shape
scale = min(20.0 / h, 20.0 / w)
new_h = int(h * scale)
new_w = int(w * scale)
digit_img = Image.fromarray(digit.astype(np.uint8))
digit_img = digit_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
digit_resized = np.array(digit_img, dtype=np.float32)
# Center in 28x28 canvas
img_array = np.zeros((28, 28), dtype=np.float32)
y_offset = (28 - new_h) // 2
x_offset = (28 - new_w) // 2
img_array[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = digit_resized
# Use center of mass for fine adjustment
cy, cx = ndimage.center_of_mass(img_array)
shift_y = int(14 - cy)
shift_x = int(14 - cx)
# Limit shifts to prevent digit from going off canvas
shift_y = np.clip(shift_y, -3, 3)
shift_x = np.clip(shift_x, -3, 3)
img_array = ndimage.shift(img_array, (shift_y, shift_x), mode='constant', cval=0)
# Save what the network sees
debug_img = Image.fromarray(img_array.astype(np.uint8))
debug_img.save(f"debug_{os.path.basename(image_path)}")
# Test-time augmentation: predict on multiple slight shifts and average
# img_array is in 0-255 range at this point
nn.training = False
predictions = []
shifts = [(0, 0), (-1, 0), (1, 0), (0, -1), (0, 1)]
for dy, dx in shifts:
shifted = ndimage.shift(img_array, (dy, dx), mode='constant', cval=0)
# Normalize to [0, 1] after shifting
shifted_norm = (shifted / 255.0).reshape(1, 784)
output = nn.forward(shifted_norm)
predictions.append(output[0])
# Average predictions across augmentations
avg_output = np.mean(predictions, axis=0)
prediction = np.argmax(avg_output)
confidence = avg_output[prediction]
print(f"{image_path}: predicted {prediction} ({confidence:.1%} confidence)")
return prediction
if __name__ == "__main__":
nn = load_trained_network()
print("\n--- Predictions ---")
test_files = sorted([f for f in os.listdir('.') if f.startswith('test_digit')])
if test_files:
for f in test_files:
predict_image(nn, f)
else:
print("No test_digit files found. Add images like test_digit_1.png, test_digit_2.png, etc.")