-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTest.py
More file actions
186 lines (151 loc) · 5.97 KB
/
Test.py
File metadata and controls
186 lines (151 loc) · 5.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import serial
import torch
import numpy as np
import time
from Model import EEGCNN
import sys
import requests
# ==== 机器人 IP 地址 ====
ROBOT_IP = "http://192.168.4.1/"
BASE_URL = f"http://192.168.4.1/controller"
# ==== 参数配置 ====
SERIAL_PORT = 'COM3'
BAUD_RATE = 57600
SAMPLES_PER_SEGMENT = 256
WINDOW_SIZE = 128
STEP_SIZE = 32
label_map = {0: 'blink', 1: 'frown', 2: 'rest'}
# ==== 全局状态 ====
attention = 0
signal_quality = 200 # 初始设定为很差
can_control = False # 是否允许控制机器人
# ==== 加载模型 ====
model = EEGCNN(num_classes=3)
model.load_state_dict(torch.load("model_dev/model.pth", map_location='cpu'))
model.eval()
# ==== 初始化串口 ====
try:
ser = serial.Serial(SERIAL_PORT, BAUD_RATE, timeout=1)
print(f"[✓] 串口连接成功:{SERIAL_PORT}")
except serial.SerialException as e:
print(f"[×] 串口连接失败: {e}")
sys.exit()
# ==== 发送指令 ====
def send_preset_command(command_id):
"""发送预设动作指令(如左转、前进等)"""
global can_control
if not can_control: # 没有注意力或信号质量差时不发
print("⚠️ 注意力不足/信号差,停止发送指令")
return
try:
params = {"pm": command_id}
response = requests.get(BASE_URL, params=params)
response.raise_for_status()
print(f"指令发送成功,响应:{response.text}")
except requests.exceptions.RequestException as err:
print(f"发送指令失败:{err}")
# ==== 校验和 ====
def valid_checksum(packet):
if len(packet) != 8:
return False
sum_check = ((0x80 + 0x02 + packet[5] + packet[6]) ^ 0xFF) & 0xFF
return sum_check == packet[7]
# ==== 读取一个对齐数据包 ====
def read_aligned_packet():
while True:
head = ser.read(1)
if head == b'\xaa':
second = ser.read(1)
if second == b'\xaa':
# 先 peek 一下第三字节,区分大小包
third = ser.read(1)
if not third:
continue
third_val = third[0]
if third_val == 0x04: # 小包(8字节)
rest = ser.read(5)
if len(rest) == 5:
packet = b'\xaa' + b'\xaa' + bytes([third_val]) + rest
if packet[3] == 0x80 and packet[4] == 0x02 and valid_checksum(packet):
return ("small", packet)
elif third_val == 0x20: # 大包(32字节)
rest = ser.read(31)
if len(rest) == 31:
packet = b'\xaa' + b'\xaa' + bytes([third_val]) + rest
return ("big", packet)
# ==== 提取原始值 ====
def extract_raw_value(packet):
high = packet[5]
low = packet[6]
val = (high << 8) | low
if val > 32768:
val -= 65536
return val
# ==== 预测单段 ====
def predict(buffer):
input_tensor = torch.tensor(np.array(buffer, dtype=np.float32).reshape(1, 1, -1))
with torch.no_grad():
output = model(input_tensor)
pred_class = torch.argmax(output, dim=1).item()
return label_map.get(pred_class, "未知")
# ==== 多窗口预测 + 多数投票 ====
def multi_predict(buffer, window_size=128, step_size=32):
votes = {lab: 0 for lab in label_map.values()}
for start in range(0, len(buffer) - window_size + 1, step_size):
window = buffer[start:start + window_size]
result = predict(window)
votes[result] += 1
final_result = max(votes, key=votes.get)
return final_result
# ==== 主程序 ====
def main():
global attention, signal_quality, can_control
print("开始实时预测(100次,每次约1秒)...\n")
for i in range(1000):
print(f"\n--- 第 {i + 1} 次 ---")
segment = []
while len(segment) < SAMPLES_PER_SEGMENT:
pkt_type, pkt = read_aligned_packet()
if pkt_type == "small": # 原始信号包
val = extract_raw_value(pkt)
if val not in (2047, -2048): # 屏蔽极值
segment.append(val)
elif pkt_type == "big": # 状态包
signal_quality = pkt[4]
attention = pkt[32]
# meditation = pkt[34]
# 更新控制状态
if attention > 0 and signal_quality <= 10:
can_control = True
else:
can_control = False
# print(f"[INFO] Attention={attention}, Meditation={meditation}, SignalQuality={signal_quality}, Control={can_control}")
print(
f"[INFO] Attention={attention} SignalQuality={signal_quality}, Control={can_control}")
print("原始数据样本:", segment)
# === 只有在 can_control=True 时才预测和发指令 ===
if can_control:
result = multi_predict(segment, window_size=WINDOW_SIZE, step_size=STEP_SIZE)
if result == 'rest':
print('rest')
# send_preset_command(2)
elif result == 'frown':
print('clench')
# send_preset_command(1)
elif result == 'blink':
print('blink')
# send_preset_command(4)
if result == 'rest':
send_preset_command(1) # 静止
elif result == 'frown':
send_preset_command(6) # 左转
elif result == 'blink':
send_preset_command(2) # 前进
print("多数投票预测结果:", result)
else:
print("⚠️ 注意力为0或信号差,跳过预测和指令发送")
time.sleep(1)
print("\n[✓] 实时预测结束。")
ser.close()
if __name__ == '__main__':
main()