"""Strategy performance tracker for ML training data collection.""" from decimal import Decimal from typing import Dict, Any, Optional, List from datetime import datetime, timedelta from sqlalchemy.orm import Session import pandas as pd import numpy as np from src.core.database import get_database, Trade, Strategy from src.core.logger import get_logger from .market_analyzer import MarketConditions logger = get_logger(__name__) class PerformanceTracker: """Tracks strategy performance for ML training.""" def __init__(self): """Initialize performance tracker.""" self.db = get_database() self.logger = get_logger(__name__) async def record_trade( self, strategy_name: str, market_conditions: MarketConditions, trade_result: Dict[str, Any] ) -> bool: """Record a trade for ML training. Args: strategy_name: Name of strategy used market_conditions: Market conditions at trade time trade_result: Trade result with performance metrics Returns: True if recorded successfully """ try: async with self.db.get_session() as session: try: # Store market conditions snapshot from src.core.database import MarketConditionsSnapshot snapshot = MarketConditionsSnapshot( symbol=market_conditions.symbol, timeframe=market_conditions.timeframe, regime=market_conditions.regime.value, features=market_conditions.features, strategy_name=strategy_name, timestamp=market_conditions.timestamp ) session.add(snapshot) # Store performance record from src.core.database import StrategyPerformance performance = StrategyPerformance( strategy_name=strategy_name, symbol=market_conditions.symbol, timeframe=market_conditions.timeframe, market_regime=market_conditions.regime.value, return_pct=float(trade_result.get('return_pct', 0.0)), sharpe_ratio=float(trade_result.get('sharpe_ratio', 0.0)), win_rate=float(trade_result.get('win_rate', 0.0)), max_drawdown=float(trade_result.get('max_drawdown', 0.0)), trade_count=int(trade_result.get('trade_count', 1)), timestamp=datetime.utcnow() ) session.add(performance) await session.commit() return True except Exception as e: await session.rollback() self.logger.error(f"Failed to record trade: {e}") return False except Exception as e: self.logger.error(f"Error recording trade: {e}") return False async def get_performance_history( self, strategy_name: Optional[str] = None, market_regime: Optional[str] = None, days: int = 30 ) -> pd.DataFrame: """Get performance history for training. Args: strategy_name: Filter by strategy name market_regime: Filter by market regime days: Number of days to look back Returns: DataFrame with performance history """ from sqlalchemy import select try: async with self.db.get_session() as session: from src.core.database import StrategyPerformance, MarketConditionsSnapshot # Query performance records stmt = select(StrategyPerformance) if strategy_name: stmt = stmt.where(StrategyPerformance.strategy_name == strategy_name) if market_regime: stmt = stmt.where(StrategyPerformance.market_regime == market_regime) cutoff_date = datetime.utcnow() - timedelta(days=days) stmt = stmt.where(StrategyPerformance.timestamp >= cutoff_date) # Limit to prevent excessive queries - if we have lots of data, sample it stmt = stmt.order_by(StrategyPerformance.timestamp.desc()).limit(10000) result = await session.execute(stmt) records = result.scalars().all() if not records: return pd.DataFrame() self.logger.info(f"Processing {len(records)} performance records for training data") # Convert to DataFrame - optimize by batching snapshot queries data = [] # Batch snapshot lookups to reduce N+1 query problem snapshot_cache = {} for record in records: cache_key = f"{record.strategy_name}:{record.symbol}:{record.timeframe}:{record.timestamp.date()}" if cache_key not in snapshot_cache: # Get corresponding market conditions (only once per day per strategy) snapshot_stmt = select(MarketConditionsSnapshot).filter_by( strategy_name=record.strategy_name, symbol=record.symbol, timeframe=record.timeframe ).where( MarketConditionsSnapshot.timestamp <= record.timestamp ).order_by( MarketConditionsSnapshot.timestamp.desc() ).limit(1) snapshot_result = await session.execute(snapshot_stmt) snapshot = snapshot_result.scalar_one_or_none() snapshot_cache[cache_key] = snapshot else: snapshot = snapshot_cache[cache_key] row = { 'strategy_name': record.strategy_name, 'symbol': record.symbol, 'timeframe': record.timeframe, 'market_regime': record.market_regime, 'return_pct': float(record.return_pct), 'sharpe_ratio': float(record.sharpe_ratio), 'win_rate': float(record.win_rate), 'max_drawdown': float(record.max_drawdown), 'trade_count': int(record.trade_count), 'timestamp': record.timestamp } # Add market features if available if snapshot and snapshot.features: row.update(snapshot.features) data.append(row) return pd.DataFrame(data) except Exception as e: self.logger.error(f"Error getting performance history: {e}") return pd.DataFrame() async def calculate_metrics( self, strategy_name: str, period_days: int = 30 ) -> Dict[str, Any]: """Calculate performance metrics for a strategy. Args: strategy_name: Strategy name period_days: Period to calculate metrics over Returns: Dictionary of performance metrics """ from sqlalchemy import select try: async with self.db.get_session() as session: from src.core.database import StrategyPerformance cutoff_date = datetime.utcnow() - timedelta(days=period_days) stmt = select(StrategyPerformance).where( StrategyPerformance.strategy_name == strategy_name, StrategyPerformance.timestamp >= cutoff_date ) result = await session.execute(stmt) records = result.scalars().all() if not records: return { 'total_trades': 0, 'win_rate': 0.0, 'avg_return': 0.0, 'total_return': 0.0, 'sharpe_ratio': 0.0, 'max_drawdown': 0.0 } returns = [float(r.return_pct) for r in records] wins = [r for r in returns if r > 0] total_return = sum(returns) avg_return = np.mean(returns) if returns else 0.0 win_rate = len(wins) / len(returns) if returns else 0.0 # Calculate Sharpe ratio (simplified) if len(returns) > 1: sharpe = (avg_return / np.std(returns)) * np.sqrt(252) if np.std(returns) > 0 else 0.0 else: sharpe = 0.0 # Max drawdown max_dd = max([float(r.max_drawdown) for r in records]) if records else 0.0 return { 'total_trades': len(records), 'win_rate': float(win_rate), 'avg_return': float(avg_return), 'total_return': float(total_return), 'sharpe_ratio': float(sharpe), 'max_drawdown': float(max_dd) } except Exception as e: self.logger.error(f"Error calculating metrics: {e}") return {} async def prepare_training_data( self, min_samples_per_strategy: int = 10 ) -> Optional[Dict[str, Any]]: """Prepare training data for ML model. Args: min_samples_per_strategy: Minimum samples required per strategy Returns: Dictionary with 'X' (features), 'y' (targets), and 'strategy_names' """ try: # Get all performance history - extended to 365 days for better training df = await self.get_performance_history(days=365) if df.empty: self.logger.warning("No performance history available for training") return None # Filter strategies with enough samples strategy_counts = df['strategy_name'].value_counts() valid_strategies = strategy_counts[strategy_counts >= min_samples_per_strategy].index.tolist() if not valid_strategies: self.logger.warning(f"No strategies with at least {min_samples_per_strategy} samples") return None df = df[df['strategy_name'].isin(valid_strategies)] # Extract features (exclude metadata columns) feature_cols = [col for col in df.columns if col not in [ 'strategy_name', 'symbol', 'timeframe', 'market_regime', 'return_pct', 'sharpe_ratio', 'win_rate', 'max_drawdown', 'trade_count', 'timestamp' ]] # Fill missing features with 0 for col in feature_cols: if col not in df.columns: df[col] = 0.0 df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0) X = df[feature_cols] y = df['strategy_name'].values return { 'X': X, 'y': y, 'strategy_names': valid_strategies, 'feature_names': feature_cols, 'training_symbols': sorted(df['symbol'].unique().tolist()) } except Exception as e: self.logger.error(f"Error preparing training data: {e}") return None async def get_strategy_sample_counts(self, days: int = 365) -> Dict[str, int]: """Get sample counts per strategy. Args: days: Number of days to look back Returns: Dictionary mapping strategy names to sample counts """ from sqlalchemy import select, func try: async with self.db.get_session() as session: from src.core.database import StrategyPerformance cutoff_date = datetime.utcnow() - timedelta(days=days) # Group by strategy name and count stmt = select( StrategyPerformance.strategy_name, func.count(StrategyPerformance.id) ).where( StrategyPerformance.timestamp >= cutoff_date ).group_by( StrategyPerformance.strategy_name ) result = await session.execute(stmt) counts = dict(result.all()) return counts except Exception as e: self.logger.error(f"Error getting strategy sample counts: {e}") return {} # Global instance _performance_tracker: Optional[PerformanceTracker] = None def get_performance_tracker() -> PerformanceTracker: """Get global performance tracker instance.""" global _performance_tracker if _performance_tracker is None: _performance_tracker = PerformanceTracker() return _performance_tracker