Files
crypto_trader/backend/api/backtesting.py

185 lines
7.3 KiB
Python
Raw Normal View History

"""Backtesting API endpoints."""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Dict, Any
from sqlalchemy import select
import uuid
from ..core.dependencies import get_backtesting_engine, get_strategy_registry
from ..core.schemas import BacktestRequest, BacktestResponse, WalkForwardRequest, MonteCarloRequest
from src.core.database import Strategy, get_database
from src.backtesting.walk_forward import WalkForwardAnalyzer
from src.backtesting.monte_carlo import MonteCarloSimulator
router = APIRouter()
# Store running backtests
_backtests: Dict[str, Dict[str, Any]] = {}
@router.post("/run", response_model=BacktestResponse)
async def run_backtest(
backtest_data: BacktestRequest,
background_tasks: BackgroundTasks,
backtest_engine=Depends(get_backtesting_engine)
):
"""Run a backtest."""
try:
db = get_database()
async with db.get_session() as session:
# Get strategy
stmt = select(Strategy).where(Strategy.id == backtest_data.strategy_id)
result = await session.execute(stmt)
strategy_db = result.scalar_one_or_none()
if not strategy_db:
raise HTTPException(status_code=404, detail="Strategy not found")
# Create strategy instance
registry = get_strategy_registry()
strategy_instance = registry.create_instance(
strategy_id=strategy_db.id,
name=strategy_db.class_name,
parameters=strategy_db.parameters,
timeframes=strategy_db.timeframes or [backtest_data.timeframe]
)
if not strategy_instance:
raise HTTPException(status_code=400, detail="Failed to create strategy instance")
# Run backtest
results = backtest_engine.run_backtest(
strategy=strategy_instance,
symbol=backtest_data.symbol,
exchange=backtest_data.exchange,
timeframe=backtest_data.timeframe,
start_date=backtest_data.start_date,
end_date=backtest_data.end_date,
initial_capital=backtest_data.initial_capital,
slippage=backtest_data.slippage,
fee_rate=backtest_data.fee_rate
)
if "error" in results:
raise HTTPException(status_code=400, detail=results["error"])
return BacktestResponse(
results=results,
status="completed"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/results/{backtest_id}")
async def get_backtest_results(backtest_id: str):
"""Get backtest results by ID."""
if backtest_id not in _backtests:
raise HTTPException(status_code=404, detail="Backtest not found")
return _backtests[backtest_id]
@router.post("/walk-forward")
async def run_walk_forward(
walk_forward_data: WalkForwardRequest,
backtest_engine=Depends(get_backtesting_engine)
):
"""Run walk-forward analysis for robust parameter optimization."""
try:
db = get_database()
async with db.get_session() as session:
# Get strategy
stmt = select(Strategy).where(Strategy.id == walk_forward_data.strategy_id)
result = await session.execute(stmt)
strategy_db = result.scalar_one_or_none()
if not strategy_db:
raise HTTPException(status_code=404, detail="Strategy not found")
# Get strategy class
registry = get_strategy_registry()
strategy_class = registry.get_strategy_class(strategy_db.class_name)
if not strategy_class:
raise HTTPException(status_code=400, detail=f"Strategy class {strategy_db.class_name} not found")
# Run walk-forward analysis
analyzer = WalkForwardAnalyzer(backtest_engine)
results = await analyzer.run_walk_forward(
strategy_class=strategy_class,
symbol=walk_forward_data.symbol,
exchange=walk_forward_data.exchange,
timeframe=walk_forward_data.timeframe,
start_date=walk_forward_data.start_date,
end_date=walk_forward_data.end_date,
train_period_days=walk_forward_data.train_period_days,
test_period_days=walk_forward_data.test_period_days,
step_days=walk_forward_data.step_days,
initial_capital=walk_forward_data.initial_capital,
parameter_grid=walk_forward_data.parameter_grid,
optimization_metric=walk_forward_data.optimization_metric
)
if "error" in results:
raise HTTPException(status_code=400, detail=results["error"])
return results
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/monte-carlo")
async def run_monte_carlo(
monte_carlo_data: MonteCarloRequest,
backtest_engine=Depends(get_backtesting_engine)
):
"""Run Monte Carlo simulation for risk analysis."""
try:
db = get_database()
async with db.get_session() as session:
# Get strategy
stmt = select(Strategy).where(Strategy.id == monte_carlo_data.strategy_id)
result = await session.execute(stmt)
strategy_db = result.scalar_one_or_none()
if not strategy_db:
raise HTTPException(status_code=404, detail="Strategy not found")
# Get strategy class
registry = get_strategy_registry()
strategy_class = registry.get_strategy_class(strategy_db.class_name)
if not strategy_class:
raise HTTPException(status_code=400, detail=f"Strategy class {strategy_db.class_name} not found")
# Convert parameter_ranges format if provided
param_ranges = None
if monte_carlo_data.parameter_ranges:
param_ranges = {
k: (v[0], v[1]) for k, v in monte_carlo_data.parameter_ranges.items()
if len(v) >= 2
}
# Run Monte Carlo simulation
simulator = MonteCarloSimulator(backtest_engine)
results = await simulator.run_monte_carlo(
strategy_class=strategy_class,
symbol=monte_carlo_data.symbol,
exchange=monte_carlo_data.exchange,
timeframe=monte_carlo_data.timeframe,
start_date=monte_carlo_data.start_date,
end_date=monte_carlo_data.end_date,
initial_capital=monte_carlo_data.initial_capital,
num_simulations=monte_carlo_data.num_simulations,
parameter_ranges=param_ranges,
random_seed=monte_carlo_data.random_seed
)
if "error" in results:
raise HTTPException(status_code=400, detail=results["error"])
return results
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))