Files
crypto_trader/src/autopilot/performance_tracker.py

352 lines
14 KiB
Python
Raw Normal View History

"""Strategy performance tracker for ML training data collection."""
from decimal import Decimal
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
import pandas as pd
import numpy as np
from src.core.database import get_database, Trade, Strategy
from src.core.logger import get_logger
from .market_analyzer import MarketConditions
logger = get_logger(__name__)
class PerformanceTracker:
"""Tracks strategy performance for ML training."""
def __init__(self):
"""Initialize performance tracker."""
self.db = get_database()
self.logger = get_logger(__name__)
async def record_trade(
self,
strategy_name: str,
market_conditions: MarketConditions,
trade_result: Dict[str, Any]
) -> bool:
"""Record a trade for ML training.
Args:
strategy_name: Name of strategy used
market_conditions: Market conditions at trade time
trade_result: Trade result with performance metrics
Returns:
True if recorded successfully
"""
try:
async with self.db.get_session() as session:
try:
# Store market conditions snapshot
from src.core.database import MarketConditionsSnapshot
snapshot = MarketConditionsSnapshot(
symbol=market_conditions.symbol,
timeframe=market_conditions.timeframe,
regime=market_conditions.regime.value,
features=market_conditions.features,
strategy_name=strategy_name,
timestamp=market_conditions.timestamp
)
session.add(snapshot)
# Store performance record
from src.core.database import StrategyPerformance
performance = StrategyPerformance(
strategy_name=strategy_name,
symbol=market_conditions.symbol,
timeframe=market_conditions.timeframe,
market_regime=market_conditions.regime.value,
return_pct=float(trade_result.get('return_pct', 0.0)),
sharpe_ratio=float(trade_result.get('sharpe_ratio', 0.0)),
win_rate=float(trade_result.get('win_rate', 0.0)),
max_drawdown=float(trade_result.get('max_drawdown', 0.0)),
trade_count=int(trade_result.get('trade_count', 1)),
timestamp=datetime.utcnow()
)
session.add(performance)
await session.commit()
return True
except Exception as e:
await session.rollback()
self.logger.error(f"Failed to record trade: {e}")
return False
except Exception as e:
self.logger.error(f"Error recording trade: {e}")
return False
async def get_performance_history(
self,
strategy_name: Optional[str] = None,
market_regime: Optional[str] = None,
days: int = 30
) -> pd.DataFrame:
"""Get performance history for training.
Args:
strategy_name: Filter by strategy name
market_regime: Filter by market regime
days: Number of days to look back
Returns:
DataFrame with performance history
"""
from sqlalchemy import select
try:
async with self.db.get_session() as session:
from src.core.database import StrategyPerformance, MarketConditionsSnapshot
# Query performance records
stmt = select(StrategyPerformance)
if strategy_name:
stmt = stmt.where(StrategyPerformance.strategy_name == strategy_name)
if market_regime:
stmt = stmt.where(StrategyPerformance.market_regime == market_regime)
cutoff_date = datetime.utcnow() - timedelta(days=days)
stmt = stmt.where(StrategyPerformance.timestamp >= cutoff_date)
# Limit to prevent excessive queries - if we have lots of data, sample it
stmt = stmt.order_by(StrategyPerformance.timestamp.desc()).limit(10000)
result = await session.execute(stmt)
records = result.scalars().all()
if not records:
return pd.DataFrame()
self.logger.info(f"Processing {len(records)} performance records for training data")
# Convert to DataFrame - optimize by batching snapshot queries
data = []
# Batch snapshot lookups to reduce N+1 query problem
snapshot_cache = {}
for record in records:
cache_key = f"{record.strategy_name}:{record.symbol}:{record.timeframe}:{record.timestamp.date()}"
if cache_key not in snapshot_cache:
# Get corresponding market conditions (only once per day per strategy)
snapshot_stmt = select(MarketConditionsSnapshot).filter_by(
strategy_name=record.strategy_name,
symbol=record.symbol,
timeframe=record.timeframe
).where(
MarketConditionsSnapshot.timestamp <= record.timestamp
).order_by(
MarketConditionsSnapshot.timestamp.desc()
).limit(1)
snapshot_result = await session.execute(snapshot_stmt)
snapshot = snapshot_result.scalar_one_or_none()
snapshot_cache[cache_key] = snapshot
else:
snapshot = snapshot_cache[cache_key]
row = {
'strategy_name': record.strategy_name,
'symbol': record.symbol,
'timeframe': record.timeframe,
'market_regime': record.market_regime,
'return_pct': float(record.return_pct),
'sharpe_ratio': float(record.sharpe_ratio),
'win_rate': float(record.win_rate),
'max_drawdown': float(record.max_drawdown),
'trade_count': int(record.trade_count),
'timestamp': record.timestamp
}
# Add market features if available
if snapshot and snapshot.features:
row.update(snapshot.features)
data.append(row)
return pd.DataFrame(data)
except Exception as e:
self.logger.error(f"Error getting performance history: {e}")
return pd.DataFrame()
async def calculate_metrics(
self,
strategy_name: str,
period_days: int = 30
) -> Dict[str, Any]:
"""Calculate performance metrics for a strategy.
Args:
strategy_name: Strategy name
period_days: Period to calculate metrics over
Returns:
Dictionary of performance metrics
"""
from sqlalchemy import select
try:
async with self.db.get_session() as session:
from src.core.database import StrategyPerformance
cutoff_date = datetime.utcnow() - timedelta(days=period_days)
stmt = select(StrategyPerformance).where(
StrategyPerformance.strategy_name == strategy_name,
StrategyPerformance.timestamp >= cutoff_date
)
result = await session.execute(stmt)
records = result.scalars().all()
if not records:
return {
'total_trades': 0,
'win_rate': 0.0,
'avg_return': 0.0,
'total_return': 0.0,
'sharpe_ratio': 0.0,
'max_drawdown': 0.0
}
returns = [float(r.return_pct) for r in records]
wins = [r for r in returns if r > 0]
total_return = sum(returns)
avg_return = np.mean(returns) if returns else 0.0
win_rate = len(wins) / len(returns) if returns else 0.0
# Calculate Sharpe ratio (simplified)
if len(returns) > 1:
sharpe = (avg_return / np.std(returns)) * np.sqrt(252) if np.std(returns) > 0 else 0.0
else:
sharpe = 0.0
# Max drawdown
max_dd = max([float(r.max_drawdown) for r in records]) if records else 0.0
return {
'total_trades': len(records),
'win_rate': float(win_rate),
'avg_return': float(avg_return),
'total_return': float(total_return),
'sharpe_ratio': float(sharpe),
'max_drawdown': float(max_dd)
}
except Exception as e:
self.logger.error(f"Error calculating metrics: {e}")
return {}
async def prepare_training_data(
self,
min_samples_per_strategy: int = 10
) -> Optional[Dict[str, Any]]:
"""Prepare training data for ML model.
Args:
min_samples_per_strategy: Minimum samples required per strategy
Returns:
Dictionary with 'X' (features), 'y' (targets), and 'strategy_names'
"""
try:
# Get all performance history - extended to 365 days for better training
df = await self.get_performance_history(days=365)
if df.empty:
self.logger.warning("No performance history available for training")
return None
# Filter strategies with enough samples
strategy_counts = df['strategy_name'].value_counts()
valid_strategies = strategy_counts[strategy_counts >= min_samples_per_strategy].index.tolist()
if not valid_strategies:
self.logger.warning(f"No strategies with at least {min_samples_per_strategy} samples")
return None
df = df[df['strategy_name'].isin(valid_strategies)]
# Extract features (exclude metadata columns)
feature_cols = [col for col in df.columns if col not in [
'strategy_name', 'symbol', 'timeframe', 'market_regime',
'return_pct', 'sharpe_ratio', 'win_rate', 'max_drawdown',
'trade_count', 'timestamp'
]]
# Fill missing features with 0
for col in feature_cols:
if col not in df.columns:
df[col] = 0.0
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
X = df[feature_cols]
y = df['strategy_name'].values
return {
'X': X,
'y': y,
'strategy_names': valid_strategies,
'feature_names': feature_cols,
'training_symbols': sorted(df['symbol'].unique().tolist())
}
except Exception as e:
self.logger.error(f"Error preparing training data: {e}")
return None
async def get_strategy_sample_counts(self, days: int = 365) -> Dict[str, int]:
"""Get sample counts per strategy.
Args:
days: Number of days to look back
Returns:
Dictionary mapping strategy names to sample counts
"""
from sqlalchemy import select, func
try:
async with self.db.get_session() as session:
from src.core.database import StrategyPerformance
cutoff_date = datetime.utcnow() - timedelta(days=days)
# Group by strategy name and count
stmt = select(
StrategyPerformance.strategy_name,
func.count(StrategyPerformance.id)
).where(
StrategyPerformance.timestamp >= cutoff_date
).group_by(
StrategyPerformance.strategy_name
)
result = await session.execute(stmt)
counts = dict(result.all())
return counts
except Exception as e:
self.logger.error(f"Error getting strategy sample counts: {e}")
return {}
# Global instance
_performance_tracker: Optional[PerformanceTracker] = None
def get_performance_tracker() -> PerformanceTracker:
"""Get global performance tracker instance."""
global _performance_tracker
if _performance_tracker is None:
_performance_tracker = PerformanceTracker()
return _performance_tracker