Files
crypto_trader/backend/api/backtesting.py
kfox 7bd6be64a4
Some checks are pending
Documentation / build-docs (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.13) (push) Waiting to run
Tests / test (macos-latest, 3.14) (push) Waiting to run
Tests / test (ubuntu-latest, 3.11) (push) Waiting to run
Tests / test (ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (ubuntu-latest, 3.13) (push) Waiting to run
Tests / test (ubuntu-latest, 3.14) (push) Waiting to run
feat: Add core trading modules for risk management, backtesting, and execution algorithms, alongside a new ML transparency widget and related frontend dependencies.
2025-12-31 21:25:06 -05:00

185 lines
7.3 KiB
Python

"""Backtesting API endpoints."""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Dict, Any
from sqlalchemy import select
import uuid
from ..core.dependencies import get_backtesting_engine, get_strategy_registry
from ..core.schemas import BacktestRequest, BacktestResponse, WalkForwardRequest, MonteCarloRequest
from src.core.database import Strategy, get_database
from src.backtesting.walk_forward import WalkForwardAnalyzer
from src.backtesting.monte_carlo import MonteCarloSimulator
router = APIRouter()
# Store running backtests
_backtests: Dict[str, Dict[str, Any]] = {}
@router.post("/run", response_model=BacktestResponse)
async def run_backtest(
backtest_data: BacktestRequest,
background_tasks: BackgroundTasks,
backtest_engine=Depends(get_backtesting_engine)
):
"""Run a backtest."""
try:
db = get_database()
async with db.get_session() as session:
# Get strategy
stmt = select(Strategy).where(Strategy.id == backtest_data.strategy_id)
result = await session.execute(stmt)
strategy_db = result.scalar_one_or_none()
if not strategy_db:
raise HTTPException(status_code=404, detail="Strategy not found")
# Create strategy instance
registry = get_strategy_registry()
strategy_instance = registry.create_instance(
strategy_id=strategy_db.id,
name=strategy_db.class_name,
parameters=strategy_db.parameters,
timeframes=strategy_db.timeframes or [backtest_data.timeframe]
)
if not strategy_instance:
raise HTTPException(status_code=400, detail="Failed to create strategy instance")
# Run backtest
results = backtest_engine.run_backtest(
strategy=strategy_instance,
symbol=backtest_data.symbol,
exchange=backtest_data.exchange,
timeframe=backtest_data.timeframe,
start_date=backtest_data.start_date,
end_date=backtest_data.end_date,
initial_capital=backtest_data.initial_capital,
slippage=backtest_data.slippage,
fee_rate=backtest_data.fee_rate
)
if "error" in results:
raise HTTPException(status_code=400, detail=results["error"])
return BacktestResponse(
results=results,
status="completed"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/results/{backtest_id}")
async def get_backtest_results(backtest_id: str):
"""Get backtest results by ID."""
if backtest_id not in _backtests:
raise HTTPException(status_code=404, detail="Backtest not found")
return _backtests[backtest_id]
@router.post("/walk-forward")
async def run_walk_forward(
walk_forward_data: WalkForwardRequest,
backtest_engine=Depends(get_backtesting_engine)
):
"""Run walk-forward analysis for robust parameter optimization."""
try:
db = get_database()
async with db.get_session() as session:
# Get strategy
stmt = select(Strategy).where(Strategy.id == walk_forward_data.strategy_id)
result = await session.execute(stmt)
strategy_db = result.scalar_one_or_none()
if not strategy_db:
raise HTTPException(status_code=404, detail="Strategy not found")
# Get strategy class
registry = get_strategy_registry()
strategy_class = registry.get_strategy_class(strategy_db.class_name)
if not strategy_class:
raise HTTPException(status_code=400, detail=f"Strategy class {strategy_db.class_name} not found")
# Run walk-forward analysis
analyzer = WalkForwardAnalyzer(backtest_engine)
results = await analyzer.run_walk_forward(
strategy_class=strategy_class,
symbol=walk_forward_data.symbol,
exchange=walk_forward_data.exchange,
timeframe=walk_forward_data.timeframe,
start_date=walk_forward_data.start_date,
end_date=walk_forward_data.end_date,
train_period_days=walk_forward_data.train_period_days,
test_period_days=walk_forward_data.test_period_days,
step_days=walk_forward_data.step_days,
initial_capital=walk_forward_data.initial_capital,
parameter_grid=walk_forward_data.parameter_grid,
optimization_metric=walk_forward_data.optimization_metric
)
if "error" in results:
raise HTTPException(status_code=400, detail=results["error"])
return results
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/monte-carlo")
async def run_monte_carlo(
monte_carlo_data: MonteCarloRequest,
backtest_engine=Depends(get_backtesting_engine)
):
"""Run Monte Carlo simulation for risk analysis."""
try:
db = get_database()
async with db.get_session() as session:
# Get strategy
stmt = select(Strategy).where(Strategy.id == monte_carlo_data.strategy_id)
result = await session.execute(stmt)
strategy_db = result.scalar_one_or_none()
if not strategy_db:
raise HTTPException(status_code=404, detail="Strategy not found")
# Get strategy class
registry = get_strategy_registry()
strategy_class = registry.get_strategy_class(strategy_db.class_name)
if not strategy_class:
raise HTTPException(status_code=400, detail=f"Strategy class {strategy_db.class_name} not found")
# Convert parameter_ranges format if provided
param_ranges = None
if monte_carlo_data.parameter_ranges:
param_ranges = {
k: (v[0], v[1]) for k, v in monte_carlo_data.parameter_ranges.items()
if len(v) >= 2
}
# Run Monte Carlo simulation
simulator = MonteCarloSimulator(backtest_engine)
results = await simulator.run_monte_carlo(
strategy_class=strategy_class,
symbol=monte_carlo_data.symbol,
exchange=monte_carlo_data.exchange,
timeframe=monte_carlo_data.timeframe,
start_date=monte_carlo_data.start_date,
end_date=monte_carlo_data.end_date,
initial_capital=monte_carlo_data.initial_capital,
num_simulations=monte_carlo_data.num_simulations,
parameter_ranges=param_ranges,
random_seed=monte_carlo_data.random_seed
)
if "error" in results:
raise HTTPException(status_code=400, detail=results["error"])
return results
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))