-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.go
More file actions
365 lines (323 loc) · 8.77 KB
/
main.go
File metadata and controls
365 lines (323 loc) · 8.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
package main
import (
"log"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"google.golang.org/protobuf/proto"
pb "tinychat/proto"
)
// Client WebSocket 客户端
type Client struct {
hub *Hub
conn *websocket.Conn
send chan []byte
id int32
name string
lastMsgTime time.Time // 限流: 上次消息时间
msgCount int // 限流: 当前秒内消息数
ip string // 客户端 IP
}
// Hub 管理所有客户端连接
type Hub struct {
clients map[*Client]bool
userCount map[int32]int // 每个用户ID的连接数
onlineCount int32 // 在线唯一用户数
broadcast chan []byte
register chan *Client
unregister chan *Client
mu sync.RWMutex
ipConnCount map[string]int // 每个IP的连接数 (限流)
}
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// NewHub 创建新的 Hub
func NewHub() *Hub {
return &Hub{
broadcast: make(chan []byte, 512),
register: make(chan *Client),
unregister: make(chan *Client),
clients: make(map[*Client]bool),
userCount: make(map[int32]int),
ipConnCount: make(map[string]int),
}
}
// Run Hub 主循环
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.mu.Lock()
h.clients[client] = true
count := len(h.clients)
h.mu.Unlock()
log.Printf("[连接] 新客户端连接, 当前连接数: %d", count)
// 不立即广播加入消息,等收到第一条消息后再广播
case client := <-h.unregister:
h.mu.Lock()
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
close(client.send)
count := len(h.clients)
// 减少 IP 连接计数
if client.ip != "" {
h.ipConnCount[client.ip]--
if h.ipConnCount[client.ip] <= 0 {
delete(h.ipConnCount, client.ip)
}
}
h.mu.Unlock()
// 只有已识别的用户才处理离开逻辑
if client.id != 0 {
isLastConnection := h.userLeave(client.id)
log.Printf("[离开] %s (ID: %d), 当前在线: %d", client.name, client.id, h.onlineCount)
// 只有完全离开才广播(防止多标签页刷屏)
if isLastConnection {
leaveMsg := &pb.ChatMessage{
Type: pb.MessageType_LEAVE,
SenderName: client.name,
Content: client.name + " 离开了聊天室",
}
h.broadcastProtoMessage(leaveMsg)
}
} else {
log.Printf("[断开] 未识别的客户端断开, 当前连接数: %d", count)
}
// 广播在线人数
h.broadcastOnlineCount()
} else {
h.mu.Unlock()
}
case message := <-h.broadcast:
h.mu.RLock()
var failedClients []*Client
for client := range h.clients {
select {
case client.send <- message:
default:
// 发送失败,记录需要移除的客户端
failedClients = append(failedClients, client)
}
}
h.mu.RUnlock()
// 在锁外处理失败的客户端,触发 unregister 流程
for _, client := range failedClients {
h.unregister <- client
}
}
}
}
// broadcastOnlineCount 广播在线人数
func (h *Hub) broadcastOnlineCount() {
h.mu.RLock()
count := h.onlineCount
h.mu.RUnlock()
msg := &pb.ChatMessage{
Type: pb.MessageType_ONLINE_COUNT,
OnlineCount: count,
}
h.broadcastProtoMessage(msg)
}
// userJoin 用户加入,返回是否是该用户的首次连接
func (h *Hub) userJoin(id int32) bool {
h.mu.Lock()
defer h.mu.Unlock()
isFirstConnection := h.userCount[id] == 0
if isFirstConnection {
h.onlineCount++
}
h.userCount[id]++
return isFirstConnection
}
// userLeave 用户离开,返回是否是该用户的最后一个连接
func (h *Hub) userLeave(id int32) bool {
h.mu.Lock()
defer h.mu.Unlock()
h.userCount[id]--
isLastConnection := h.userCount[id] <= 0
if isLastConnection {
delete(h.userCount, id)
h.onlineCount = max(h.onlineCount-1, 0)
}
return isLastConnection
}
// broadcastProtoMessage 广播 Protobuf 消息
func (h *Hub) broadcastProtoMessage(msg *pb.ChatMessage) {
data, err := proto.Marshal(msg)
if err != nil {
log.Printf("Protobuf 编码失败: %v", err)
return
}
h.broadcast <- data
}
// readPump 读取客户端消息
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
c.conn.SetReadLimit(128 * 1024) // 限制单条消息最大 128KB
c.conn.SetReadDeadline(time.Now().Add(90 * time.Second))
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket 错误: %v", err)
}
break
}
// 收到任何消息都重置超时(包括心跳)
c.conn.SetReadDeadline(time.Now().Add(90 * time.Second))
// 限流检查: 每秒最多 5 条消息
now := time.Now()
if now.Sub(c.lastMsgTime) < time.Second {
c.msgCount++
if c.msgCount > 5 {
log.Printf("[限流] %s (IP: %s) 发送过快, 已丢弃", c.name, c.ip)
continue // 丢弃消息
}
} else {
c.msgCount = 1
c.lastMsgTime = now
}
// 解析 Protobuf 消息
var msg pb.ChatMessage
if err := proto.Unmarshal(message, &msg); err != nil {
log.Printf("Protobuf 解码失败: %v", err)
continue
}
// 检查消息大小,防止恶意大包攻击 (限制 60KB)
if len(msg.ImageData) > 60*1024 {
log.Printf("[安全] %s 发送的图片过大 (%dKB),已丢弃", c.name, len(msg.ImageData)/1024)
continue
}
// 提取用户信息(第一次收到消息时)
if c.id == 0 && msg.SenderId != 0 {
c.id = msg.SenderId
c.name = msg.SenderName
isFirstConnection := c.hub.userJoin(c.id)
log.Printf("[加入] %s (ID: %d)", c.name, c.id)
// 只有首次加入才广播(防止多标签页刷屏)
if isFirstConnection {
joinMsg := &pb.ChatMessage{
Type: pb.MessageType_JOIN,
SenderName: c.name,
Content: c.name + " 加入了聊天室",
}
data, err := proto.Marshal(joinMsg)
if err == nil {
c.hub.mu.RLock()
for client := range c.hub.clients {
if client != c {
select {
case client.send <- data:
default:
}
}
}
c.hub.mu.RUnlock()
}
}
c.hub.broadcastOnlineCount()
}
// 广播聊天消息给其他人(不包括发送者自己)
// 过滤掉空消息(握手消息),但保留有图片的消息
if msg.Type == pb.MessageType_MESSAGE && (msg.Content != "" || len(msg.ImageData) > 0) {
c.hub.mu.RLock()
for client := range c.hub.clients {
if client != c {
select {
case client.send <- message:
default:
log.Printf("发送消息失败: %s", client.name)
}
}
}
c.hub.mu.RUnlock()
}
}
}
// writePump 发送消息到客户端
func (c *Client) writePump() {
defer c.conn.Close()
for message := range c.send {
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.conn.WriteMessage(websocket.BinaryMessage, message); err != nil {
return
}
}
// send 通道关闭时发送关闭帧
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
}
// handleWebSocket 处理 WebSocket 连接
func handleWebSocket(hub *Hub, w http.ResponseWriter, r *http.Request) {
// 获取客户端 IP
ip := r.Header.Get("X-Forwarded-For")
if ip == "" {
ip = r.Header.Get("X-Real-IP")
}
if ip == "" {
ip = r.RemoteAddr
}
// IP 连接数限制: 每个 IP 最多 3 个连接
hub.mu.Lock()
if hub.ipConnCount[ip] >= 3 {
hub.mu.Unlock()
log.Printf("[限流] IP %s 连接数过多, 拒绝连接", ip)
http.Error(w, "Too many connections from this IP", http.StatusTooManyRequests)
return
}
hub.ipConnCount[ip]++
hub.mu.Unlock()
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("WebSocket 升级失败:", err)
// 连接失败,减少计数
hub.mu.Lock()
hub.ipConnCount[ip]--
if hub.ipConnCount[ip] <= 0 {
delete(hub.ipConnCount, ip)
}
hub.mu.Unlock()
return
}
client := &Client{
hub: hub,
conn: conn,
send: make(chan []byte, 256),
ip: ip,
}
client.hub.register <- client
go client.writePump()
go client.readPump()
}
func main() {
hub := NewHub()
go hub.Run()
// WebSocket 路由
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
handleWebSocket(hub, w, r)
})
// 静态文件服务
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
http.ServeFile(w, r, "chat.html")
} else if r.URL.Path == "/proto/message.js" {
w.Header().Set("Content-Type", "application/javascript")
http.ServeFile(w, r, "proto/message.js")
} else {
http.NotFound(w, r)
}
})
addr := ":8080"
log.Printf("🚀 TinyChat 服务器启动在 http://localhost%s", addr)
if err := http.ListenAndServe(addr, nil); err != nil {
log.Fatal("服务器启动失败: ", err)
}
}