243 lines
9.5 KiB
Python
243 lines
9.5 KiB
Python
"""WebSocket endpoints for real-time updates."""
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
from typing import List, Dict, Set, Callable, Optional
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from collections import deque
|
|
|
|
from ..core.schemas import PriceUpdate, OrderUpdate
|
|
from src.data.pricing_service import get_pricing_service
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections."""
|
|
|
|
|
|
def __init__(self):
|
|
self.active_connections: List[WebSocket] = []
|
|
self.subscribed_symbols: Set[str] = set()
|
|
self._pricing_service = None
|
|
self._price_callbacks: Dict[str, List[Callable]] = {}
|
|
# Queue for price updates (thread-safe for async processing)
|
|
self._price_update_queue: deque = deque()
|
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
self._processing_task = None
|
|
|
|
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
|
|
"""Set the event loop for async operations."""
|
|
self._loop = loop
|
|
|
|
async def start_background_tasks(self):
|
|
"""Start background processing tasks."""
|
|
if self._processing_task is None or self._processing_task.done():
|
|
self._processing_task = asyncio.create_task(self._process_queue())
|
|
print("WebSocket manager background tasks started")
|
|
|
|
async def _process_queue(self):
|
|
"""Periodically process price updates from queue."""
|
|
while True:
|
|
try:
|
|
if self._price_update_queue:
|
|
# Process up to 10 updates at a time to prevent blocking
|
|
for _ in range(10):
|
|
if not self._price_update_queue:
|
|
break
|
|
update = self._price_update_queue.popleft()
|
|
await self.broadcast_price_update(
|
|
exchange=update["exchange"],
|
|
symbol=update["symbol"],
|
|
price=update["price"]
|
|
)
|
|
await asyncio.sleep(0.01) # Check queue frequently but yield
|
|
except Exception as e:
|
|
print(f"Error processing price update queue: {e}")
|
|
await asyncio.sleep(1)
|
|
|
|
def _initialize_pricing_service(self):
|
|
"""Initialize pricing service and subscribe to price updates."""
|
|
if self._pricing_service is None:
|
|
self._pricing_service = get_pricing_service()
|
|
|
|
def subscribe_to_symbol(self, symbol: str):
|
|
"""Subscribe to price updates for a symbol."""
|
|
self._initialize_pricing_service()
|
|
|
|
if symbol not in self.subscribed_symbols:
|
|
self.subscribed_symbols.add(symbol)
|
|
|
|
def price_callback(data):
|
|
"""Callback for price updates from pricing service."""
|
|
# Store update in queue for async processing
|
|
update = {
|
|
"exchange": "pricing",
|
|
"symbol": data.get('symbol', symbol),
|
|
"price": Decimal(str(data.get('price', 0)))
|
|
}
|
|
self._price_update_queue.append(update)
|
|
# Note: We rely on the background task to process this queue now
|
|
|
|
self._pricing_service.subscribe_ticker(symbol, price_callback)
|
|
|
|
async def _process_price_update(self, update: Dict):
|
|
"""Process a price update asynchronously.
|
|
DEPRECATED: Use _process_queue instead.
|
|
"""
|
|
await self.broadcast_price_update(
|
|
exchange=update["exchange"],
|
|
symbol=update["symbol"],
|
|
price=update["price"]
|
|
)
|
|
|
|
def unsubscribe_from_symbol(self, symbol: str):
|
|
"""Unsubscribe from price updates for a symbol."""
|
|
if symbol in self.subscribed_symbols:
|
|
self.subscribed_symbols.remove(symbol)
|
|
if self._pricing_service:
|
|
self._pricing_service.unsubscribe_ticker(symbol)
|
|
|
|
async def connect(self, websocket: WebSocket):
|
|
"""Add WebSocket connection to active connections.
|
|
|
|
Note: websocket.accept() must be called before this method.
|
|
"""
|
|
self.active_connections.append(websocket)
|
|
# Ensure background task is running
|
|
await self.start_background_tasks()
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
"""Remove a WebSocket connection."""
|
|
if websocket in self.active_connections:
|
|
self.active_connections.remove(websocket)
|
|
|
|
async def broadcast_price_update(self, exchange: str, symbol: str, price: Decimal):
|
|
"""Broadcast price update to all connected clients."""
|
|
message = PriceUpdate(
|
|
exchange=exchange,
|
|
symbol=symbol,
|
|
price=price,
|
|
timestamp=datetime.utcnow()
|
|
)
|
|
await self.broadcast(message.dict())
|
|
|
|
async def broadcast_order_update(self, order_id: int, status: str, filled_quantity: Decimal = None):
|
|
"""Broadcast order update to all connected clients."""
|
|
message = OrderUpdate(
|
|
order_id=order_id,
|
|
status=status,
|
|
filled_quantity=filled_quantity,
|
|
timestamp=datetime.utcnow()
|
|
)
|
|
await self.broadcast(message.dict())
|
|
|
|
async def broadcast_training_progress(self, step: str, progress: int, total: int, message: str, details: dict = None):
|
|
"""Broadcast training progress update to all connected clients."""
|
|
update = {
|
|
"type": "training_progress",
|
|
"step": step,
|
|
"progress": progress,
|
|
"total": total,
|
|
"percent": int((progress / total) * 100) if total > 0 else 0,
|
|
"message": message,
|
|
"details": details or {},
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
await self.broadcast(update)
|
|
|
|
async def broadcast(self, message: dict):
|
|
"""Broadcast message to all connected clients."""
|
|
disconnected = []
|
|
for connection in self.active_connections:
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception:
|
|
disconnected.append(connection)
|
|
|
|
# Remove disconnected clients
|
|
for conn in disconnected:
|
|
self.disconnect(conn)
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@router.websocket("/")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
"""WebSocket endpoint for real-time updates."""
|
|
# Check origin for CORS before accepting
|
|
origin = websocket.headers.get("origin")
|
|
allowed_origins = ["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:3000", "http://127.0.0.1:5173"]
|
|
|
|
# Allow connections from allowed origins or if no origin header (direct connections, testing)
|
|
# Relaxed check: Log warning but allow if origin doesn't match, to prevent disconnection issues in some environments
|
|
if origin and origin not in allowed_origins:
|
|
print(f"Warning: WebSocket connection from unknown origin: {origin}")
|
|
# We allow it for now to fix disconnection issues, but normally we might block
|
|
# await websocket.close(code=1008, reason="Origin not allowed")
|
|
# return
|
|
|
|
# Accept the connection
|
|
await websocket.accept()
|
|
|
|
try:
|
|
# Connect to manager (starts background tasks if needed)
|
|
await manager.connect(websocket)
|
|
|
|
subscribed_symbols = set()
|
|
|
|
while True:
|
|
# Receive messages from client (for subscriptions, etc.)
|
|
data = await websocket.receive_text()
|
|
try:
|
|
message = json.loads(data)
|
|
|
|
# Handle subscription messages
|
|
if message.get("type") == "subscribe":
|
|
symbol = message.get("symbol")
|
|
if symbol:
|
|
subscribed_symbols.add(symbol)
|
|
manager.subscribe_to_symbol(symbol)
|
|
await websocket.send_json({
|
|
"type": "subscription_confirmed",
|
|
"symbol": symbol
|
|
})
|
|
|
|
elif message.get("type") == "unsubscribe":
|
|
symbol = message.get("symbol")
|
|
if symbol and symbol in subscribed_symbols:
|
|
subscribed_symbols.remove(symbol)
|
|
manager.unsubscribe_from_symbol(symbol)
|
|
await websocket.send_json({
|
|
"type": "unsubscription_confirmed",
|
|
"symbol": symbol
|
|
})
|
|
|
|
else:
|
|
# Default acknowledgment
|
|
await websocket.send_json({"type": "ack", "message": "received"})
|
|
|
|
except json.JSONDecodeError:
|
|
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
|
|
except Exception as e:
|
|
# Don't send internal errors to client in production, but okay for debugging
|
|
await websocket.send_json({"type": "error", "message": str(e)})
|
|
|
|
except WebSocketDisconnect:
|
|
# Clean up subscriptions
|
|
for symbol in subscribed_symbols:
|
|
manager.unsubscribe_from_symbol(symbol)
|
|
manager.disconnect(websocket)
|
|
except Exception as e:
|
|
manager.disconnect(websocket)
|
|
print(f"WebSocket error: {e}")
|
|
# Only close if not already closed
|
|
try:
|
|
await websocket.close(code=1011)
|
|
except Exception:
|
|
pass
|
|
|