352 lines
14 KiB
Python
352 lines
14 KiB
Python
|
|
"""Strategy performance tracker for ML training data collection."""
|
||
|
|
|
||
|
|
from decimal import Decimal
|
||
|
|
from typing import Dict, Any, Optional, List
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from src.core.database import get_database, Trade, Strategy
|
||
|
|
from src.core.logger import get_logger
|
||
|
|
from .market_analyzer import MarketConditions
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class PerformanceTracker:
|
||
|
|
"""Tracks strategy performance for ML training."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
"""Initialize performance tracker."""
|
||
|
|
self.db = get_database()
|
||
|
|
self.logger = get_logger(__name__)
|
||
|
|
|
||
|
|
async def record_trade(
|
||
|
|
self,
|
||
|
|
strategy_name: str,
|
||
|
|
market_conditions: MarketConditions,
|
||
|
|
trade_result: Dict[str, Any]
|
||
|
|
) -> bool:
|
||
|
|
"""Record a trade for ML training.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
strategy_name: Name of strategy used
|
||
|
|
market_conditions: Market conditions at trade time
|
||
|
|
trade_result: Trade result with performance metrics
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if recorded successfully
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
async with self.db.get_session() as session:
|
||
|
|
try:
|
||
|
|
# Store market conditions snapshot
|
||
|
|
from src.core.database import MarketConditionsSnapshot
|
||
|
|
|
||
|
|
snapshot = MarketConditionsSnapshot(
|
||
|
|
symbol=market_conditions.symbol,
|
||
|
|
timeframe=market_conditions.timeframe,
|
||
|
|
regime=market_conditions.regime.value,
|
||
|
|
features=market_conditions.features,
|
||
|
|
strategy_name=strategy_name,
|
||
|
|
timestamp=market_conditions.timestamp
|
||
|
|
)
|
||
|
|
session.add(snapshot)
|
||
|
|
|
||
|
|
# Store performance record
|
||
|
|
from src.core.database import StrategyPerformance
|
||
|
|
|
||
|
|
performance = StrategyPerformance(
|
||
|
|
strategy_name=strategy_name,
|
||
|
|
symbol=market_conditions.symbol,
|
||
|
|
timeframe=market_conditions.timeframe,
|
||
|
|
market_regime=market_conditions.regime.value,
|
||
|
|
return_pct=float(trade_result.get('return_pct', 0.0)),
|
||
|
|
sharpe_ratio=float(trade_result.get('sharpe_ratio', 0.0)),
|
||
|
|
win_rate=float(trade_result.get('win_rate', 0.0)),
|
||
|
|
max_drawdown=float(trade_result.get('max_drawdown', 0.0)),
|
||
|
|
trade_count=int(trade_result.get('trade_count', 1)),
|
||
|
|
timestamp=datetime.utcnow()
|
||
|
|
)
|
||
|
|
session.add(performance)
|
||
|
|
|
||
|
|
await session.commit()
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
self.logger.error(f"Failed to record trade: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error recording trade: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def get_performance_history(
|
||
|
|
self,
|
||
|
|
strategy_name: Optional[str] = None,
|
||
|
|
market_regime: Optional[str] = None,
|
||
|
|
days: int = 30
|
||
|
|
) -> pd.DataFrame:
|
||
|
|
"""Get performance history for training.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
strategy_name: Filter by strategy name
|
||
|
|
market_regime: Filter by market regime
|
||
|
|
days: Number of days to look back
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
DataFrame with performance history
|
||
|
|
"""
|
||
|
|
from sqlalchemy import select
|
||
|
|
try:
|
||
|
|
async with self.db.get_session() as session:
|
||
|
|
from src.core.database import StrategyPerformance, MarketConditionsSnapshot
|
||
|
|
|
||
|
|
# Query performance records
|
||
|
|
stmt = select(StrategyPerformance)
|
||
|
|
|
||
|
|
if strategy_name:
|
||
|
|
stmt = stmt.where(StrategyPerformance.strategy_name == strategy_name)
|
||
|
|
if market_regime:
|
||
|
|
stmt = stmt.where(StrategyPerformance.market_regime == market_regime)
|
||
|
|
|
||
|
|
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||
|
|
stmt = stmt.where(StrategyPerformance.timestamp >= cutoff_date)
|
||
|
|
# Limit to prevent excessive queries - if we have lots of data, sample it
|
||
|
|
stmt = stmt.order_by(StrategyPerformance.timestamp.desc()).limit(10000)
|
||
|
|
|
||
|
|
result = await session.execute(stmt)
|
||
|
|
records = result.scalars().all()
|
||
|
|
|
||
|
|
if not records:
|
||
|
|
return pd.DataFrame()
|
||
|
|
|
||
|
|
self.logger.info(f"Processing {len(records)} performance records for training data")
|
||
|
|
|
||
|
|
# Convert to DataFrame - optimize by batching snapshot queries
|
||
|
|
data = []
|
||
|
|
# Batch snapshot lookups to reduce N+1 query problem
|
||
|
|
snapshot_cache = {}
|
||
|
|
for record in records:
|
||
|
|
cache_key = f"{record.strategy_name}:{record.symbol}:{record.timeframe}:{record.timestamp.date()}"
|
||
|
|
|
||
|
|
if cache_key not in snapshot_cache:
|
||
|
|
# Get corresponding market conditions (only once per day per strategy)
|
||
|
|
snapshot_stmt = select(MarketConditionsSnapshot).filter_by(
|
||
|
|
strategy_name=record.strategy_name,
|
||
|
|
symbol=record.symbol,
|
||
|
|
timeframe=record.timeframe
|
||
|
|
).where(
|
||
|
|
MarketConditionsSnapshot.timestamp <= record.timestamp
|
||
|
|
).order_by(
|
||
|
|
MarketConditionsSnapshot.timestamp.desc()
|
||
|
|
).limit(1)
|
||
|
|
|
||
|
|
snapshot_result = await session.execute(snapshot_stmt)
|
||
|
|
snapshot = snapshot_result.scalar_one_or_none()
|
||
|
|
snapshot_cache[cache_key] = snapshot
|
||
|
|
else:
|
||
|
|
snapshot = snapshot_cache[cache_key]
|
||
|
|
|
||
|
|
row = {
|
||
|
|
'strategy_name': record.strategy_name,
|
||
|
|
'symbol': record.symbol,
|
||
|
|
'timeframe': record.timeframe,
|
||
|
|
'market_regime': record.market_regime,
|
||
|
|
'return_pct': float(record.return_pct),
|
||
|
|
'sharpe_ratio': float(record.sharpe_ratio),
|
||
|
|
'win_rate': float(record.win_rate),
|
||
|
|
'max_drawdown': float(record.max_drawdown),
|
||
|
|
'trade_count': int(record.trade_count),
|
||
|
|
'timestamp': record.timestamp
|
||
|
|
}
|
||
|
|
|
||
|
|
# Add market features if available
|
||
|
|
if snapshot and snapshot.features:
|
||
|
|
row.update(snapshot.features)
|
||
|
|
|
||
|
|
data.append(row)
|
||
|
|
|
||
|
|
return pd.DataFrame(data)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error getting performance history: {e}")
|
||
|
|
return pd.DataFrame()
|
||
|
|
|
||
|
|
async def calculate_metrics(
|
||
|
|
self,
|
||
|
|
strategy_name: str,
|
||
|
|
period_days: int = 30
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""Calculate performance metrics for a strategy.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
strategy_name: Strategy name
|
||
|
|
period_days: Period to calculate metrics over
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary of performance metrics
|
||
|
|
"""
|
||
|
|
from sqlalchemy import select
|
||
|
|
try:
|
||
|
|
async with self.db.get_session() as session:
|
||
|
|
from src.core.database import StrategyPerformance
|
||
|
|
|
||
|
|
cutoff_date = datetime.utcnow() - timedelta(days=period_days)
|
||
|
|
|
||
|
|
stmt = select(StrategyPerformance).where(
|
||
|
|
StrategyPerformance.strategy_name == strategy_name,
|
||
|
|
StrategyPerformance.timestamp >= cutoff_date
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await session.execute(stmt)
|
||
|
|
records = result.scalars().all()
|
||
|
|
|
||
|
|
if not records:
|
||
|
|
return {
|
||
|
|
'total_trades': 0,
|
||
|
|
'win_rate': 0.0,
|
||
|
|
'avg_return': 0.0,
|
||
|
|
'total_return': 0.0,
|
||
|
|
'sharpe_ratio': 0.0,
|
||
|
|
'max_drawdown': 0.0
|
||
|
|
}
|
||
|
|
|
||
|
|
returns = [float(r.return_pct) for r in records]
|
||
|
|
wins = [r for r in returns if r > 0]
|
||
|
|
|
||
|
|
total_return = sum(returns)
|
||
|
|
avg_return = np.mean(returns) if returns else 0.0
|
||
|
|
win_rate = len(wins) / len(returns) if returns else 0.0
|
||
|
|
|
||
|
|
# Calculate Sharpe ratio (simplified)
|
||
|
|
if len(returns) > 1:
|
||
|
|
sharpe = (avg_return / np.std(returns)) * np.sqrt(252) if np.std(returns) > 0 else 0.0
|
||
|
|
else:
|
||
|
|
sharpe = 0.0
|
||
|
|
|
||
|
|
# Max drawdown
|
||
|
|
max_dd = max([float(r.max_drawdown) for r in records]) if records else 0.0
|
||
|
|
|
||
|
|
return {
|
||
|
|
'total_trades': len(records),
|
||
|
|
'win_rate': float(win_rate),
|
||
|
|
'avg_return': float(avg_return),
|
||
|
|
'total_return': float(total_return),
|
||
|
|
'sharpe_ratio': float(sharpe),
|
||
|
|
'max_drawdown': float(max_dd)
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error calculating metrics: {e}")
|
||
|
|
return {}
|
||
|
|
|
||
|
|
async def prepare_training_data(
|
||
|
|
self,
|
||
|
|
min_samples_per_strategy: int = 10
|
||
|
|
) -> Optional[Dict[str, Any]]:
|
||
|
|
"""Prepare training data for ML model.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
min_samples_per_strategy: Minimum samples required per strategy
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary with 'X' (features), 'y' (targets), and 'strategy_names'
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Get all performance history - extended to 365 days for better training
|
||
|
|
df = await self.get_performance_history(days=365)
|
||
|
|
|
||
|
|
if df.empty:
|
||
|
|
self.logger.warning("No performance history available for training")
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Filter strategies with enough samples
|
||
|
|
strategy_counts = df['strategy_name'].value_counts()
|
||
|
|
valid_strategies = strategy_counts[strategy_counts >= min_samples_per_strategy].index.tolist()
|
||
|
|
|
||
|
|
if not valid_strategies:
|
||
|
|
self.logger.warning(f"No strategies with at least {min_samples_per_strategy} samples")
|
||
|
|
return None
|
||
|
|
|
||
|
|
df = df[df['strategy_name'].isin(valid_strategies)]
|
||
|
|
|
||
|
|
# Extract features (exclude metadata columns)
|
||
|
|
feature_cols = [col for col in df.columns if col not in [
|
||
|
|
'strategy_name', 'symbol', 'timeframe', 'market_regime',
|
||
|
|
'return_pct', 'sharpe_ratio', 'win_rate', 'max_drawdown',
|
||
|
|
'trade_count', 'timestamp'
|
||
|
|
]]
|
||
|
|
|
||
|
|
# Fill missing features with 0
|
||
|
|
for col in feature_cols:
|
||
|
|
if col not in df.columns:
|
||
|
|
df[col] = 0.0
|
||
|
|
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
|
||
|
|
|
||
|
|
X = df[feature_cols]
|
||
|
|
y = df['strategy_name'].values
|
||
|
|
|
||
|
|
return {
|
||
|
|
'X': X,
|
||
|
|
'y': y,
|
||
|
|
'strategy_names': valid_strategies,
|
||
|
|
'feature_names': feature_cols,
|
||
|
|
'training_symbols': sorted(df['symbol'].unique().tolist())
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error preparing training data: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
async def get_strategy_sample_counts(self, days: int = 365) -> Dict[str, int]:
|
||
|
|
"""Get sample counts per strategy.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
days: Number of days to look back
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary mapping strategy names to sample counts
|
||
|
|
"""
|
||
|
|
from sqlalchemy import select, func
|
||
|
|
try:
|
||
|
|
async with self.db.get_session() as session:
|
||
|
|
from src.core.database import StrategyPerformance
|
||
|
|
|
||
|
|
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||
|
|
|
||
|
|
# Group by strategy name and count
|
||
|
|
stmt = select(
|
||
|
|
StrategyPerformance.strategy_name,
|
||
|
|
func.count(StrategyPerformance.id)
|
||
|
|
).where(
|
||
|
|
StrategyPerformance.timestamp >= cutoff_date
|
||
|
|
).group_by(
|
||
|
|
StrategyPerformance.strategy_name
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await session.execute(stmt)
|
||
|
|
counts = dict(result.all())
|
||
|
|
|
||
|
|
return counts
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error getting strategy sample counts: {e}")
|
||
|
|
return {}
|
||
|
|
|
||
|
|
|
||
|
|
# Global instance
|
||
|
|
_performance_tracker: Optional[PerformanceTracker] = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_performance_tracker() -> PerformanceTracker:
|
||
|
|
"""Get global performance tracker instance."""
|
||
|
|
global _performance_tracker
|
||
|
|
if _performance_tracker is None:
|
||
|
|
_performance_tracker = PerformanceTracker()
|
||
|
|
return _performance_tracker
|
||
|
|
|