Local changes: Updated model training, removed debug instrumentation, and configuration improvements

This commit is contained in:
kfox
2025-12-26 01:15:43 -05:00
commit cc60da49e7
388 changed files with 57127 additions and 0 deletions

1
backend/api/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""API routers package."""

144
backend/api/alerts.py Normal file
View File

@@ -0,0 +1,144 @@
"""Alerts API endpoints."""
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from datetime import datetime
from sqlalchemy import select
from src.core.database import Alert, get_database
router = APIRouter()
def get_alert_manager():
"""Get alert manager instance."""
from src.alerts.manager import get_alert_manager as _get_alert_manager
return _get_alert_manager()
class AlertCreate(BaseModel):
"""Create alert request."""
name: str
alert_type: str # price, indicator, risk, system
condition: dict
class AlertUpdate(BaseModel):
"""Update alert request."""
name: Optional[str] = None
condition: Optional[dict] = None
enabled: Optional[bool] = None
class AlertResponse(BaseModel):
"""Alert response."""
id: int
name: str
alert_type: str
condition: dict
enabled: bool
triggered: bool
triggered_at: Optional[datetime] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
@router.get("/", response_model=List[AlertResponse])
async def list_alerts(
enabled_only: bool = False,
manager=Depends(get_alert_manager)
):
"""List all alerts."""
try:
alerts = await manager.list_alerts(enabled_only=enabled_only)
return [AlertResponse.model_validate(alert) for alert in alerts]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/", response_model=AlertResponse)
async def create_alert(
alert_data: AlertCreate,
manager=Depends(get_alert_manager)
):
"""Create a new alert."""
try:
alert = await manager.create_alert(
name=alert_data.name,
alert_type=alert_data.alert_type,
condition=alert_data.condition
)
return AlertResponse.model_validate(alert)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{alert_id}", response_model=AlertResponse)
async def get_alert(alert_id: int):
"""Get alert by ID."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Alert).where(Alert.id == alert_id)
result = await session.execute(stmt)
alert = result.scalar_one_or_none()
if not alert:
raise HTTPException(status_code=404, detail="Alert not found")
return AlertResponse.model_validate(alert)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/{alert_id}", response_model=AlertResponse)
async def update_alert(alert_id: int, alert_data: AlertUpdate):
"""Update an alert."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Alert).where(Alert.id == alert_id)
result = await session.execute(stmt)
alert = result.scalar_one_or_none()
if not alert:
raise HTTPException(status_code=404, detail="Alert not found")
if alert_data.name is not None:
alert.name = alert_data.name
if alert_data.condition is not None:
alert.condition = alert_data.condition
if alert_data.enabled is not None:
alert.enabled = alert_data.enabled
await session.commit()
await session.refresh(alert)
return AlertResponse.model_validate(alert)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{alert_id}")
async def delete_alert(alert_id: int):
"""Delete an alert."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Alert).where(Alert.id == alert_id)
result = await session.execute(stmt)
alert = result.scalar_one_or_none()
if not alert:
raise HTTPException(status_code=404, detail="Alert not found")
await session.delete(alert)
await session.commit()
return {"status": "deleted", "alert_id": alert_id}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

564
backend/api/autopilot.py Normal file
View File

@@ -0,0 +1,564 @@
"""AutoPilot API endpoints."""
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel
from ..core.dependencies import get_database
from ..core.schemas import OrderSide
from src.core.database import get_database as get_db
# Import autopilot - path should be set up in main.py
from src.autopilot import (
stop_all_autopilots,
get_intelligent_autopilot,
get_strategy_selector,
get_performance_tracker,
get_performance_tracker,
get_autopilot_mode_info,
)
from src.worker.tasks import train_model_task
from src.core.config import get_config
from celery.result import AsyncResult
router = APIRouter()
class BootstrapConfig(BaseModel):
"""Bootstrap training data configuration."""
days: int = 90
timeframe: str = "1h"
min_samples_per_strategy: int = 10
symbols: List[str] = ["BTC/USD", "ETH/USD"]
class MultiSymbolAutopilotConfig(BaseModel):
"""Multi-symbol autopilot configuration."""
symbols: List[str]
mode: str = "intelligent"
auto_execute: bool = False
timeframe: str = "1h"
exchange_id: int = 1
paper_trading: bool = True
interval: float = 60.0
# =============================================================================
# Intelligent Autopilot Endpoints
# =============================================================================
class IntelligentAutopilotConfig(BaseModel):
symbol: str
exchange_id: int = 1
timeframe: str = "1h"
interval: float = 60.0
paper_trading: bool = True
# =============================================================================
# Unified Autopilot Endpoints
# =============================================================================
class UnifiedAutopilotConfig(BaseModel):
"""Unified autopilot configuration (Inteligent Mode)."""
symbol: str
mode: str = "intelligent" # Kept for compatibility but only "intelligent" is supported
auto_execute: bool = False
interval: float = 60.0
exchange_id: int = 1
timeframe: str = "1h"
paper_trading: bool = True
@router.post("/intelligent/start", deprecated=True)
async def start_intelligent_autopilot(
config: IntelligentAutopilotConfig,
background_tasks: BackgroundTasks
):
"""Start the Intelligent Autopilot engine.
.. deprecated:: Use /start-unified instead with mode='intelligent'
"""
try:
autopilot = get_intelligent_autopilot(
symbol=config.symbol,
exchange_id=config.exchange_id,
timeframe=config.timeframe,
interval=config.interval,
paper_trading=config.paper_trading
)
if not autopilot.is_running:
background_tasks.add_task(autopilot.start)
return {
"status": "started",
"symbol": config.symbol,
"timeframe": config.timeframe
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/intelligent/stop")
async def stop_intelligent_autopilot(symbol: str, timeframe: str = "1h"):
"""Stop the Intelligent Autopilot engine."""
try:
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
autopilot.stop()
return {"status": "stopped", "symbol": symbol}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/intelligent/status/{symbol:path}")
async def get_intelligent_status(symbol: str, timeframe: str = "1h"):
"""Get Intelligent Autopilot status."""
try:
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
return autopilot.get_status()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/intelligent/performance")
async def get_intelligent_performance(
strategy_name: Optional[str] = None,
days: int = 30
):
"""Get strategy performance metrics."""
try:
tracker = get_performance_tracker()
if strategy_name:
metrics = tracker.calculate_metrics(strategy_name, period_days=days)
return {"strategy": strategy_name, "metrics": metrics}
else:
# Get all strategies
history = tracker.get_performance_history(days=days)
if history.empty:
return {"strategies": []}
strategies = history['strategy_name'].unique()
results = {}
for strat in strategies:
results[strat] = tracker.calculate_metrics(strat, period_days=days)
return {"strategies": results}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/intelligent/training-stats")
async def get_training_stats(days: int = 365):
"""Get statistics about available training data.
Returns:
Dictionary with total samples and per-strategy counts
"""
try:
tracker = get_performance_tracker()
counts = await tracker.get_strategy_sample_counts(days=days)
return {
"total_samples": sum(counts.values()),
"strategy_counts": counts
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/intelligent/retrain")
async def retrain_model(force: bool = False, bootstrap: bool = True):
"""Manually trigger model retraining (Background Task).
Offloads training to Celery worker.
"""
try:
# Get all bootstrap config to pass to worker
config = get_config()
symbols = config.get("autopilot.intelligent.bootstrap.symbols", ["BTC/USD", "ETH/USD"])
days = config.get("autopilot.intelligent.bootstrap.days", 90)
timeframe = config.get("autopilot.intelligent.bootstrap.timeframe", "1h")
min_samples = config.get("autopilot.intelligent.bootstrap.min_samples_per_strategy", 10)
# Submit to Celery with all configured parameters
task = train_model_task.delay(
force_retrain=force,
bootstrap=bootstrap,
symbols=symbols,
days=days,
timeframe=timeframe,
min_samples_per_strategy=min_samples
)
return {
"status": "queued",
"message": "Model retraining started in background",
"task_id": task.id
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/intelligent/model-info")
async def get_model_info():
"""Get ML model information."""
try:
selector = get_strategy_selector()
return selector.get_model_info()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/intelligent/reset")
async def reset_model():
"""Reset/delete all saved ML models and training data.
This clears all persisted model files AND training data from database,
allowing for a fresh start with new features.
"""
try:
from pathlib import Path
from src.core.database import get_database, MarketConditionsSnapshot, StrategyPerformance
from sqlalchemy import delete
# Get model directory
model_dir = Path.home() / ".local" / "share" / "crypto_trader" / "models"
deleted_count = 0
if model_dir.exists():
# Delete all strategy selector model files
for model_file in model_dir.glob("strategy_selector_*.joblib"):
model_file.unlink()
deleted_count += 1
# Clear training data from database
db = get_database()
db_cleared = 0
try:
async with db.get_session() as session:
# Delete all market conditions snapshots
result1 = await session.execute(delete(MarketConditionsSnapshot))
# Delete all strategy performance records
result2 = await session.execute(delete(StrategyPerformance))
await session.commit()
db_cleared = result1.rowcount + result2.rowcount
except Exception as e:
# Database clearing is optional - continue even if it fails
pass
# Reset the in-memory model state
selector = get_strategy_selector()
from src.autopilot.models import StrategySelectorModel
selector.model = StrategySelectorModel(model_type="classifier")
return {
"status": "success",
"message": f"Deleted {deleted_count} model file(s) and {db_cleared} training records. Model reset to untrained state.",
"deleted_count": deleted_count,
"db_records_cleared": db_cleared
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Multi-Symbol Autopilot Endpoints
# =============================================================================
@router.post("/multi-symbol/start")
async def start_multi_symbol_autopilot(
config: MultiSymbolAutopilotConfig,
background_tasks: BackgroundTasks
):
"""Start autopilot for multiple symbols simultaneously.
Args:
config: Multi-symbol autopilot configuration
background_tasks: FastAPI background tasks
"""
try:
results = []
for symbol in config.symbols:
# Always use intelligent mode
autopilot = get_intelligent_autopilot(
symbol=symbol,
exchange_id=config.exchange_id,
timeframe=config.timeframe,
interval=config.interval,
paper_trading=config.paper_trading
)
autopilot.enable_auto_execution = config.auto_execute
if not autopilot.is_running:
# Set running flag synchronously before scheduling background task
autopilot._running = True
background_tasks.add_task(autopilot.start)
results.append({"symbol": symbol, "status": "started"})
else:
results.append({"symbol": symbol, "status": "already_running"})
return {
"status": "success",
"mode": "intelligent",
"symbols": results
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/multi-symbol/stop")
async def stop_multi_symbol_autopilot(
symbols: List[str],
mode: str = "intelligent",
timeframe: str = "1h"
):
"""Stop autopilot for multiple symbols.
Args:
symbols: List of symbols to stop
mode: Autopilot mode (pattern or intelligent)
timeframe: Timeframe for intelligent mode
"""
try:
results = []
for symbol in symbols:
# Always use intelligent mode
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
autopilot.stop()
results.append({"symbol": symbol, "status": "stopped"})
return {
"status": "success",
"symbols": results
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/multi-symbol/status")
async def get_multi_symbol_status(
symbols: str = "", # Comma-separated list
mode: str = "intelligent",
timeframe: str = "1h"
):
"""Get status for multiple symbols.
Args:
symbols: Comma-separated list of symbols (empty = all running)
mode: Autopilot mode
timeframe: Timeframe for intelligent mode
"""
from src.autopilot.intelligent_autopilot import _intelligent_autopilots
try:
results = []
if symbols:
symbol_list = [s.strip() for s in symbols.split(",")]
else:
# Get all running autopilots (intelligent only)
symbol_list = [key.split(":")[0] for key in _intelligent_autopilots.keys()]
for symbol in symbol_list:
try:
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
status = autopilot.get_status()
status["symbol"] = symbol
status["mode"] = "intelligent"
results.append(status)
except Exception:
results.append({
"symbol": symbol,
"mode": "intelligent",
"running": False,
"error": "Not found"
})
return {
"mode": mode,
"symbols": results,
"total_running": sum(1 for r in results if r.get("running", False))
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Unified Autopilot Endpoints (New)
# =============================================================================
@router.get("/modes")
async def get_autopilot_modes():
"""Get information about available autopilot modes.
Returns mode descriptions, capabilities, tradeoffs, and comparison data.
"""
try:
return get_autopilot_mode_info()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/start-unified")
async def start_unified_autopilot(
config: UnifiedAutopilotConfig,
background_tasks: BackgroundTasks
):
"""Start autopilot with unified interface (Intelligent Mode only)."""
try:
# Validate mode (for backward compatibility of API clients sending mode)
if config.mode and config.mode != "intelligent":
# We allow it but will treat it as intelligent if possible, or raise error if critical
pass
# Start ML-based autopilot
autopilot = get_intelligent_autopilot(
symbol=config.symbol,
exchange_id=config.exchange_id,
timeframe=config.timeframe,
interval=config.interval,
paper_trading=config.paper_trading
)
# Set auto-execution if enabled
if config.auto_execute:
autopilot.enable_auto_execution = True
if not autopilot.is_running:
# Schedule background task (state management handled by autopilot.start via Redis)
background_tasks.add_task(autopilot.start)
return {
"status": "started",
"mode": "intelligent",
"symbol": config.symbol,
"timeframe": config.timeframe,
"auto_execute": config.auto_execute
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/stop-unified")
async def stop_unified_autopilot(symbol: str, mode: str, timeframe: str = "1h"):
"""Stop autopilot for a symbol."""
try:
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
autopilot.stop()
return {"status": "stopped", "symbol": symbol, "mode": "intelligent"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/status-unified/{symbol:path}")
async def get_unified_status(symbol: str, mode: str, timeframe: str = "1h"):
"""Get autopilot status for a symbol."""
try:
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
# Use distributed status check (Redis)
status = await autopilot.get_distributed_status()
status["mode"] = "intelligent"
return status
except Exception as e:
logger.error(f"Error getting unified status: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/bootstrap-config", response_model=BootstrapConfig)
async def get_bootstrap_config():
"""Get bootstrap training data configuration."""
from src.core.config import get_config
config = get_config()
return BootstrapConfig(
days=config.get("autopilot.intelligent.bootstrap.days", 90),
timeframe=config.get("autopilot.intelligent.bootstrap.timeframe", "1h"),
min_samples_per_strategy=config.get("autopilot.intelligent.bootstrap.min_samples_per_strategy", 10),
symbols=config.get("autopilot.intelligent.bootstrap.symbols", ["BTC/USD", "ETH/USD"]),
)
@router.put("/bootstrap-config")
async def update_bootstrap_config(settings: BootstrapConfig):
"""Update bootstrap training data configuration."""
from src.core.config import get_config
config = get_config()
try:
config.set("autopilot.intelligent.bootstrap.days", settings.days)
config.set("autopilot.intelligent.bootstrap.timeframe", settings.timeframe)
config.set("autopilot.intelligent.bootstrap.min_samples_per_strategy", settings.min_samples_per_strategy)
config.set("autopilot.intelligent.bootstrap.symbols", settings.symbols)
# Also update the strategy selector instance if it exists
selector = get_strategy_selector()
selector.bootstrap_days = settings.days
selector.bootstrap_timeframe = settings.timeframe
selector.min_samples_per_strategy = settings.min_samples_per_strategy
selector.bootstrap_symbols = settings.symbols
return {"status": "success", "message": "Bootstrap configuration updated"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/tasks/{task_id}")
async def get_task_status(task_id: str):
"""Get status of a background task."""
try:
task_result = AsyncResult(task_id)
try:
# Accessing status or result might raise an exception if deserialization fails
status = task_result.status
result_data = task_result.result if task_result.ready() else None
meta_data = task_result.info if status == 'PROGRESS' else None
# serialized exception handling
if isinstance(result_data, Exception):
result_data = {
"error": str(result_data),
"type": type(result_data).__name__,
"detail": str(result_data)
}
elif status == "FAILURE" and (not result_data or result_data == {}):
# If failure but empty result, try to get traceback or use a default message
tb = getattr(task_result, 'traceback', None)
if tb:
result_data = {"error": "Task failed", "detail": str(tb)}
else:
result_data = {"error": "Task failed with no error info", "detail": "Check worker logs for details"}
result = {
"task_id": task_id,
"status": status,
"result": result_data
}
if meta_data:
result["meta"] = meta_data
return result
except Exception as inner_e:
# If Celery fails to get status/result (e.g. serialization error), return FAILURE
# This prevents 500 errors in the API when the task itself failed badly
return {
"task_id": task_id,
"status": "FAILURE",
"result": {"error": str(inner_e), "detail": "Failed to retrieve task status"},
"meta": {"error": "Task retrieval failed"}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,78 @@
"""Backtesting API endpoints."""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Dict, Any
from sqlalchemy import select
import uuid
from ..core.dependencies import get_backtesting_engine, get_strategy_registry
from ..core.schemas import BacktestRequest, BacktestResponse
from src.core.database import Strategy, get_database
router = APIRouter()
# Store running backtests
_backtests: Dict[str, Dict[str, Any]] = {}
@router.post("/run", response_model=BacktestResponse)
async def run_backtest(
backtest_data: BacktestRequest,
background_tasks: BackgroundTasks,
backtest_engine=Depends(get_backtesting_engine)
):
"""Run a backtest."""
try:
db = get_database()
async with db.get_session() as session:
# Get strategy
stmt = select(Strategy).where(Strategy.id == backtest_data.strategy_id)
result = await session.execute(stmt)
strategy_db = result.scalar_one_or_none()
if not strategy_db:
raise HTTPException(status_code=404, detail="Strategy not found")
# Create strategy instance
registry = get_strategy_registry()
strategy_instance = registry.create_instance(
strategy_id=strategy_db.id,
name=strategy_db.class_name,
parameters=strategy_db.parameters,
timeframes=strategy_db.timeframes or [backtest_data.timeframe]
)
if not strategy_instance:
raise HTTPException(status_code=400, detail="Failed to create strategy instance")
# Run backtest
results = backtest_engine.run_backtest(
strategy=strategy_instance,
symbol=backtest_data.symbol,
exchange=backtest_data.exchange,
timeframe=backtest_data.timeframe,
start_date=backtest_data.start_date,
end_date=backtest_data.end_date,
initial_capital=backtest_data.initial_capital,
slippage=backtest_data.slippage,
fee_rate=backtest_data.fee_rate
)
if "error" in results:
raise HTTPException(status_code=400, detail=results["error"])
return BacktestResponse(
results=results,
status="completed"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/results/{backtest_id}")
async def get_backtest_results(backtest_id: str):
"""Get backtest results by ID."""
if backtest_id not in _backtests:
raise HTTPException(status_code=404, detail="Backtest not found")
return _backtests[backtest_id]

42
backend/api/exchanges.py Normal file
View File

@@ -0,0 +1,42 @@
"""Exchange API endpoints."""
from typing import List
from fastapi import APIRouter, HTTPException
from sqlalchemy import select
from ..core.schemas import ExchangeResponse
from src.core.database import Exchange, get_database
router = APIRouter()
@router.get("/", response_model=List[ExchangeResponse])
async def list_exchanges():
"""List all exchanges."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Exchange).order_by(Exchange.name)
result = await session.execute(stmt)
exchanges = result.scalars().all()
return [ExchangeResponse.model_validate(e) for e in exchanges]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{exchange_id}", response_model=ExchangeResponse)
async def get_exchange(exchange_id: int):
"""Get exchange by ID."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Exchange).where(Exchange.id == exchange_id)
result = await session.execute(stmt)
exchange = result.scalar_one_or_none()
if not exchange:
raise HTTPException(status_code=404, detail="Exchange not found")
return ExchangeResponse.model_validate(exchange)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

280
backend/api/market_data.py Normal file
View File

@@ -0,0 +1,280 @@
"""Market Data API endpoints."""
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta
from fastapi import APIRouter, HTTPException, Query, Body
from pydantic import BaseModel
import pandas as pd
from src.core.database import MarketData, get_database
from src.data.pricing_service import get_pricing_service
from src.core.config import get_config
router = APIRouter()
@router.get("/ohlcv/{symbol:path}")
async def get_ohlcv(
symbol: str,
timeframe: str = "1h",
limit: int = 100,
exchange: str = "coinbase" # Default exchange
):
"""Get OHLCV data for a symbol."""
from sqlalchemy import select
try:
# Try database first
try:
db = get_database()
async with db.get_session() as session:
# Use select() for async compatibility
stmt = select(MarketData).filter_by(
symbol=symbol,
timeframe=timeframe,
exchange=exchange
).order_by(MarketData.timestamp.desc()).limit(limit)
result = await session.execute(stmt)
data = result.scalars().all()
if data:
return [
{
"time": int(d.timestamp.timestamp()),
"open": float(d.open),
"high": float(d.high),
"low": float(d.low),
"close": float(d.close),
"volume": float(d.volume)
}
for d in reversed(data)
]
except Exception as db_error:
import sys
print(f"Database query failed, falling back to live data: {db_error}", file=sys.stderr)
# If no data in DB or DB error, fetch live from pricing service
try:
pricing_service = get_pricing_service()
# pricing_service.get_ohlcv is currently sync in its implementation but we call it from our async handler
ohlcv_data = pricing_service.get_ohlcv(
symbol=symbol,
timeframe=timeframe,
limit=limit
)
if ohlcv_data:
# Convert to frontend format: [timestamp, open, high, low, close, volume] -> {time, open, high, low, close, volume}
return [
{
"time": int(candle[0] / 1000), # Convert ms to seconds
"open": float(candle[1]),
"high": float(candle[2]),
"low": float(candle[3]),
"close": float(candle[4]),
"volume": float(candle[5])
}
for candle in ohlcv_data
]
except Exception as fetch_error:
import sys
print(f"Failed to fetch live data: {fetch_error}", file=sys.stderr)
# If all else fails, return empty list
return []
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/ticker/{symbol:path}")
async def get_ticker(symbol: str):
"""Get current ticker data for a symbol.
Returns ticker data with provider information.
"""
try:
pricing_service = get_pricing_service()
ticker_data = pricing_service.get_ticker(symbol)
if not ticker_data:
raise HTTPException(status_code=404, detail=f"Ticker data not available for {symbol}")
active_provider = pricing_service.get_active_provider()
return {
"symbol": symbol,
"bid": float(ticker_data.get('bid', 0)),
"ask": float(ticker_data.get('ask', 0)),
"last": float(ticker_data.get('last', 0)),
"high": float(ticker_data.get('high', 0)),
"low": float(ticker_data.get('low', 0)),
"volume": float(ticker_data.get('volume', 0)),
"timestamp": ticker_data.get('timestamp'),
"provider": active_provider,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/providers/health")
async def get_provider_health(provider: Optional[str] = Query(None, description="Specific provider name")):
"""Get health status for pricing providers.
Args:
provider: Optional provider name to get health for specific provider
"""
try:
pricing_service = get_pricing_service()
health_data = pricing_service.get_provider_health(provider)
return {
"active_provider": pricing_service.get_active_provider(),
"health": health_data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/providers/status")
async def get_provider_status():
"""Get detailed status for all pricing providers."""
try:
pricing_service = get_pricing_service()
health_data = pricing_service.get_provider_health()
cache_stats = pricing_service.get_cache_stats()
return {
"active_provider": pricing_service.get_active_provider(),
"providers": health_data,
"cache": cache_stats,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/providers/config")
async def get_provider_config():
"""Get provider configuration."""
try:
config = get_config()
provider_config = config.get("data_providers", {})
return provider_config
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
class ProviderConfigUpdate(BaseModel):
"""Provider configuration update model."""
primary: Optional[List[Dict[str, Any]]] = None
fallback: Optional[Dict[str, Any]] = None
caching: Optional[Dict[str, Any]] = None
websocket: Optional[Dict[str, Any]] = None
@router.put("/providers/config")
async def update_provider_config(config_update: ProviderConfigUpdate = Body(...)):
"""Update provider configuration."""
try:
config = get_config()
current_config = config.get("data_providers", {})
# Update configuration
if config_update.primary is not None:
current_config["primary"] = config_update.primary
if config_update.fallback is not None:
current_config["fallback"] = {**current_config.get("fallback", {}), **config_update.fallback}
if config_update.caching is not None:
current_config["caching"] = {**current_config.get("caching", {}), **config_update.caching}
if config_update.websocket is not None:
current_config["websocket"] = {**current_config.get("websocket", {}), **config_update.websocket}
# Save configuration
config.set("data_providers", current_config)
config.save()
return {"message": "Configuration updated successfully", "config": current_config}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/spread")
async def get_spread_data(
primary_symbol: str = Query(..., description="Primary symbol (e.g., SOL/USD)"),
secondary_symbol: str = Query(..., description="Secondary symbol (e.g., AVAX/USD)"),
timeframe: str = Query("1h", description="Timeframe"),
lookback: int = Query(50, description="Number of candles to fetch"),
):
"""Get spread and Z-Score data for pairs trading visualization.
Returns spread ratio and Z-Score time series for the given symbol pair.
"""
try:
pricing_service = get_pricing_service()
# Fetch OHLCV for both symbols
ohlcv_a = pricing_service.get_ohlcv(
symbol=primary_symbol,
timeframe=timeframe,
limit=lookback
)
ohlcv_b = pricing_service.get_ohlcv(
symbol=secondary_symbol,
timeframe=timeframe,
limit=lookback
)
if not ohlcv_a or not ohlcv_b:
raise HTTPException(status_code=404, detail="Could not fetch data for one or both symbols")
# Convert to DataFrames
df_a = pd.DataFrame(ohlcv_a, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
df_b = pd.DataFrame(ohlcv_b, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
# Align by length
min_len = min(len(df_a), len(df_b))
df_a = df_a.tail(min_len).reset_index(drop=True)
df_b = df_b.tail(min_len).reset_index(drop=True)
# Calculate spread (ratio)
closes_a = df_a['close'].astype(float)
closes_b = df_b['close'].astype(float)
spread = closes_a / closes_b
# Calculate Z-Score with rolling window
lookback_window = min(20, min_len - 1)
rolling_mean = spread.rolling(window=lookback_window).mean()
rolling_std = spread.rolling(window=lookback_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Build response
result = []
for i in range(min_len):
result.append({
"timestamp": int(df_a['timestamp'].iloc[i]),
"spread": float(spread.iloc[i]) if not pd.isna(spread.iloc[i]) else None,
"zScore": float(z_score.iloc[i]) if not pd.isna(z_score.iloc[i]) else None,
"priceA": float(closes_a.iloc[i]),
"priceB": float(closes_b.iloc[i]),
})
# Filter out entries with null Z-Score (during warmup period)
result = [r for r in result if r["zScore"] is not None]
return {
"primarySymbol": primary_symbol,
"secondarySymbol": secondary_symbol,
"timeframe": timeframe,
"lookbackWindow": lookback_window,
"data": result,
"currentSpread": result[-1]["spread"] if result else None,
"currentZScore": result[-1]["zScore"] if result else None,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

84
backend/api/portfolio.py Normal file
View File

@@ -0,0 +1,84 @@
"""Portfolio API endpoints."""
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Query
from datetime import datetime, timedelta
from ..core.dependencies import get_portfolio_tracker
from ..core.schemas import PortfolioResponse, PortfolioHistoryResponse
router = APIRouter()
# Import portfolio analytics
def get_portfolio_analytics():
"""Get portfolio analytics instance."""
from src.portfolio.analytics import get_portfolio_analytics as _get_analytics
return _get_analytics()
@router.get("/current", response_model=PortfolioResponse)
async def get_current_portfolio(
paper_trading: bool = True,
tracker=Depends(get_portfolio_tracker)
):
"""Get current portfolio state."""
try:
portfolio = await tracker.get_current_portfolio(paper_trading=paper_trading)
return PortfolioResponse(**portfolio)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/history", response_model=PortfolioHistoryResponse)
async def get_portfolio_history(
days: int = Query(30, ge=1, le=365),
paper_trading: bool = True,
tracker=Depends(get_portfolio_tracker)
):
"""Get portfolio history."""
try:
history = await tracker.get_portfolio_history(days=days, paper_trading=paper_trading)
dates = [h['timestamp'] if isinstance(h['timestamp'], str) else h['timestamp'].isoformat()
for h in history]
values = [float(h['total_value']) for h in history]
pnl = [float(h.get('total_pnl', 0)) for h in history]
return PortfolioHistoryResponse(
dates=dates,
values=values,
pnl=pnl
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/positions/update-prices")
async def update_positions_prices(
prices: dict,
paper_trading: bool = True,
tracker=Depends(get_portfolio_tracker)
):
"""Update current prices for positions."""
try:
from decimal import Decimal
price_dict = {k: Decimal(str(v)) for k, v in prices.items()}
await tracker.update_positions_prices(price_dict, paper_trading=paper_trading)
return {"status": "updated"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/risk-metrics")
async def get_risk_metrics(
days: int = Query(30, ge=1, le=365),
paper_trading: bool = True,
analytics=Depends(get_portfolio_analytics)
):
"""Get portfolio risk metrics."""
try:
metrics = await analytics.get_performance_metrics(days=days, paper_trading=paper_trading)
return metrics
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

272
backend/api/reporting.py Normal file
View File

@@ -0,0 +1,272 @@
"""Reporting API endpoints for CSV and PDF export."""
from fastapi import APIRouter, HTTPException, Query, Body
from fastapi.responses import StreamingResponse
from typing import Optional, Dict, Any
from datetime import datetime
from sqlalchemy import select
import io
import csv
import tempfile
from pathlib import Path
from src.core.database import Trade, get_database
router = APIRouter()
def get_csv_exporter():
"""Get CSV exporter instance."""
from src.reporting.csv_exporter import get_csv_exporter as _get_csv_exporter
return _get_csv_exporter()
def get_pdf_generator():
"""Get PDF generator instance."""
from src.reporting.pdf_generator import get_pdf_generator as _get_pdf_generator
return _get_pdf_generator()
@router.post("/backtest/csv")
async def export_backtest_csv(
results: Dict[str, Any] = Body(...),
):
"""Export backtest results as CSV."""
try:
output = io.StringIO()
writer = csv.writer(output)
# Write header
writer.writerow(['Metric', 'Value'])
# Write metrics
writer.writerow(['Total Return', f"{(results.get('total_return', 0) * 100):.2f}%"])
writer.writerow(['Sharpe Ratio', f"{results.get('sharpe_ratio', 0):.2f}"])
writer.writerow(['Sortino Ratio', f"{results.get('sortino_ratio', 0):.2f}"])
writer.writerow(['Max Drawdown', f"{(results.get('max_drawdown', 0) * 100):.2f}%"])
writer.writerow(['Win Rate', f"{(results.get('win_rate', 0) * 100):.2f}%"])
writer.writerow(['Total Trades', results.get('total_trades', 0)])
writer.writerow(['Final Value', f"${results.get('final_value', 0):.2f}"])
writer.writerow(['Initial Capital', f"${results.get('initial_capital', 0):.2f}"])
# Write trades if available
if results.get('trades'):
writer.writerow([])
writer.writerow(['Trades'])
writer.writerow(['Timestamp', 'Side', 'Price', 'Quantity', 'Value'])
for trade in results['trades']:
writer.writerow([
trade.get('timestamp', ''),
trade.get('side', ''),
f"${trade.get('price', 0):.2f}",
trade.get('quantity', 0),
f"${(trade.get('price', 0) * trade.get('quantity', 0)):.2f}",
])
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=backtest_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/backtest/pdf")
async def export_backtest_pdf(
results: Dict[str, Any] = Body(...),
):
"""Export backtest results as PDF."""
try:
pdf_generator = get_pdf_generator()
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_path = Path(tmp_file.name)
# Convert results to metrics format expected by PDF generator
metrics = {
'total_return_percent': (results.get('total_return', 0) * 100),
'sharpe_ratio': results.get('sharpe_ratio', 0),
'sortino_ratio': results.get('sortino_ratio', 0),
'max_drawdown': results.get('max_drawdown', 0),
'win_rate': results.get('win_rate', 0),
}
# Generate PDF
success = pdf_generator.generate_performance_report(
tmp_path,
metrics,
"Backtest Report"
)
if not success:
raise HTTPException(status_code=500, detail="Failed to generate PDF")
# Read PDF and return as stream
with open(tmp_path, 'rb') as f:
pdf_content = f.read()
# Clean up
tmp_path.unlink()
return StreamingResponse(
io.BytesIO(pdf_content),
media_type="application/pdf",
headers={"Content-Disposition": f"attachment; filename=backtest_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/trades/csv")
async def export_trades_csv(
start_date: Optional[str] = None,
end_date: Optional[str] = None,
paper_trading: bool = True,
):
"""Export trades as CSV."""
try:
csv_exporter = get_csv_exporter()
# Parse dates if provided
start = datetime.fromisoformat(start_date) if start_date else None
end = datetime.fromisoformat(end_date) if end_date else None
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as tmp_file:
tmp_path = Path(tmp_file.name)
# Export to file
success = csv_exporter.export_trades(
filepath=tmp_path,
paper_trading=paper_trading,
start_date=start,
end_date=end
)
if not success:
raise HTTPException(status_code=500, detail="Failed to export trades")
# Read and return
with open(tmp_path, 'r') as f:
csv_content = f.read()
tmp_path.unlink()
return StreamingResponse(
iter([csv_content]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=trades_{datetime.now().strftime('%Y%m%d')}.csv"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/portfolio/csv")
async def export_portfolio_csv(
paper_trading: bool = True,
):
"""Export portfolio as CSV."""
try:
csv_exporter = get_csv_exporter()
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as tmp_file:
tmp_path = Path(tmp_file.name)
# Export to file
success = csv_exporter.export_portfolio(filepath=tmp_path)
if not success:
raise HTTPException(status_code=500, detail="Failed to export portfolio")
# Read and return
with open(tmp_path, 'r') as f:
csv_content = f.read()
tmp_path.unlink()
return StreamingResponse(
iter([csv_content]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=portfolio_{datetime.now().strftime('%Y%m%d')}.csv"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/tax/{method}")
async def generate_tax_report(
method: str, # fifo, lifo, specific_id
symbol: Optional[str] = Query(None),
year: Optional[int] = Query(None),
paper_trading: bool = Query(True),
):
"""Generate tax report using specified method."""
try:
if year is None:
year = datetime.now().year
tax_reporter = get_tax_reporter()
if method == "fifo":
if symbol:
events = tax_reporter.generate_fifo_report(symbol, year, paper_trading)
else:
# Generate for all symbols
events = []
# Get all symbols from trades
db = get_database()
async with db.get_session() as session:
stmt = select(Trade.symbol).distinct()
result = await session.execute(stmt)
symbols = result.scalars().all()
for sym in symbols:
events.extend(tax_reporter.generate_fifo_report(sym, year, paper_trading))
elif method == "lifo":
if symbol:
events = tax_reporter.generate_lifo_report(symbol, year, paper_trading)
else:
events = []
db = get_database()
async with db.get_session() as session:
stmt = select(Trade.symbol).distinct()
result = await session.execute(stmt)
symbols = result.scalars().all()
for sym in symbols:
events.extend(tax_reporter.generate_lifo_report(sym, year, paper_trading))
else:
raise HTTPException(status_code=400, detail=f"Unsupported tax method: {method}")
# Generate CSV
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(['Date', 'Symbol', 'Quantity', 'Cost Basis', 'Proceeds', 'Gain/Loss', 'Buy Date'])
for event in events:
writer.writerow([
event.get('date', ''),
event.get('symbol', ''),
event.get('quantity', 0),
event.get('cost_basis', 0),
event.get('proceeds', 0),
event.get('gain_loss', 0),
event.get('buy_date', ''),
])
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=tax_report_{method}_{year}_{datetime.now().strftime('%Y%m%d')}.csv"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def get_tax_reporter():
"""Get tax reporter instance."""
from src.reporting.tax_reporter import get_tax_reporter as _get_tax_reporter
return _get_tax_reporter()

155
backend/api/reports.py Normal file
View File

@@ -0,0 +1,155 @@
"""Reports API endpoints for background report generation."""
from typing import Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
router = APIRouter()
class ReportRequest(BaseModel):
"""Request model for report generation."""
report_type: str # 'performance', 'trades', 'tax', 'backtest'
format: str = "pdf" # 'pdf' or 'csv'
year: Optional[int] = None # For tax reports
method: Optional[str] = "fifo" # For tax reports: 'fifo', 'lifo'
class ExportRequest(BaseModel):
"""Request model for data export."""
export_type: str # 'orders', 'positions'
@router.post("/generate")
async def generate_report(request: ReportRequest):
"""Generate a report in the background.
This endpoint queues a report generation task and returns immediately.
Use /api/tasks/{task_id} to monitor progress.
Supported report types:
- performance: Portfolio performance report
- trades: Trade history export
- tax: Tax report with capital gains calculation
Returns:
Task ID for monitoring
"""
try:
from src.worker.tasks import generate_report_task
params = {
"format": request.format,
}
if request.report_type == "tax":
from datetime import datetime
params["year"] = request.year or datetime.now().year
params["method"] = request.method
task = generate_report_task.delay(
report_type=request.report_type,
params=params
)
return {
"task_id": task.id,
"status": "queued",
"report_type": request.report_type,
"format": request.format
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/export")
async def export_data(request: ExportRequest):
"""Export data in the background.
This endpoint queues a data export task and returns immediately.
Use /api/tasks/{task_id} to monitor progress.
Supported export types:
- orders: Order history
- positions: Current/historical positions
Returns:
Task ID for monitoring
"""
try:
from src.worker.tasks import export_data_task
task = export_data_task.delay(
export_type=request.export_type,
params={}
)
return {
"task_id": task.id,
"status": "queued",
"export_type": request.export_type
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/list")
async def list_reports():
"""List available reports in the reports directory."""
try:
from pathlib import Path
import os
reports_dir = Path(os.path.expanduser("~/.local/share/crypto_trader/reports"))
if not reports_dir.exists():
return {"reports": []}
reports = []
for f in reports_dir.iterdir():
if f.is_file():
stat = f.stat()
reports.append({
"name": f.name,
"path": str(f),
"size": stat.st_size,
"created": stat.st_mtime,
"type": f.suffix.lstrip(".")
})
# Sort by creation time, newest first
reports.sort(key=lambda x: x["created"], reverse=True)
return {"reports": reports}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{filename}")
async def delete_report(filename: str):
"""Delete a generated report."""
try:
from pathlib import Path
import os
reports_dir = Path(os.path.expanduser("~/.local/share/crypto_trader/reports"))
filepath = reports_dir / filename
if not filepath.exists():
raise HTTPException(status_code=404, detail="Report not found")
# Security check: ensure the file is actually in reports dir
if not str(filepath.resolve()).startswith(str(reports_dir.resolve())):
raise HTTPException(status_code=403, detail="Access denied")
filepath.unlink()
return {"status": "deleted", "filename": filename}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

359
backend/api/settings.py Normal file
View File

@@ -0,0 +1,359 @@
from typing import Dict, Any, Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from src.core.database import Exchange, get_database
from src.core.config import get_config
from src.security.key_manager import get_key_manager
from src.trading.paper_trading import get_paper_trading
router = APIRouter()
class RiskSettings(BaseModel):
max_drawdown_percent: float
daily_loss_limit_percent: float
position_size_percent: float
class PaperTradingSettings(BaseModel):
initial_capital: float
fee_exchange: str = "coinbase" # Which exchange's fee model to use
class LoggingSettings(BaseModel):
level: str
dir: str
retention_days: int
class GeneralSettings(BaseModel):
timezone: str = "UTC"
theme: str = "dark"
currency: str = "USD"
class ExchangeCreate(BaseModel):
name: str
api_key: Optional[str] = None
api_secret: Optional[str] = None
sandbox: bool = False
read_only: bool = True
enabled: bool = True
class ExchangeUpdate(BaseModel):
api_key: Optional[str] = None
api_secret: Optional[str] = None
sandbox: Optional[bool] = None
read_only: Optional[bool] = None
enabled: Optional[bool] = None
@router.get("/risk")
async def get_risk_settings():
"""Get risk management settings."""
try:
config = get_config()
return {
"max_drawdown_percent": config.get("risk.max_drawdown_percent", 20.0),
"daily_loss_limit_percent": config.get("risk.daily_loss_limit_percent", 5.0),
"position_size_percent": config.get("risk.position_size_percent", 2.0),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/risk")
async def update_risk_settings(settings: RiskSettings):
"""Update risk management settings."""
try:
config = get_config()
config.set("risk.max_drawdown_percent", settings.max_drawdown_percent)
config.set("risk.daily_loss_limit_percent", settings.daily_loss_limit_percent)
config.set("risk.position_size_percent", settings.position_size_percent)
config.save()
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/paper-trading")
async def get_paper_trading_settings():
"""Get paper trading settings."""
try:
config = get_config()
fee_exchange = config.get("paper_trading.fee_exchange", "coinbase")
# Get fee rates for current exchange
fee_rates = config.get(f"trading.exchanges.{fee_exchange}.fees",
config.get("trading.default_fees", {"maker": 0.001, "taker": 0.001}))
return {
"initial_capital": config.get("paper_trading.default_capital", 100.0),
"fee_exchange": fee_exchange,
"fee_rates": fee_rates,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/paper-trading")
async def update_paper_trading_settings(settings: PaperTradingSettings):
"""Update paper trading settings."""
try:
config = get_config()
config.set("paper_trading.default_capital", settings.initial_capital)
config.set("paper_trading.fee_exchange", settings.fee_exchange)
config.save()
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/paper-trading/fee-exchanges")
async def get_available_fee_exchanges():
"""Get available exchange fee models for paper trading."""
try:
config = get_config()
exchanges_config = config.get("trading.exchanges", {})
default_fees = config.get("trading.default_fees", {"maker": 0.001, "taker": 0.001})
current = config.get("paper_trading.fee_exchange", "coinbase")
exchanges = [{"name": "default", "fees": default_fees}]
for name, data in exchanges_config.items():
if "fees" in data:
exchanges.append({
"name": name,
"fees": data["fees"]
})
return {
"exchanges": exchanges,
"current": current,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/paper-trading/reset")
async def reset_paper_account():
"""Reset paper trading account."""
try:
paper_trading = get_paper_trading()
# Reset to capital from config
success = await paper_trading.reset()
if success:
return {"status": "success", "message": "Paper account reset successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to reset paper account")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/logging")
async def get_logging_settings():
"""Get logging settings."""
try:
config = get_config()
return {
"level": config.get("logging.level", "INFO"),
"dir": config.get("logging.dir", ""),
"retention_days": config.get("logging.retention_days", 30),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/logging")
async def update_logging_settings(settings: LoggingSettings):
"""Update logging settings."""
try:
config = get_config()
config.set("logging.level", settings.level)
config.set("logging.dir", settings.dir)
config.set("logging.retention_days", settings.retention_days)
config.save()
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/general")
async def get_general_settings():
"""Get general settings."""
try:
config = get_config()
return {
"timezone": config.get("general.timezone", "UTC"),
"theme": config.get("general.theme", "dark"),
"currency": config.get("general.currency", "USD"),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/general")
async def update_general_settings(settings: GeneralSettings):
"""Update general settings."""
try:
config = get_config()
config.set("general.timezone", settings.timezone)
config.set("general.theme", settings.theme)
config.set("general.currency", settings.currency)
config.save()
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/exchanges")
async def create_exchange(exchange: ExchangeCreate):
"""Create a new exchange."""
try:
db = get_database()
async with db.get_session() as session:
from src.exchanges.factory import ExchangeFactory
from src.exchanges.public_data import PublicDataAdapter
# Check if this is a public data exchange
adapter_class = None
try:
if hasattr(ExchangeFactory, '_adapters'):
adapter_class = ExchangeFactory._adapters.get(exchange.name.lower())
except:
pass
is_public_data = adapter_class == PublicDataAdapter if adapter_class else False
# Only require API keys for non-public-data exchanges
if not is_public_data:
if not exchange.api_key or not exchange.api_secret:
raise HTTPException(status_code=400, detail="API key and secret are required")
new_exchange = Exchange(
name=exchange.name,
enabled=exchange.enabled
)
session.add(new_exchange)
await session.flush()
# Save credentials
key_manager = get_key_manager()
key_manager.update_exchange(
new_exchange.id,
api_key=exchange.api_key or "",
api_secret=exchange.api_secret or "",
read_only=exchange.read_only if not is_public_data else True,
sandbox=exchange.sandbox if not is_public_data else False
)
await session.commit()
return {"id": new_exchange.id, "name": new_exchange.name, "status": "created"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/exchanges/{exchange_id}")
async def update_exchange(exchange_id: int, exchange: ExchangeUpdate):
"""Update an existing exchange."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Exchange).where(Exchange.id == exchange_id)
result = await session.execute(stmt)
exchange_obj = result.scalar_one_or_none()
if not exchange_obj:
raise HTTPException(status_code=404, detail="Exchange not found")
key_manager = get_key_manager()
credentials = key_manager.get_exchange_credentials(exchange_id)
# Update credentials if provided
if exchange.api_key is not None or exchange.api_secret is not None:
key_manager.update_exchange(
exchange_id,
api_key=exchange.api_key or credentials.get('api_key', ''),
api_secret=exchange.api_secret or credentials.get('api_secret', ''),
read_only=exchange.read_only if exchange.read_only is not None else credentials.get('read_only', True),
sandbox=exchange.sandbox if exchange.sandbox is not None else credentials.get('sandbox', False)
)
if exchange.enabled is not None:
exchange_obj.enabled = exchange.enabled
await session.commit()
return {"status": "success"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/exchanges/{exchange_id}")
async def delete_exchange(exchange_id: int):
"""Delete an exchange."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Exchange).where(Exchange.id == exchange_id)
result = await session.execute(stmt)
exchange = result.scalar_one_or_none()
if not exchange:
raise HTTPException(status_code=404, detail="Exchange not found")
await session.delete(exchange)
await session.commit()
return {"status": "success"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/exchanges/{exchange_id}/test")
async def test_exchange_connection(exchange_id: int):
"""Test exchange connection."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Exchange).where(Exchange.id == exchange_id)
result = await session.execute(stmt)
exchange = result.scalar_one_or_none()
if not exchange:
raise HTTPException(status_code=404, detail="Exchange not found")
from src.exchanges.factory import ExchangeFactory
from src.exchanges.public_data import PublicDataAdapter
# Get adapter class safely
adapter_class = None
try:
if hasattr(ExchangeFactory, '_adapters'):
adapter_class = ExchangeFactory._adapters.get(exchange.name.lower())
except:
pass
is_public_data = adapter_class == PublicDataAdapter if adapter_class else False
key_manager = get_key_manager()
if not is_public_data and not key_manager.get_exchange_credentials(exchange_id):
raise HTTPException(status_code=400, detail="No credentials found for this exchange")
adapter = ExchangeFactory.create(exchange_id)
if adapter and adapter.connect():
try:
if is_public_data:
adapter.get_ticker("BTC/USDT")
else:
adapter.get_balance()
return {"status": "success", "message": "Connection successful"}
except Exception as e:
return {"status": "error", "message": f"Connected but failed to fetch data: {str(e)}"}
else:
return {"status": "error", "message": "Failed to connect"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

310
backend/api/strategies.py Normal file
View File

@@ -0,0 +1,310 @@
"""Strategy API endpoints."""
from typing import List
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy import select
from ..core.dependencies import get_strategy_registry
from ..core.schemas import (
StrategyCreate, StrategyUpdate, StrategyResponse
)
from src.core.database import Strategy, get_database
from src.strategies.scheduler import get_scheduler as _get_scheduler
def get_strategy_scheduler():
"""Get strategy scheduler instance."""
return _get_scheduler()
router = APIRouter()
@router.get("/", response_model=List[StrategyResponse])
async def list_strategies():
"""List all strategies."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Strategy).order_by(Strategy.created_at.desc())
result = await session.execute(stmt)
strategies = result.scalars().all()
return [StrategyResponse.model_validate(s) for s in strategies]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/available")
async def list_available_strategies(
registry=Depends(get_strategy_registry)
):
"""List available strategy types."""
try:
return {"strategies": registry.list_available()}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/", response_model=StrategyResponse)
async def create_strategy(strategy_data: StrategyCreate):
"""Create a new strategy."""
try:
db = get_database()
async with db.get_session() as session:
strategy = Strategy(
name=strategy_data.name,
description=strategy_data.description,
strategy_type=strategy_data.strategy_type,
class_name=strategy_data.class_name,
parameters=strategy_data.parameters,
timeframes=strategy_data.timeframes,
paper_trading=strategy_data.paper_trading,
schedule=strategy_data.schedule,
enabled=False
)
session.add(strategy)
await session.commit()
await session.refresh(strategy)
return StrategyResponse.model_validate(strategy)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{strategy_id}", response_model=StrategyResponse)
async def get_strategy(strategy_id: int):
"""Get strategy by ID."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Strategy).where(Strategy.id == strategy_id)
result = await session.execute(stmt)
strategy = result.scalar_one_or_none()
if not strategy:
raise HTTPException(status_code=404, detail="Strategy not found")
return StrategyResponse.model_validate(strategy)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/{strategy_id}", response_model=StrategyResponse)
async def update_strategy(strategy_id: int, strategy_data: StrategyUpdate):
"""Update a strategy."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Strategy).where(Strategy.id == strategy_id)
result = await session.execute(stmt)
strategy = result.scalar_one_or_none()
if not strategy:
raise HTTPException(status_code=404, detail="Strategy not found")
if strategy_data.name is not None:
strategy.name = strategy_data.name
if strategy_data.description is not None:
strategy.description = strategy_data.description
if strategy_data.parameters is not None:
strategy.parameters = strategy_data.parameters
if strategy_data.timeframes is not None:
strategy.timeframes = strategy_data.timeframes
if strategy_data.enabled is not None:
strategy.enabled = strategy_data.enabled
if strategy_data.schedule is not None:
strategy.schedule = strategy_data.schedule
await session.commit()
await session.refresh(strategy)
return StrategyResponse.model_validate(strategy)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{strategy_id}")
async def delete_strategy(strategy_id: int):
"""Delete a strategy."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Strategy).where(Strategy.id == strategy_id)
result = await session.execute(stmt)
strategy = result.scalar_one_or_none()
if not strategy:
raise HTTPException(status_code=404, detail="Strategy not found")
await session.delete(strategy)
await session.commit()
return {"status": "deleted", "strategy_id": strategy_id}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{strategy_id}/start")
async def start_strategy(strategy_id: int):
"""Start a strategy manually (bypasses Autopilot)."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Strategy).where(Strategy.id == strategy_id)
result = await session.execute(stmt)
strategy = result.scalar_one_or_none()
if not strategy:
raise HTTPException(status_code=404, detail="Strategy not found")
# Start strategy via scheduler
scheduler = get_strategy_scheduler()
scheduler.start_strategy(strategy_id)
strategy.running = True # Only set running, not enabled
await session.commit()
return {"status": "started", "strategy_id": strategy_id}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{strategy_id}/stop")
async def stop_strategy(strategy_id: int):
"""Stop a manually running strategy."""
try:
db = get_database()
async with db.get_session() as session:
stmt = select(Strategy).where(Strategy.id == strategy_id)
result = await session.execute(stmt)
strategy = result.scalar_one_or_none()
if not strategy:
raise HTTPException(status_code=404, detail="Strategy not found")
# Stop strategy via scheduler
scheduler = get_strategy_scheduler()
scheduler.stop_strategy(strategy_id)
strategy.running = False # Only set running, not enabled
await session.commit()
return {"status": "stopped", "strategy_id": strategy_id}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{strategy_type}/optimize")
async def optimize_strategy(
strategy_type: str,
symbol: str = "BTC/USD",
method: str = "genetic",
population_size: int = 50,
generations: int = 100
):
"""Optimize strategy parameters using genetic algorithm.
This endpoint queues an optimization task and returns immediately.
Use /api/tasks/{task_id} to monitor progress.
Args:
strategy_type: Type of strategy to optimize (e.g., 'rsi', 'macd')
symbol: Trading symbol for backtesting
method: Optimization method ('genetic', 'grid')
population_size: Population size for genetic algorithm
generations: Number of generations
Returns:
Task ID for monitoring
"""
try:
from src.worker.tasks import optimize_strategy_task
# Get parameter ranges for the strategy type
registry = get_strategy_registry()
strategy_class = registry.get(strategy_type)
if not strategy_class:
raise HTTPException(status_code=404, detail=f"Strategy type '{strategy_type}' not found")
# Default parameter ranges based on strategy type
param_ranges = {
"rsi": {"period": (5, 50), "overbought": (60, 90), "oversold": (10, 40)},
"macd": {"fast_period": (5, 20), "slow_period": (15, 40), "signal_period": (5, 15)},
"moving_average": {"short_period": (5, 30), "long_period": (20, 100)},
"bollinger_mean_reversion": {"period": (10, 50), "std_dev": (1.5, 3.0)},
}.get(strategy_type.lower(), {"period": (10, 50)})
# Queue the optimization task
task = optimize_strategy_task.delay(
strategy_type=strategy_type,
symbol=symbol,
param_ranges=param_ranges,
method=method,
population_size=population_size,
generations=generations
)
return {
"task_id": task.id,
"status": "queued",
"strategy_type": strategy_type,
"symbol": symbol,
"method": method
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{strategy_id}/status")
async def get_strategy_status(strategy_id: int):
"""Get real-time status of a running strategy.
Returns execution info including last tick time, last signal, and stats.
"""
try:
scheduler = get_strategy_scheduler()
status = scheduler.get_strategy_status(strategy_id)
if not status:
# Check if strategy exists but isn't running
db = get_database()
async with db.get_session() as session:
stmt = select(Strategy).where(Strategy.id == strategy_id)
result = await session.execute(stmt)
strategy = result.scalar_one_or_none()
if not strategy:
raise HTTPException(status_code=404, detail="Strategy not found")
return {
"strategy_id": strategy_id,
"name": strategy.name,
"type": strategy.strategy_type,
"symbol": strategy.parameters.get('symbol') if strategy.parameters else None,
"running": False,
"enabled": strategy.enabled,
}
return status
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/running/all")
async def get_all_running_strategies():
"""Get status of all currently running strategies."""
try:
scheduler = get_strategy_scheduler()
active = scheduler.get_all_active_strategies()
return {
"total_running": len(active),
"strategies": active
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

206
backend/api/trading.py Normal file
View File

@@ -0,0 +1,206 @@
"""Trading API endpoints."""
from decimal import Decimal
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from ..core.dependencies import get_trading_engine, get_db_session
from ..core.schemas import OrderCreate, OrderResponse, PositionResponse
from src.core.database import Order, OrderSide, OrderType, OrderStatus
from src.core.repositories import OrderRepository, PositionRepository
from src.core.logger import get_logger
from src.trading.paper_trading import get_paper_trading
router = APIRouter()
logger = get_logger(__name__)
@router.post("/orders", response_model=OrderResponse)
async def create_order(
order_data: OrderCreate,
trading_engine=Depends(get_trading_engine)
):
"""Create and execute a trading order."""
try:
# Convert string enums to actual enums
side = OrderSide(order_data.side.value)
order_type = OrderType(order_data.order_type.value)
order = await trading_engine.execute_order(
exchange_id=order_data.exchange_id,
strategy_id=order_data.strategy_id,
symbol=order_data.symbol,
side=side,
order_type=order_type,
quantity=order_data.quantity,
price=order_data.price,
paper_trading=order_data.paper_trading
)
if not order:
raise HTTPException(status_code=400, detail="Order execution failed")
return OrderResponse.model_validate(order)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/orders", response_model=List[OrderResponse])
async def get_orders(
paper_trading: bool = True,
limit: int = 100,
db: AsyncSession = Depends(get_db_session)
):
"""Get order history."""
try:
repo = OrderRepository(db)
orders = await repo.get_all(limit=limit)
# Filter by paper_trading in memory or add to repo method (repo returns all by default currently sorted by date)
# Let's verify repo method. It has limit/offset but not filtering.
# We should filter here or improve repo.
# For this refactor, let's filter in python for simplicity or assume get_all needs an update.
# Ideally, update repo. But strictly following "get_all" contract.
filtered = [o for o in orders if o.paper_trading == paper_trading]
return [OrderResponse.model_validate(order) for order in filtered]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/orders/{order_id}", response_model=OrderResponse)
async def get_order(
order_id: int,
db: AsyncSession = Depends(get_db_session)
):
"""Get order by ID."""
try:
repo = OrderRepository(db)
order = await repo.get_by_id(order_id)
if not order:
raise HTTPException(status_code=404, detail="Order not found")
return OrderResponse.model_validate(order)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/orders/{order_id}/cancel")
async def cancel_order(
order_id: int,
trading_engine=Depends(get_trading_engine)
):
try:
# We can use trading_engine's cancel which handles DB and Exchange
success = await trading_engine.cancel_order(order_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to cancel order or order not found")
return {"status": "cancelled", "order_id": order_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/orders/cancel-all")
async def cancel_all_orders(
paper_trading: bool = True,
trading_engine=Depends(get_trading_engine),
db: AsyncSession = Depends(get_db_session)
):
"""Cancel all open orders."""
try:
repo = OrderRepository(db)
open_orders = await repo.get_open_orders(paper_trading=paper_trading)
if not open_orders:
return {"status": "no_orders", "cancelled_count": 0}
cancelled_count = 0
failed_count = 0
for order in open_orders:
try:
if await trading_engine.cancel_order(order.id):
cancelled_count += 1
else:
failed_count += 1
except Exception as e:
logger.error(f"Failed to cancel order {order.id}: {e}")
failed_count += 1
return {
"status": "completed",
"cancelled_count": cancelled_count,
"failed_count": failed_count,
"total_orders": len(open_orders)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/positions", response_model=List[PositionResponse])
async def get_positions(
paper_trading: bool = True,
db: AsyncSession = Depends(get_db_session)
):
"""Get current positions."""
try:
if paper_trading:
paper_trading_sim = get_paper_trading()
positions = paper_trading_sim.get_positions()
# positions is a List[Position], convert to PositionResponse list
position_list = []
for pos in positions:
# pos is a Position database object
current_price = pos.current_price if pos.current_price else pos.entry_price
unrealized_pnl = pos.unrealized_pnl if pos.unrealized_pnl else Decimal(0)
realized_pnl = pos.realized_pnl if pos.realized_pnl else Decimal(0)
position_list.append(
PositionResponse(
symbol=pos.symbol,
quantity=pos.quantity,
entry_price=pos.entry_price,
current_price=current_price,
unrealized_pnl=unrealized_pnl,
realized_pnl=realized_pnl
)
)
return position_list
else:
# Live trading positions from database
repo = PositionRepository(db)
positions = await repo.get_all(paper_trading=False)
return [
PositionResponse(
symbol=pos.symbol,
quantity=pos.quantity,
entry_price=pos.entry_price,
current_price=pos.current_price or pos.entry_price,
unrealized_pnl=pos.unrealized_pnl,
realized_pnl=pos.realized_pnl
)
for pos in positions
]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/balance")
async def get_balance(paper_trading: bool = True):
"""Get account balance."""
try:
paper_trading_sim = get_paper_trading()
if paper_trading:
balance = paper_trading_sim.get_balance()
performance = paper_trading_sim.get_performance()
return {
"balance": float(balance),
"performance": performance
}
else:
# Live trading balance from exchange
return {"balance": 0.0, "performance": {}}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

242
backend/api/websocket.py Normal file
View File

@@ -0,0 +1,242 @@
"""WebSocket endpoints for real-time updates."""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import List, Dict, Set, Callable, Optional
import json
import asyncio
from datetime import datetime
from decimal import Decimal
from collections import deque
from ..core.schemas import PriceUpdate, OrderUpdate
from src.data.pricing_service import get_pricing_service
router = APIRouter()
class ConnectionManager:
"""Manages WebSocket connections."""
def __init__(self):
self.active_connections: List[WebSocket] = []
self.subscribed_symbols: Set[str] = set()
self._pricing_service = None
self._price_callbacks: Dict[str, List[Callable]] = {}
# Queue for price updates (thread-safe for async processing)
self._price_update_queue: deque = deque()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._processing_task = None
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
"""Set the event loop for async operations."""
self._loop = loop
async def start_background_tasks(self):
"""Start background processing tasks."""
if self._processing_task is None or self._processing_task.done():
self._processing_task = asyncio.create_task(self._process_queue())
print("WebSocket manager background tasks started")
async def _process_queue(self):
"""Periodically process price updates from queue."""
while True:
try:
if self._price_update_queue:
# Process up to 10 updates at a time to prevent blocking
for _ in range(10):
if not self._price_update_queue:
break
update = self._price_update_queue.popleft()
await self.broadcast_price_update(
exchange=update["exchange"],
symbol=update["symbol"],
price=update["price"]
)
await asyncio.sleep(0.01) # Check queue frequently but yield
except Exception as e:
print(f"Error processing price update queue: {e}")
await asyncio.sleep(1)
def _initialize_pricing_service(self):
"""Initialize pricing service and subscribe to price updates."""
if self._pricing_service is None:
self._pricing_service = get_pricing_service()
def subscribe_to_symbol(self, symbol: str):
"""Subscribe to price updates for a symbol."""
self._initialize_pricing_service()
if symbol not in self.subscribed_symbols:
self.subscribed_symbols.add(symbol)
def price_callback(data):
"""Callback for price updates from pricing service."""
# Store update in queue for async processing
update = {
"exchange": "pricing",
"symbol": data.get('symbol', symbol),
"price": Decimal(str(data.get('price', 0)))
}
self._price_update_queue.append(update)
# Note: We rely on the background task to process this queue now
self._pricing_service.subscribe_ticker(symbol, price_callback)
async def _process_price_update(self, update: Dict):
"""Process a price update asynchronously.
DEPRECATED: Use _process_queue instead.
"""
await self.broadcast_price_update(
exchange=update["exchange"],
symbol=update["symbol"],
price=update["price"]
)
def unsubscribe_from_symbol(self, symbol: str):
"""Unsubscribe from price updates for a symbol."""
if symbol in self.subscribed_symbols:
self.subscribed_symbols.remove(symbol)
if self._pricing_service:
self._pricing_service.unsubscribe_ticker(symbol)
async def connect(self, websocket: WebSocket):
"""Add WebSocket connection to active connections.
Note: websocket.accept() must be called before this method.
"""
self.active_connections.append(websocket)
# Ensure background task is running
await self.start_background_tasks()
def disconnect(self, websocket: WebSocket):
"""Remove a WebSocket connection."""
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def broadcast_price_update(self, exchange: str, symbol: str, price: Decimal):
"""Broadcast price update to all connected clients."""
message = PriceUpdate(
exchange=exchange,
symbol=symbol,
price=price,
timestamp=datetime.utcnow()
)
await self.broadcast(message.dict())
async def broadcast_order_update(self, order_id: int, status: str, filled_quantity: Decimal = None):
"""Broadcast order update to all connected clients."""
message = OrderUpdate(
order_id=order_id,
status=status,
filled_quantity=filled_quantity,
timestamp=datetime.utcnow()
)
await self.broadcast(message.dict())
async def broadcast_training_progress(self, step: str, progress: int, total: int, message: str, details: dict = None):
"""Broadcast training progress update to all connected clients."""
update = {
"type": "training_progress",
"step": step,
"progress": progress,
"total": total,
"percent": int((progress / total) * 100) if total > 0 else 0,
"message": message,
"details": details or {},
"timestamp": datetime.utcnow().isoformat()
}
await self.broadcast(update)
async def broadcast(self, message: dict):
"""Broadcast message to all connected clients."""
disconnected = []
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception:
disconnected.append(connection)
# Remove disconnected clients
for conn in disconnected:
self.disconnect(conn)
manager = ConnectionManager()
@router.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time updates."""
# Check origin for CORS before accepting
origin = websocket.headers.get("origin")
allowed_origins = ["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:3000", "http://127.0.0.1:5173"]
# Allow connections from allowed origins or if no origin header (direct connections, testing)
# Relaxed check: Log warning but allow if origin doesn't match, to prevent disconnection issues in some environments
if origin and origin not in allowed_origins:
print(f"Warning: WebSocket connection from unknown origin: {origin}")
# We allow it for now to fix disconnection issues, but normally we might block
# await websocket.close(code=1008, reason="Origin not allowed")
# return
# Accept the connection
await websocket.accept()
try:
# Connect to manager (starts background tasks if needed)
await manager.connect(websocket)
subscribed_symbols = set()
while True:
# Receive messages from client (for subscriptions, etc.)
data = await websocket.receive_text()
try:
message = json.loads(data)
# Handle subscription messages
if message.get("type") == "subscribe":
symbol = message.get("symbol")
if symbol:
subscribed_symbols.add(symbol)
manager.subscribe_to_symbol(symbol)
await websocket.send_json({
"type": "subscription_confirmed",
"symbol": symbol
})
elif message.get("type") == "unsubscribe":
symbol = message.get("symbol")
if symbol and symbol in subscribed_symbols:
subscribed_symbols.remove(symbol)
manager.unsubscribe_from_symbol(symbol)
await websocket.send_json({
"type": "unsubscription_confirmed",
"symbol": symbol
})
else:
# Default acknowledgment
await websocket.send_json({"type": "ack", "message": "received"})
except json.JSONDecodeError:
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
except Exception as e:
# Don't send internal errors to client in production, but okay for debugging
await websocket.send_json({"type": "error", "message": str(e)})
except WebSocketDisconnect:
# Clean up subscriptions
for symbol in subscribed_symbols:
manager.unsubscribe_from_symbol(symbol)
manager.disconnect(websocket)
except Exception as e:
manager.disconnect(websocket)
print(f"WebSocket error: {e}")
# Only close if not already closed
try:
await websocket.close(code=1011)
except Exception:
pass