-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_gui.py
More file actions
81 lines (57 loc) · 1.77 KB
/
predict_gui.py
File metadata and controls
81 lines (57 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import tkinter as tk
from PIL import Image , ImageDraw
import numpy as np
import torch
import torchvision.transforms as transforms
from models.model import NeuralNet
#Loading the model
model = NeuralNet()
model.load_state_dict(torch.load('model.pth'))
model.eval()
#Trasnforming the image to tensor
transform = transforms.Compose([
transforms.ToTensor(),
])
#Creating the GUI
window = tk.Tk()
window.title("Handwritten Digit Classifier using MNIST Dataset")
canvas_width = 280
canvas_height = 280
canvas = tk.Canvas(window , width = canvas_width , height = canvas_height , bg = "white")
canvas.pack()
#Creating a PIL image to draw on the canvas
image = Image.new("L", (canvas_width , canvas_height ), "white")
draw = ImageDraw.Draw(image)
#Drawing on the canvas
def paint(event):
x1 = event.x - 8
y1 = event.y - 8
x2 = event.x + 8
y2 = event.y + 8
canvas.create_oval([x1, y1, x2, y2], fill="black")
draw.ellipse([x1 ,y1 , x2 ,y2], fill = "black")
canvas.bind("<B1-Motion>", paint)
#predicting the digit
def predict():
img = image.resize((28, 28))
img = 255 - np.array(img)
img = transform(img)
img = img.unsqueeze(0)
with torch.no_grad():
output = model(img)
_, predicted = torch.max(output , 1)
label.config(text = f"Predicted Label : {predicted.item()}")
#Clearing the canvas
def clear():
canvas.delete("all")
global image, draw
image = Image.new("L", (canvas_width , canvas_height), "white")
draw = ImageDraw.Draw(image)
#Buttons
predict_btn = tk.Button(window , text = "Predict" , command = predict)
predict_btn.pack()
clear_btn = tk.Button(window , text = "Clear", command = clear)
clear_btn.pack()
label = tk.Label(window , text = "Draw a digit and click on Predict")
label.pack()
window.mainloop()