Python WebSocket
大约 10 分钟约 2876 字
Python WebSocket
简介
WebSocket 是一种全双工通信协议,允许服务端主动向客户端推送数据。Python 中通过 FastAPI/WebSocket、websockets 库和 Socket.IO 等方案实现实时通信,广泛应用于即时消息、实时数据推送、协同编辑和在线游戏等场景。
特点
实现
FastAPI WebSocket 实时聊天
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import Dict, Set
import json
import asyncio
app = FastAPI()
class ConnectionManager:
"""WebSocket 连接管理器"""
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {} # room -> connections
self.user_info: Dict[WebSocket, dict] = {}
async def connect(self, websocket: WebSocket, room: str, user: dict):
await websocket.accept()
if room not in self.active_connections:
self.active_connections[room] = set()
self.active_connections[room].add(websocket)
self.user_info[websocket] = {**user, "room": room}
# 通知房间内其他用户
await self.broadcast(room, {
"type": "user_joined",
"user": user["name"],
"members": len(self.active_connections[room]),
}, exclude=websocket)
def disconnect(self, websocket: WebSocket):
info = self.user_info.pop(websocket, None)
if info:
room = info["room"]
self.active_connections[room].discard(websocket)
if not self.active_connections[room]:
del self.active_connections[room]
async def broadcast(self, room: str, message: dict, exclude: WebSocket = None):
connections = self.active_connections.get(room, set())
for connection in connections:
if connection != exclude:
try:
await connection.send_json(message)
except Exception:
self.disconnect(connection)
async def send_to_user(self, websocket: WebSocket, message: dict):
try:
await websocket.send_json(message)
except Exception:
self.disconnect(websocket)
manager = ConnectionManager()
@app.websocket("/ws/chat/{room_id}")
async def chat_endpoint(websocket: WebSocket, room_id: str):
# 从查询参数获取用户信息
user_name = websocket.query_params.get("name", "匿名")
await manager.connect(websocket, room_id, {"name": user_name})
try:
while True:
data = await asyncio.wait_for(websocket.receive_text(), timeout=120)
message = json.loads(data)
await manager.broadcast(room_id, {
"type": "message",
"user": user_name,
"content": message.get("content", ""),
})
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(room_id, {
"type": "user_left",
"user": user_name,
})
except asyncio.TimeoutError:
await manager.send_to_user(websocket, {"type": "ping"})
manager.disconnect(websocket)websockets 库:实时数据推送
import asyncio
import websockets
import json
from typing import Set
from datetime import datetime
class DataPushServer:
"""实时数据推送服务"""
def __init__(self):
self.clients: Set[websockets.WebSocketServerProtocol] = set()
self.subscriptions: dict = {} # client -> set of topics
async def register(self, websocket):
self.clients.add(websocket)
self.subscriptions[websocket] = set()
try:
async for message in websocket:
data = json.loads(message)
if data.get("action") == "subscribe":
topics = data.get("topics", [])
self.subscriptions[websocket].update(topics)
await websocket.send(json.dumps({
"type": "subscribed",
"topics": list(self.subscriptions[websocket]),
}))
elif data.get("action") == "unsubscribe":
topics = data.get("topics", [])
self.subscriptions[websocket] -= set(topics)
except websockets.exceptions.ConnectionClosed:
pass
finally:
self.clients.discard(websocket)
self.subscriptions.pop(websocket, None)
async def publish(self, topic: str, data: dict):
"""向订阅了指定主题的客户端推送数据"""
message = json.dumps({
"topic": topic,
"data": data,
"timestamp": datetime.now().isoformat(),
})
disconnected = set()
for client, topics in self.subscriptions.items():
if topic in topics:
try:
await client.send(message)
except websockets.exceptions.ConnectionClosed:
disconnected.add(client)
for client in disconnected:
self.clients.discard(client)
self.subscriptions.pop(client, None)
async def heartbeat(self):
"""定期心跳检测"""
while True:
await asyncio.sleep(30)
disconnected = set()
for client in self.clients:
try:
await client.ping()
except Exception:
disconnected.add(client)
for client in disconnected:
self.clients.discard(client)
self.subscriptions.pop(client, None)
server = DataPushServer()
async def main():
async with websockets.serve(server.register, "localhost", 8765):
await server.heartbeat()
# asyncio.run(main())消息队列模式与心跳重连
"""客户端心跳与重连机制"""
import asyncio
import websockets
import json
class ResilientWebSocketClient:
"""支持自动重连的 WebSocket 客户端"""
def __init__(self, url: str, reconnect_interval: float = 5.0):
self.url = url
self.reconnect_interval = reconnect_interval
self._ws = None
self._running = False
self._handlers = {}
def on(self, event_type: str, handler):
"""注册消息处理器"""
self._handlers[event_type] = handler
async def connect(self):
self._running = True
while self._running:
try:
async with websockets.connect(
self.url,
ping_interval=20,
ping_timeout=10,
close_timeout=5,
) as ws:
self._ws = ws
print(f"WebSocket 已连接: {self.url}")
await self._listen()
except (
websockets.exceptions.ConnectionClosed,
ConnectionRefusedError,
asyncio.TimeoutError,
) as e:
print(f"连接断开: {e},{self.reconnect_interval}s 后重连...")
self._ws = None
await asyncio.sleep(self.reconnect_interval)
async def _listen(self):
if not self._ws:
return
async for raw_message in self._ws:
try:
message = json.loads(raw_message)
msg_type = message.get("type", "unknown")
handler = self._handlers.get(msg_type)
if handler:
await handler(message)
except json.JSONDecodeError:
print(f"无效消息: {raw_message}")
async def send(self, data: dict):
if self._ws:
await self._ws.send(json.dumps(data))
async def close(self):
self._running = False
if self._ws:
await self._ws.close()
# 使用示例
async def on_message(msg):
print(f"收到消息: {msg}")
async def on_notification(msg):
print(f"收到通知: {msg}")
client = ResilientWebSocketClient("ws://localhost:8765")
client.on("message", on_message)
client.on("notification", on_notification)
# await client.connect()生产级 WebSocket 架构
"""生产级 WebSocket 架构:Redis Pub/Sub + 多实例"""
import asyncio
import json
import aioredis
from fastapi import FastAPI, WebSocket
from typing import Dict, Set
app = FastAPI()
class RedisPubSubManager:
"""基于 Redis Pub/Sub 的跨实例消息分发"""
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis_url = redis_url
self.redis = None
self.pubsub = None
self.local_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self):
self.redis = await aioredis.from_url(self.redis_url)
self.pubsub = self.redis.pubsub()
async def subscribe(self, channel: str):
await self.pubsub.subscribe(channel)
async def publish(self, channel: str, message: dict):
await self.redis.publish(channel, json.dumps(message))
async def listen(self):
"""监听 Redis 消息并分发到本地连接"""
async for message in self.pubsub.listen():
if message["type"] == "message":
channel = message["channel"]
data = json.loads(message["data"])
connections = self.local_connections.get(channel, set())
disconnected = set()
for ws in connections:
try:
await ws.send_json(data)
except Exception:
disconnected.add(ws)
connections -= disconnected
def add_connection(self, channel: str, ws: WebSocket):
if channel not in self.local_connections:
self.local_connections[channel] = set()
self.local_connections[channel].add(ws)
def remove_connection(self, channel: str, ws: WebSocket):
if channel in self.local_connections:
self.local_connections[channel].discard(ws)
# 实例化
pubsub_manager = RedisPubSubManager()
@app.on_event("startup")
async def startup():
await pubsub_manager.connect()
asyncio.create_task(pubsub_manager.listen())
@app.websocket("/ws/{channel}")
async def websocket_endpoint(websocket: WebSocket, channel: str):
await websocket.accept()
pubsub_manager.add_connection(channel, websocket)
await pubsub_manager.subscribe(channel)
try:
while True:
data = await websocket.receive_json()
await pubsub_manager.publish(channel, data)
except Exception:
pubsub_manager.remove_connection(channel, websocket)优点
缺点
总结
WebSocket 是实现实时通信的标准方案,Python 中 FastAPI 提供了简洁的 WebSocket 支持。生产环境需要关注连接管理、心跳检测、断线重连和横向扩展(Redis Pub/Sub)。对于消息可靠性要求高的场景,需要在应用层实现消息确认和重发机制。
关键知识点
- WebSocket 通过 HTTP Upgrade 握手建立,建立后切换为全双工通信协议
- 心跳机制(ping/pong)是检测连接存活的标准方式
- 多实例部署需要 Redis Pub/Sub 等机制实现跨实例消息分发
- asyncio.wait_for 可以为 WebSocket 接收设置超时,防止连接长时间空闲
项目落地视角
- WebSocket 连接必须携带认证 Token,防止未授权连接
- 配置心跳间隔(30-60 秒)和超时时间,及时清理断开的连接
- 使用 Redis Pub/Sub 支持多实例部署,确保消息不丢失
- 监控在线连接数、消息吞吐量和连接断开率
常见误区
- 忽略心跳检测,导致断开的连接占用服务器资源
- 不处理连接断开和重连,消息丢失无感知
- 将大量业务逻辑放在 WebSocket 处理中,导致消息处理变慢
- 忽略 WebSocket 连接的认证和授权
进阶路线
- 学习 Socket.IO 提供的房间、命名空间和自动重连能力
- 研究消息队列(RabbitMQ/Kafka)在实时数据推送中的应用
- 了解 Server-Sent Events(SSE)作为 WebSocket 的轻量替代方案
- 探索 WebTransport(HTTP/3)作为下一代实时通信协议
适用场景
- 即时消息、实时聊天、协同编辑等双向通信场景
- 实时数据监控、仪表盘、股票行情推送
- 在线游戏、实时协作工具等低延迟场景
落地建议
- 统一封装 WebSocket 连接管理器,处理连接、断开、心跳和重连
- 使用 Redis Pub/Sub 实现多实例消息分发
- 为 WebSocket 消息定义统一的消息格式(type、data、timestamp)
排错清单
- 检查 Nginx/反向代理是否正确配置 WebSocket 代理头
- 确认心跳机制是否正常工作,是否有僵尸连接
- 排查连接数是否超过服务器或代理的上限配置
复盘问题
- 你的 WebSocket 服务最大支持多少并发连接?连接数上限是如何确定的?
- 消息推送的延迟是否满足业务需求?是否有监控数据?
- 断线重连的用户体验如何?是否有消息补发机制?
WebSocket 认证与授权
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, Depends
from typing import Optional
import jwt
import hashlib
import hmac
app = FastAPI()
SECRET_KEY = "your-secret-key"
def verify_token(token: str) -> Optional[dict]:
"""验证 JWT Token"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
async def get_ws_user(
websocket: WebSocket,
token: str = Query(...)
):
"""WebSocket 认证依赖"""
user = verify_token(token)
if not user:
await websocket.close(code=4001, reason="认证失败")
raise WebSocketDisconnect(code=4001)
return user
@app.websocket("/ws/chat/{room_id}")
async def chat_endpoint(
websocket: WebSocket,
room_id: str,
user: dict = Depends(get_ws_user)
):
"""带认证的 WebSocket 端点"""
await websocket.accept()
user_name = user.get("sub", "anonymous")
# 验证房间权限
if not await check_room_permission(user["user_id"], room_id):
await websocket.close(code=4003, reason="无权访问该房间")
return
await manager.connect(websocket, room_id, {"name": user_name})
# ... 正常处理消息# WebSocket 消息签名验证(防止伪造)
def sign_message(message: dict, secret: str) -> str:
"""对消息进行 HMAC 签名"""
import json
payload = json.dumps(message, sort_keys=True)
return hmac.new(secret.encode(), payload.encode(), hashlib.sha256).hexdigest()
def verify_message(message: dict, signature: str, secret: str) -> bool:
"""验证消息签名"""
expected = sign_message(message, secret)
return hmac.compare_digest(expected, signature)WebSocket 消息协议设计
"""统一的 WebSocket 消息协议
所有消息格式:
{
"type": "消息类型",
"id": "消息唯一ID(用于 ACK)",
"data": {...}, // 消息内容
"timestamp": "ISO时间戳",
"ack": true // 是否需要确认
}
"""
import json
import uuid
from datetime import datetime
from typing import Any, Callable, Awaitable
class WebSocketMessage:
"""统一的 WebSocket 消息封装"""
def __init__(self, msg_type: str, data: Any = None,
require_ack: bool = False):
self.id = str(uuid.uuid4())
self.type = msg_type
self.data = data or {}
self.timestamp = datetime.utcnow().isoformat() + "Z"
self.require_ack = require_ack
def to_json(self) -> str:
return json.dumps({
"id": self.id,
"type": self.type,
"data": self.data,
"timestamp": self.timestamp,
"ack": self.require_ack
})
@classmethod
def from_json(cls, raw: str) -> "WebSocketMessage":
parsed = json.loads(raw)
msg = cls(parsed["type"], parsed.get("data"))
msg.id = parsed.get("id", str(uuid.uuid4()))
msg.timestamp = parsed.get("timestamp", "")
msg.require_ack = parsed.get("ack", False)
return msg
# 消息处理器注册
class MessageRouter:
"""消息路由 — 根据消息类型分发到不同的处理器"""
def __init__(self):
self._handlers: dict[str, Callable[[WebSocketMessage], Awaitable[None]]] = {}
def on(self, msg_type: str):
"""注册消息处理器"""
def decorator(handler):
self._handlers[msg_type] = handler
return handler
return decorator
async def dispatch(self, message: WebSocketMessage):
"""分发消息"""
handler = self._handlers.get(message.type)
if handler:
await handler(message)
else:
print(f"未处理的消息类型: {message.type}")
# 使用示例
router = MessageRouter()
@router.on("chat.message")
async def handle_chat_message(msg: WebSocketMessage):
"""处理聊天消息"""
print(f"收到消息: {msg.data}")
@router.on("user.typing")
async def handle_typing(msg: WebSocketMessage):
"""处理正在输入状态"""
print(f"用户正在输入: {msg.data.get('user')}")
@router.on("user.join")
async def handle_join(msg: WebSocketMessage):
"""处理用户加入"""
print(f"用户加入: {msg.data.get('user')}")WebSocket 性能监控
import time
from collections import deque
from dataclasses import dataclass, field
@dataclass
class WebSocketMetrics:
"""WebSocket 连接指标"""
total_connections: int = 0
active_connections: int = 0
total_messages_sent: int = 0
total_messages_received: int = 0
errors: int = 0
latency_samples: deque = field(default_factory=lambda: deque(maxlen=1000))
@property
def avg_latency_ms(self) -> float:
if not self.latency_samples:
return 0.0
return sum(self.latency_samples) / len(self.latency_samples)
@property
def p99_latency_ms(self) -> float:
if not self.latency_samples:
return 0.0
sorted_samples = sorted(self.latency_samples)
idx = int(len(sorted_samples) * 0.99)
return sorted_samples[idx]
# 全局指标
metrics = WebSocketMetrics()
class MetricsMiddleware:
"""WebSocket 指标中间件"""
@staticmethod
def record_connection_open():
metrics.total_connections += 1
metrics.active_connections += 1
@staticmethod
def record_connection_close():
metrics.active_connections = max(0, metrics.active_connections - 1)
@staticmethod
def record_message_sent():
metrics.total_messages_sent += 1
@staticmethod
def record_message_received():
metrics.total_messages_received += 1
@staticmethod
def record_latency(latency_ms: float):
metrics.latency_samples.append(latency_ms)
@staticmethod
def get_report() -> dict:
return {
"total_connections": metrics.total_connections,
"active_connections": metrics.active_connections,
"messages_sent": metrics.total_messages_sent,
"messages_received": metrics.total_messages_received,
"errors": metrics.errors,
"avg_latency_ms": round(metrics.avg_latency_ms, 2),
"p99_latency_ms": round(metrics.p99_latency_ms, 2),
}
# 定期输出指标报告
async def metrics_reporter(interval: int = 60):
"""定期输出 WebSocket 指标"""
while True:
await asyncio.sleep(interval)
report = MetricsMiddleware.get_report()
print(f"[Metrics] {report}")Nginx 反向代理配置
# nginx.conf — WebSocket 反向代理配置
upstream websocket_backend {
least_conn;
server 127.0.0.1:8000;
# 多实例时添加更多 server
keepalive 64;
}
server {
listen 80;
server_name ws.example.com;
location /ws/ {
proxy_pass http://websocket_backend;
proxy_http_version 1.1;
# WebSocket 必需的 Upgrade 头
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
# 传递客户端信息
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# 超时配置(WebSocket 长连接需要较长超时)
proxy_read_timeout 86400s;
proxy_send_timeout 86400s;
# 缓冲关闭(避免消息延迟)
proxy_buffering off;
}
}WebSocket 生产部署检查清单:
- [ ] Nginx/反向代理正确配置 Upgrade 头
- [ ] 超时设置足够长(至少 24 小时)
- [ ] 心跳机制正常工作
- [ ] 连接认证和授权
- [ ] 消息格式统一
- [ ] Redis Pub/Sub 支持多实例
- [ ] 连接数监控和告警
- [ ] 优雅关闭(SIGTERM 处理)