Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
564
backend/api/autopilot.py
Normal file
564
backend/api/autopilot.py
Normal file
@@ -0,0 +1,564 @@
|
||||
"""AutoPilot API endpoints."""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..core.dependencies import get_database
|
||||
from ..core.schemas import OrderSide
|
||||
from src.core.database import get_database as get_db
|
||||
# Import autopilot - path should be set up in main.py
|
||||
from src.autopilot import (
|
||||
stop_all_autopilots,
|
||||
get_intelligent_autopilot,
|
||||
get_strategy_selector,
|
||||
get_performance_tracker,
|
||||
get_performance_tracker,
|
||||
get_autopilot_mode_info,
|
||||
)
|
||||
from src.worker.tasks import train_model_task
|
||||
from src.core.config import get_config
|
||||
from celery.result import AsyncResult
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class BootstrapConfig(BaseModel):
|
||||
"""Bootstrap training data configuration."""
|
||||
days: int = 90
|
||||
timeframe: str = "1h"
|
||||
min_samples_per_strategy: int = 10
|
||||
symbols: List[str] = ["BTC/USD", "ETH/USD"]
|
||||
|
||||
|
||||
class MultiSymbolAutopilotConfig(BaseModel):
|
||||
"""Multi-symbol autopilot configuration."""
|
||||
symbols: List[str]
|
||||
mode: str = "intelligent"
|
||||
auto_execute: bool = False
|
||||
timeframe: str = "1h"
|
||||
exchange_id: int = 1
|
||||
paper_trading: bool = True
|
||||
interval: float = 60.0
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Intelligent Autopilot Endpoints
|
||||
# =============================================================================
|
||||
|
||||
class IntelligentAutopilotConfig(BaseModel):
|
||||
symbol: str
|
||||
exchange_id: int = 1
|
||||
timeframe: str = "1h"
|
||||
interval: float = 60.0
|
||||
paper_trading: bool = True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Unified Autopilot Endpoints
|
||||
# =============================================================================
|
||||
|
||||
class UnifiedAutopilotConfig(BaseModel):
|
||||
"""Unified autopilot configuration (Inteligent Mode)."""
|
||||
symbol: str
|
||||
mode: str = "intelligent" # Kept for compatibility but only "intelligent" is supported
|
||||
auto_execute: bool = False
|
||||
interval: float = 60.0
|
||||
exchange_id: int = 1
|
||||
timeframe: str = "1h"
|
||||
paper_trading: bool = True
|
||||
|
||||
|
||||
@router.post("/intelligent/start", deprecated=True)
|
||||
async def start_intelligent_autopilot(
|
||||
config: IntelligentAutopilotConfig,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Start the Intelligent Autopilot engine.
|
||||
|
||||
.. deprecated:: Use /start-unified instead with mode='intelligent'
|
||||
"""
|
||||
try:
|
||||
autopilot = get_intelligent_autopilot(
|
||||
symbol=config.symbol,
|
||||
exchange_id=config.exchange_id,
|
||||
timeframe=config.timeframe,
|
||||
interval=config.interval,
|
||||
paper_trading=config.paper_trading
|
||||
)
|
||||
|
||||
if not autopilot.is_running:
|
||||
background_tasks.add_task(autopilot.start)
|
||||
|
||||
return {
|
||||
"status": "started",
|
||||
"symbol": config.symbol,
|
||||
"timeframe": config.timeframe
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/intelligent/stop")
|
||||
async def stop_intelligent_autopilot(symbol: str, timeframe: str = "1h"):
|
||||
"""Stop the Intelligent Autopilot engine."""
|
||||
try:
|
||||
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
|
||||
autopilot.stop()
|
||||
return {"status": "stopped", "symbol": symbol}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/intelligent/status/{symbol:path}")
|
||||
async def get_intelligent_status(symbol: str, timeframe: str = "1h"):
|
||||
"""Get Intelligent Autopilot status."""
|
||||
try:
|
||||
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
|
||||
return autopilot.get_status()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/intelligent/performance")
|
||||
async def get_intelligent_performance(
|
||||
strategy_name: Optional[str] = None,
|
||||
days: int = 30
|
||||
):
|
||||
"""Get strategy performance metrics."""
|
||||
try:
|
||||
tracker = get_performance_tracker()
|
||||
if strategy_name:
|
||||
metrics = tracker.calculate_metrics(strategy_name, period_days=days)
|
||||
return {"strategy": strategy_name, "metrics": metrics}
|
||||
else:
|
||||
# Get all strategies
|
||||
history = tracker.get_performance_history(days=days)
|
||||
if history.empty:
|
||||
return {"strategies": []}
|
||||
|
||||
strategies = history['strategy_name'].unique()
|
||||
results = {}
|
||||
for strat in strategies:
|
||||
results[strat] = tracker.calculate_metrics(strat, period_days=days)
|
||||
|
||||
return {"strategies": results}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/intelligent/training-stats")
|
||||
async def get_training_stats(days: int = 365):
|
||||
"""Get statistics about available training data.
|
||||
|
||||
Returns:
|
||||
Dictionary with total samples and per-strategy counts
|
||||
"""
|
||||
try:
|
||||
tracker = get_performance_tracker()
|
||||
counts = await tracker.get_strategy_sample_counts(days=days)
|
||||
|
||||
return {
|
||||
"total_samples": sum(counts.values()),
|
||||
"strategy_counts": counts
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/intelligent/retrain")
|
||||
async def retrain_model(force: bool = False, bootstrap: bool = True):
|
||||
"""Manually trigger model retraining (Background Task).
|
||||
|
||||
Offloads training to Celery worker.
|
||||
"""
|
||||
try:
|
||||
# Get all bootstrap config to pass to worker
|
||||
config = get_config()
|
||||
symbols = config.get("autopilot.intelligent.bootstrap.symbols", ["BTC/USD", "ETH/USD"])
|
||||
days = config.get("autopilot.intelligent.bootstrap.days", 90)
|
||||
timeframe = config.get("autopilot.intelligent.bootstrap.timeframe", "1h")
|
||||
min_samples = config.get("autopilot.intelligent.bootstrap.min_samples_per_strategy", 10)
|
||||
|
||||
# Submit to Celery with all configured parameters
|
||||
task = train_model_task.delay(
|
||||
force_retrain=force,
|
||||
bootstrap=bootstrap,
|
||||
symbols=symbols,
|
||||
days=days,
|
||||
timeframe=timeframe,
|
||||
min_samples_per_strategy=min_samples
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "queued",
|
||||
"message": "Model retraining started in background",
|
||||
"task_id": task.id
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/intelligent/model-info")
|
||||
async def get_model_info():
|
||||
"""Get ML model information."""
|
||||
try:
|
||||
selector = get_strategy_selector()
|
||||
return selector.get_model_info()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/intelligent/reset")
|
||||
async def reset_model():
|
||||
"""Reset/delete all saved ML models and training data.
|
||||
|
||||
This clears all persisted model files AND training data from database,
|
||||
allowing for a fresh start with new features.
|
||||
"""
|
||||
try:
|
||||
from pathlib import Path
|
||||
from src.core.database import get_database, MarketConditionsSnapshot, StrategyPerformance
|
||||
from sqlalchemy import delete
|
||||
|
||||
# Get model directory
|
||||
model_dir = Path.home() / ".local" / "share" / "crypto_trader" / "models"
|
||||
|
||||
deleted_count = 0
|
||||
if model_dir.exists():
|
||||
# Delete all strategy selector model files
|
||||
for model_file in model_dir.glob("strategy_selector_*.joblib"):
|
||||
model_file.unlink()
|
||||
deleted_count += 1
|
||||
|
||||
# Clear training data from database
|
||||
db = get_database()
|
||||
db_cleared = 0
|
||||
try:
|
||||
async with db.get_session() as session:
|
||||
# Delete all market conditions snapshots
|
||||
result1 = await session.execute(delete(MarketConditionsSnapshot))
|
||||
# Delete all strategy performance records
|
||||
result2 = await session.execute(delete(StrategyPerformance))
|
||||
await session.commit()
|
||||
db_cleared = result1.rowcount + result2.rowcount
|
||||
except Exception as e:
|
||||
# Database clearing is optional - continue even if it fails
|
||||
pass
|
||||
|
||||
# Reset the in-memory model state
|
||||
selector = get_strategy_selector()
|
||||
from src.autopilot.models import StrategySelectorModel
|
||||
selector.model = StrategySelectorModel(model_type="classifier")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Deleted {deleted_count} model file(s) and {db_cleared} training records. Model reset to untrained state.",
|
||||
"deleted_count": deleted_count,
|
||||
"db_records_cleared": db_cleared
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Multi-Symbol Autopilot Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/multi-symbol/start")
|
||||
async def start_multi_symbol_autopilot(
|
||||
config: MultiSymbolAutopilotConfig,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Start autopilot for multiple symbols simultaneously.
|
||||
|
||||
Args:
|
||||
config: Multi-symbol autopilot configuration
|
||||
background_tasks: FastAPI background tasks
|
||||
"""
|
||||
try:
|
||||
results = []
|
||||
for symbol in config.symbols:
|
||||
# Always use intelligent mode
|
||||
autopilot = get_intelligent_autopilot(
|
||||
symbol=symbol,
|
||||
exchange_id=config.exchange_id,
|
||||
timeframe=config.timeframe,
|
||||
interval=config.interval,
|
||||
paper_trading=config.paper_trading
|
||||
)
|
||||
|
||||
autopilot.enable_auto_execution = config.auto_execute
|
||||
|
||||
if not autopilot.is_running:
|
||||
# Set running flag synchronously before scheduling background task
|
||||
autopilot._running = True
|
||||
background_tasks.add_task(autopilot.start)
|
||||
results.append({"symbol": symbol, "status": "started"})
|
||||
else:
|
||||
results.append({"symbol": symbol, "status": "already_running"})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"mode": "intelligent",
|
||||
"symbols": results
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/multi-symbol/stop")
|
||||
async def stop_multi_symbol_autopilot(
|
||||
symbols: List[str],
|
||||
mode: str = "intelligent",
|
||||
timeframe: str = "1h"
|
||||
):
|
||||
"""Stop autopilot for multiple symbols.
|
||||
|
||||
Args:
|
||||
symbols: List of symbols to stop
|
||||
mode: Autopilot mode (pattern or intelligent)
|
||||
timeframe: Timeframe for intelligent mode
|
||||
"""
|
||||
try:
|
||||
results = []
|
||||
for symbol in symbols:
|
||||
# Always use intelligent mode
|
||||
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
|
||||
autopilot.stop()
|
||||
results.append({"symbol": symbol, "status": "stopped"})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"symbols": results
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/multi-symbol/status")
|
||||
async def get_multi_symbol_status(
|
||||
symbols: str = "", # Comma-separated list
|
||||
mode: str = "intelligent",
|
||||
timeframe: str = "1h"
|
||||
):
|
||||
"""Get status for multiple symbols.
|
||||
|
||||
Args:
|
||||
symbols: Comma-separated list of symbols (empty = all running)
|
||||
mode: Autopilot mode
|
||||
timeframe: Timeframe for intelligent mode
|
||||
"""
|
||||
from src.autopilot.intelligent_autopilot import _intelligent_autopilots
|
||||
|
||||
try:
|
||||
results = []
|
||||
|
||||
if symbols:
|
||||
symbol_list = [s.strip() for s in symbols.split(",")]
|
||||
else:
|
||||
# Get all running autopilots (intelligent only)
|
||||
symbol_list = [key.split(":")[0] for key in _intelligent_autopilots.keys()]
|
||||
|
||||
for symbol in symbol_list:
|
||||
try:
|
||||
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
|
||||
status = autopilot.get_status()
|
||||
status["symbol"] = symbol
|
||||
status["mode"] = "intelligent"
|
||||
results.append(status)
|
||||
except Exception:
|
||||
results.append({
|
||||
"symbol": symbol,
|
||||
"mode": "intelligent",
|
||||
"running": False,
|
||||
"error": "Not found"
|
||||
})
|
||||
|
||||
return {
|
||||
"mode": mode,
|
||||
"symbols": results,
|
||||
"total_running": sum(1 for r in results if r.get("running", False))
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Unified Autopilot Endpoints (New)
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/modes")
|
||||
async def get_autopilot_modes():
|
||||
"""Get information about available autopilot modes.
|
||||
|
||||
Returns mode descriptions, capabilities, tradeoffs, and comparison data.
|
||||
"""
|
||||
try:
|
||||
return get_autopilot_mode_info()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/start-unified")
|
||||
async def start_unified_autopilot(
|
||||
config: UnifiedAutopilotConfig,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Start autopilot with unified interface (Intelligent Mode only)."""
|
||||
try:
|
||||
# Validate mode (for backward compatibility of API clients sending mode)
|
||||
if config.mode and config.mode != "intelligent":
|
||||
# We allow it but will treat it as intelligent if possible, or raise error if critical
|
||||
pass
|
||||
|
||||
# Start ML-based autopilot
|
||||
autopilot = get_intelligent_autopilot(
|
||||
symbol=config.symbol,
|
||||
exchange_id=config.exchange_id,
|
||||
timeframe=config.timeframe,
|
||||
interval=config.interval,
|
||||
paper_trading=config.paper_trading
|
||||
)
|
||||
|
||||
# Set auto-execution if enabled
|
||||
if config.auto_execute:
|
||||
autopilot.enable_auto_execution = True
|
||||
|
||||
if not autopilot.is_running:
|
||||
# Schedule background task (state management handled by autopilot.start via Redis)
|
||||
background_tasks.add_task(autopilot.start)
|
||||
|
||||
return {
|
||||
"status": "started",
|
||||
"mode": "intelligent",
|
||||
"symbol": config.symbol,
|
||||
"timeframe": config.timeframe,
|
||||
"auto_execute": config.auto_execute
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/stop-unified")
|
||||
async def stop_unified_autopilot(symbol: str, mode: str, timeframe: str = "1h"):
|
||||
"""Stop autopilot for a symbol."""
|
||||
try:
|
||||
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
|
||||
autopilot.stop()
|
||||
|
||||
return {"status": "stopped", "symbol": symbol, "mode": "intelligent"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/status-unified/{symbol:path}")
|
||||
async def get_unified_status(symbol: str, mode: str, timeframe: str = "1h"):
|
||||
"""Get autopilot status for a symbol."""
|
||||
try:
|
||||
autopilot = get_intelligent_autopilot(symbol=symbol, timeframe=timeframe)
|
||||
# Use distributed status check (Redis)
|
||||
status = await autopilot.get_distributed_status()
|
||||
status["mode"] = "intelligent"
|
||||
return status
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unified status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/bootstrap-config", response_model=BootstrapConfig)
|
||||
async def get_bootstrap_config():
|
||||
"""Get bootstrap training data configuration."""
|
||||
from src.core.config import get_config
|
||||
config = get_config()
|
||||
|
||||
return BootstrapConfig(
|
||||
days=config.get("autopilot.intelligent.bootstrap.days", 90),
|
||||
timeframe=config.get("autopilot.intelligent.bootstrap.timeframe", "1h"),
|
||||
min_samples_per_strategy=config.get("autopilot.intelligent.bootstrap.min_samples_per_strategy", 10),
|
||||
symbols=config.get("autopilot.intelligent.bootstrap.symbols", ["BTC/USD", "ETH/USD"]),
|
||||
)
|
||||
|
||||
|
||||
@router.put("/bootstrap-config")
|
||||
async def update_bootstrap_config(settings: BootstrapConfig):
|
||||
"""Update bootstrap training data configuration."""
|
||||
from src.core.config import get_config
|
||||
config = get_config()
|
||||
|
||||
try:
|
||||
config.set("autopilot.intelligent.bootstrap.days", settings.days)
|
||||
config.set("autopilot.intelligent.bootstrap.timeframe", settings.timeframe)
|
||||
config.set("autopilot.intelligent.bootstrap.min_samples_per_strategy", settings.min_samples_per_strategy)
|
||||
config.set("autopilot.intelligent.bootstrap.symbols", settings.symbols)
|
||||
|
||||
# Also update the strategy selector instance if it exists
|
||||
selector = get_strategy_selector()
|
||||
selector.bootstrap_days = settings.days
|
||||
selector.bootstrap_timeframe = settings.timeframe
|
||||
selector.min_samples_per_strategy = settings.min_samples_per_strategy
|
||||
selector.bootstrap_symbols = settings.symbols
|
||||
|
||||
return {"status": "success", "message": "Bootstrap configuration updated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
"""Get status of a background task."""
|
||||
try:
|
||||
task_result = AsyncResult(task_id)
|
||||
|
||||
try:
|
||||
# Accessing status or result might raise an exception if deserialization fails
|
||||
status = task_result.status
|
||||
result_data = task_result.result if task_result.ready() else None
|
||||
meta_data = task_result.info if status == 'PROGRESS' else None
|
||||
|
||||
# serialized exception handling
|
||||
if isinstance(result_data, Exception):
|
||||
result_data = {
|
||||
"error": str(result_data),
|
||||
"type": type(result_data).__name__,
|
||||
"detail": str(result_data)
|
||||
}
|
||||
elif status == "FAILURE" and (not result_data or result_data == {}):
|
||||
# If failure but empty result, try to get traceback or use a default message
|
||||
tb = getattr(task_result, 'traceback', None)
|
||||
if tb:
|
||||
result_data = {"error": "Task failed", "detail": str(tb)}
|
||||
else:
|
||||
result_data = {"error": "Task failed with no error info", "detail": "Check worker logs for details"}
|
||||
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"result": result_data
|
||||
}
|
||||
|
||||
if meta_data:
|
||||
result["meta"] = meta_data
|
||||
|
||||
return result
|
||||
|
||||
except Exception as inner_e:
|
||||
# If Celery fails to get status/result (e.g. serialization error), return FAILURE
|
||||
# This prevents 500 errors in the API when the task itself failed badly
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "FAILURE",
|
||||
"result": {"error": str(inner_e), "detail": "Failed to retrieve task status"},
|
||||
"meta": {"error": "Task retrieval failed"}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
Reference in New Issue
Block a user