1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
| import tkinter as tk from tkinter import messagebox from PIL import Image, ImageDraw import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import matplotlib.pyplot as plt
MODEL_PATH = "digit_recognition_model.pth"
MODEL_TYPE = "MLP"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.fc1 = nn.Linear(28 * 28, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 10)
def forward(self, x): x = x.view(-1, 28 * 28) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return self.fc3(x)
def load_model(): if MODEL_TYPE == "MLP": model = MLP().to(device) else: raise ValueError("MODEL_TYPE 只能是 'MLP'")
try: model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.eval() print(f"模型加载成功!使用设备:{device}") return model except FileNotFoundError: messagebox.showerror("错误", f"未找到模型文件:{MODEL_PATH}\n请先训练模型并保存,或修改 MODEL_PATH") exit()
model = load_model()
def preprocess_image(image): """ 输入:PIL 图像(300×300 白底黑字) 输出:模型可接收的 Tensor((1, 1, 28, 28)) """ transform_list = [ transforms.Resize((28, 28)), transforms.ToTensor(), ]
transform = transforms.Compose(transform_list)
img = image.convert("L") img = img.point(lambda x: 255 - x) img_tensor = transform(img).unsqueeze(0) return img_tensor.to(device), img
def predict_digit(image): """输入手写图像,返回预测结果""" img_tensor, processed_img = preprocess_image(image)
with torch.no_grad(): output = model(img_tensor) pred_label = torch.max(output, 1)[1].item() pred_prob = torch.softmax(output, 1)[0][pred_label].item()
return pred_label, pred_prob, processed_img
class HandwritingApp: def __init__(self, root): self.root = root self.root.title("手写数字识别(28×28 MNIST 风格)") self.root.geometry("400x400")
self.canvas = tk.Canvas(root, width=300, height=300, bg="white", bd=2, relief=tk.SUNKEN) self.canvas.pack(pady=20)
self.btn_frame = tk.Frame(root) self.btn_frame.pack()
self.predict_btn = tk.Button(self.btn_frame, text="识别数字", command=self.on_predict, font=("Arial", 12)) self.predict_btn.grid(row=0, column=0, padx=10)
self.clear_btn = tk.Button(self.btn_frame, text="清空画布", command=self.on_clear, font=("Arial", 12)) self.clear_btn.grid(row=0, column=1, padx=10)
self.drawing = False self.last_x = None self.last_y = None
self.pil_image = Image.new("RGB", (300, 300), "white") self.draw = ImageDraw.Draw(self.pil_image)
self.canvas.bind("<Button-1>", self.start_drawing) self.canvas.bind("<B1-Motion>", self.draw_line) self.canvas.bind("<ButtonRelease-1>", self.stop_drawing)
def start_drawing(self, event): """开始绘制(鼠标按下)""" self.drawing = True self.last_x = event.x self.last_y = event.y
def draw_line(self, event): """绘制线条(鼠标拖动)""" if self.drawing: self.canvas.create_line( self.last_x, self.last_y, event.x, event.y, fill="black", width=15, capstyle=tk.ROUND, smooth=tk.TRUE ) self.draw.line( [(self.last_x, self.last_y), (event.x, event.y)], fill="black", width=15, joint="round" ) self.last_x = event.x self.last_y = event.y
def stop_drawing(self, event): """停止绘制(鼠标释放)""" self.drawing = False
def on_clear(self, ): """清空画布""" self.canvas.delete("all") self.pil_image = Image.new("RGB", (300, 300), "white") self.draw = ImageDraw.Draw(self.pil_image)
def on_predict(self): """识别数字(点击按钮触发)""" try: pred_label, pred_prob, processed_img = predict_digit(self.pil_image)
messagebox.showinfo( "预测结果", f"预测数字:{pred_label}\n预测概率:{pred_prob:.2%}" )
plt.figure(figsize=(4, 4)) plt.imshow(processed_img, cmap="gray") plt.title(f"预处理后(28×28)→ 预测:{pred_label}") plt.axis("off") plt.show()
except Exception as e: messagebox.showerror("错误", f"识别失败:{str(e)}")
if __name__ == "__main__": root = tk.Tk() app = HandwritingApp(root) root.mainloop()
|