Files
crypto_trader/backend/api/websocket.py

243 lines
9.5 KiB
Python

"""WebSocket endpoints for real-time updates."""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import List, Dict, Set, Callable, Optional
import json
import asyncio
from datetime import datetime
from decimal import Decimal
from collections import deque
from ..core.schemas import PriceUpdate, OrderUpdate
from src.data.pricing_service import get_pricing_service
router = APIRouter()
class ConnectionManager:
"""Manages WebSocket connections."""
def __init__(self):
self.active_connections: List[WebSocket] = []
self.subscribed_symbols: Set[str] = set()
self._pricing_service = None
self._price_callbacks: Dict[str, List[Callable]] = {}
# Queue for price updates (thread-safe for async processing)
self._price_update_queue: deque = deque()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._processing_task = None
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
"""Set the event loop for async operations."""
self._loop = loop
async def start_background_tasks(self):
"""Start background processing tasks."""
if self._processing_task is None or self._processing_task.done():
self._processing_task = asyncio.create_task(self._process_queue())
print("WebSocket manager background tasks started")
async def _process_queue(self):
"""Periodically process price updates from queue."""
while True:
try:
if self._price_update_queue:
# Process up to 10 updates at a time to prevent blocking
for _ in range(10):
if not self._price_update_queue:
break
update = self._price_update_queue.popleft()
await self.broadcast_price_update(
exchange=update["exchange"],
symbol=update["symbol"],
price=update["price"]
)
await asyncio.sleep(0.01) # Check queue frequently but yield
except Exception as e:
print(f"Error processing price update queue: {e}")
await asyncio.sleep(1)
def _initialize_pricing_service(self):
"""Initialize pricing service and subscribe to price updates."""
if self._pricing_service is None:
self._pricing_service = get_pricing_service()
def subscribe_to_symbol(self, symbol: str):
"""Subscribe to price updates for a symbol."""
self._initialize_pricing_service()
if symbol not in self.subscribed_symbols:
self.subscribed_symbols.add(symbol)
def price_callback(data):
"""Callback for price updates from pricing service."""
# Store update in queue for async processing
update = {
"exchange": "pricing",
"symbol": data.get('symbol', symbol),
"price": Decimal(str(data.get('price', 0)))
}
self._price_update_queue.append(update)
# Note: We rely on the background task to process this queue now
self._pricing_service.subscribe_ticker(symbol, price_callback)
async def _process_price_update(self, update: Dict):
"""Process a price update asynchronously.
DEPRECATED: Use _process_queue instead.
"""
await self.broadcast_price_update(
exchange=update["exchange"],
symbol=update["symbol"],
price=update["price"]
)
def unsubscribe_from_symbol(self, symbol: str):
"""Unsubscribe from price updates for a symbol."""
if symbol in self.subscribed_symbols:
self.subscribed_symbols.remove(symbol)
if self._pricing_service:
self._pricing_service.unsubscribe_ticker(symbol)
async def connect(self, websocket: WebSocket):
"""Add WebSocket connection to active connections.
Note: websocket.accept() must be called before this method.
"""
self.active_connections.append(websocket)
# Ensure background task is running
await self.start_background_tasks()
def disconnect(self, websocket: WebSocket):
"""Remove a WebSocket connection."""
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def broadcast_price_update(self, exchange: str, symbol: str, price: Decimal):
"""Broadcast price update to all connected clients."""
message = PriceUpdate(
exchange=exchange,
symbol=symbol,
price=price,
timestamp=datetime.utcnow()
)
await self.broadcast(message.dict())
async def broadcast_order_update(self, order_id: int, status: str, filled_quantity: Decimal = None):
"""Broadcast order update to all connected clients."""
message = OrderUpdate(
order_id=order_id,
status=status,
filled_quantity=filled_quantity,
timestamp=datetime.utcnow()
)
await self.broadcast(message.dict())
async def broadcast_training_progress(self, step: str, progress: int, total: int, message: str, details: dict = None):
"""Broadcast training progress update to all connected clients."""
update = {
"type": "training_progress",
"step": step,
"progress": progress,
"total": total,
"percent": int((progress / total) * 100) if total > 0 else 0,
"message": message,
"details": details or {},
"timestamp": datetime.utcnow().isoformat()
}
await self.broadcast(update)
async def broadcast(self, message: dict):
"""Broadcast message to all connected clients."""
disconnected = []
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception:
disconnected.append(connection)
# Remove disconnected clients
for conn in disconnected:
self.disconnect(conn)
manager = ConnectionManager()
@router.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time updates."""
# Check origin for CORS before accepting
origin = websocket.headers.get("origin")
allowed_origins = ["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:3000", "http://127.0.0.1:5173"]
# Allow connections from allowed origins or if no origin header (direct connections, testing)
# Relaxed check: Log warning but allow if origin doesn't match, to prevent disconnection issues in some environments
if origin and origin not in allowed_origins:
print(f"Warning: WebSocket connection from unknown origin: {origin}")
# We allow it for now to fix disconnection issues, but normally we might block
# await websocket.close(code=1008, reason="Origin not allowed")
# return
# Accept the connection
await websocket.accept()
try:
# Connect to manager (starts background tasks if needed)
await manager.connect(websocket)
subscribed_symbols = set()
while True:
# Receive messages from client (for subscriptions, etc.)
data = await websocket.receive_text()
try:
message = json.loads(data)
# Handle subscription messages
if message.get("type") == "subscribe":
symbol = message.get("symbol")
if symbol:
subscribed_symbols.add(symbol)
manager.subscribe_to_symbol(symbol)
await websocket.send_json({
"type": "subscription_confirmed",
"symbol": symbol
})
elif message.get("type") == "unsubscribe":
symbol = message.get("symbol")
if symbol and symbol in subscribed_symbols:
subscribed_symbols.remove(symbol)
manager.unsubscribe_from_symbol(symbol)
await websocket.send_json({
"type": "unsubscription_confirmed",
"symbol": symbol
})
else:
# Default acknowledgment
await websocket.send_json({"type": "ack", "message": "received"})
except json.JSONDecodeError:
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
except Exception as e:
# Don't send internal errors to client in production, but okay for debugging
await websocket.send_json({"type": "error", "message": str(e)})
except WebSocketDisconnect:
# Clean up subscriptions
for symbol in subscribed_symbols:
manager.unsubscribe_from_symbol(symbol)
manager.disconnect(websocket)
except Exception as e:
manager.disconnect(websocket)
print(f"WebSocket error: {e}")
# Only close if not already closed
try:
await websocket.close(code=1011)
except Exception:
pass