"""WebSocket endpoints for real-time updates.""" from fastapi import APIRouter, WebSocket, WebSocketDisconnect from typing import List, Dict, Set, Callable, Optional import json import asyncio from datetime import datetime from decimal import Decimal from collections import deque from ..core.schemas import PriceUpdate, OrderUpdate from src.data.pricing_service import get_pricing_service router = APIRouter() class ConnectionManager: """Manages WebSocket connections.""" def __init__(self): self.active_connections: List[WebSocket] = [] self.subscribed_symbols: Set[str] = set() self._pricing_service = None self._price_callbacks: Dict[str, List[Callable]] = {} # Queue for price updates (thread-safe for async processing) self._price_update_queue: deque = deque() self._loop: Optional[asyncio.AbstractEventLoop] = None self._processing_task = None def set_event_loop(self, loop: asyncio.AbstractEventLoop): """Set the event loop for async operations.""" self._loop = loop async def start_background_tasks(self): """Start background processing tasks.""" if self._processing_task is None or self._processing_task.done(): self._processing_task = asyncio.create_task(self._process_queue()) print("WebSocket manager background tasks started") async def _process_queue(self): """Periodically process price updates from queue.""" while True: try: if self._price_update_queue: # Process up to 10 updates at a time to prevent blocking for _ in range(10): if not self._price_update_queue: break update = self._price_update_queue.popleft() await self.broadcast_price_update( exchange=update["exchange"], symbol=update["symbol"], price=update["price"] ) await asyncio.sleep(0.01) # Check queue frequently but yield except Exception as e: print(f"Error processing price update queue: {e}") await asyncio.sleep(1) def _initialize_pricing_service(self): """Initialize pricing service and subscribe to price updates.""" if self._pricing_service is None: self._pricing_service = get_pricing_service() def subscribe_to_symbol(self, symbol: str): """Subscribe to price updates for a symbol.""" self._initialize_pricing_service() if symbol not in self.subscribed_symbols: self.subscribed_symbols.add(symbol) def price_callback(data): """Callback for price updates from pricing service.""" # Store update in queue for async processing update = { "exchange": "pricing", "symbol": data.get('symbol', symbol), "price": Decimal(str(data.get('price', 0))) } self._price_update_queue.append(update) # Note: We rely on the background task to process this queue now self._pricing_service.subscribe_ticker(symbol, price_callback) async def _process_price_update(self, update: Dict): """Process a price update asynchronously. DEPRECATED: Use _process_queue instead. """ await self.broadcast_price_update( exchange=update["exchange"], symbol=update["symbol"], price=update["price"] ) def unsubscribe_from_symbol(self, symbol: str): """Unsubscribe from price updates for a symbol.""" if symbol in self.subscribed_symbols: self.subscribed_symbols.remove(symbol) if self._pricing_service: self._pricing_service.unsubscribe_ticker(symbol) async def connect(self, websocket: WebSocket): """Add WebSocket connection to active connections. Note: websocket.accept() must be called before this method. """ self.active_connections.append(websocket) # Ensure background task is running await self.start_background_tasks() def disconnect(self, websocket: WebSocket): """Remove a WebSocket connection.""" if websocket in self.active_connections: self.active_connections.remove(websocket) async def broadcast_price_update(self, exchange: str, symbol: str, price: Decimal): """Broadcast price update to all connected clients.""" message = PriceUpdate( exchange=exchange, symbol=symbol, price=price, timestamp=datetime.utcnow() ) await self.broadcast(message.dict()) async def broadcast_order_update(self, order_id: int, status: str, filled_quantity: Decimal = None): """Broadcast order update to all connected clients.""" message = OrderUpdate( order_id=order_id, status=status, filled_quantity=filled_quantity, timestamp=datetime.utcnow() ) await self.broadcast(message.dict()) async def broadcast_training_progress(self, step: str, progress: int, total: int, message: str, details: dict = None): """Broadcast training progress update to all connected clients.""" update = { "type": "training_progress", "step": step, "progress": progress, "total": total, "percent": int((progress / total) * 100) if total > 0 else 0, "message": message, "details": details or {}, "timestamp": datetime.utcnow().isoformat() } await self.broadcast(update) async def broadcast(self, message: dict): """Broadcast message to all connected clients.""" disconnected = [] for connection in self.active_connections: try: await connection.send_json(message) except Exception: disconnected.append(connection) # Remove disconnected clients for conn in disconnected: self.disconnect(conn) manager = ConnectionManager() @router.websocket("/") async def websocket_endpoint(websocket: WebSocket): """WebSocket endpoint for real-time updates.""" # Check origin for CORS before accepting origin = websocket.headers.get("origin") allowed_origins = ["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:3000", "http://127.0.0.1:5173"] # Allow connections from allowed origins or if no origin header (direct connections, testing) # Relaxed check: Log warning but allow if origin doesn't match, to prevent disconnection issues in some environments if origin and origin not in allowed_origins: print(f"Warning: WebSocket connection from unknown origin: {origin}") # We allow it for now to fix disconnection issues, but normally we might block # await websocket.close(code=1008, reason="Origin not allowed") # return # Accept the connection await websocket.accept() try: # Connect to manager (starts background tasks if needed) await manager.connect(websocket) subscribed_symbols = set() while True: # Receive messages from client (for subscriptions, etc.) data = await websocket.receive_text() try: message = json.loads(data) # Handle subscription messages if message.get("type") == "subscribe": symbol = message.get("symbol") if symbol: subscribed_symbols.add(symbol) manager.subscribe_to_symbol(symbol) await websocket.send_json({ "type": "subscription_confirmed", "symbol": symbol }) elif message.get("type") == "unsubscribe": symbol = message.get("symbol") if symbol and symbol in subscribed_symbols: subscribed_symbols.remove(symbol) manager.unsubscribe_from_symbol(symbol) await websocket.send_json({ "type": "unsubscription_confirmed", "symbol": symbol }) else: # Default acknowledgment await websocket.send_json({"type": "ack", "message": "received"}) except json.JSONDecodeError: await websocket.send_json({"type": "error", "message": "Invalid JSON"}) except Exception as e: # Don't send internal errors to client in production, but okay for debugging await websocket.send_json({"type": "error", "message": str(e)}) except WebSocketDisconnect: # Clean up subscriptions for symbol in subscribed_symbols: manager.unsubscribe_from_symbol(symbol) manager.disconnect(websocket) except Exception as e: manager.disconnect(websocket) print(f"WebSocket error: {e}") # Only close if not already closed try: await websocket.close(code=1011) except Exception: pass