"""Drawdown and loss limits.""" from decimal import Decimal from datetime import datetime, timedelta from typing import Dict, Any from sqlalchemy.orm import Session from src.core.database import get_database, PortfolioSnapshot, Trade from src.core.config import get_config from src.core.logger import get_logger logger = get_logger(__name__) class RiskLimitsManager: """Manages risk limits (drawdown, daily loss, etc.).""" def __init__(self): """Initialize risk limits manager.""" self.db = get_database() self.config = get_config() self.logger = get_logger(__name__) async def check_daily_loss_limit(self) -> bool: """Check if daily loss limit is exceeded.""" daily_loss_limit = Decimal(str( self.config.get("risk.daily_loss_limit_percent", 5.0) )) / 100 daily_pnl = await self.get_daily_pnl() portfolio_value = await self.get_portfolio_value() if portfolio_value == 0: return True daily_loss_percent = abs(daily_pnl) / portfolio_value if daily_pnl < 0 else Decimal(0) if daily_loss_percent > daily_loss_limit: self.logger.warning(f"Daily loss limit exceeded: {daily_loss_percent:.2%} > {daily_loss_limit:.2%}") return False return True async def check_max_drawdown(self) -> bool: """Check if maximum drawdown limit is exceeded.""" max_drawdown_limit = Decimal(str( self.config.get("risk.max_drawdown_percent", 20.0) )) / 100 current_drawdown = await self.get_current_drawdown() if current_drawdown > max_drawdown_limit: self.logger.warning(f"Max drawdown limit exceeded: {current_drawdown:.2%} > {max_drawdown_limit:.2%}") return False return True async def check_portfolio_allocation(self, symbol: str, position_value: Decimal) -> bool: """Check portfolio allocation limits.""" # Default: max 20% per asset max_allocation_percent = Decimal("0.20") portfolio_value = await self.get_portfolio_value() if portfolio_value == 0: return True allocation = position_value / portfolio_value if allocation > max_allocation_percent: self.logger.warning(f"Portfolio allocation exceeded for {symbol}: {allocation:.2%}") return False return True async def get_daily_pnl(self) -> Decimal: """Get today's P&L.""" from sqlalchemy import select async with self.db.get_session() as session: try: today = datetime.utcnow().date() start_of_day = datetime.combine(today, datetime.min.time()) # Get trades from today stmt = select(Trade).where( Trade.timestamp >= start_of_day, Trade.paper_trading == True ) result = await session.execute(stmt) trades = result.scalars().all() # Calculate P&L from trades pnl = Decimal(0) for trade in trades: trade_value = trade.quantity * trade.price fee = trade.fee if trade.fee else Decimal(0) if trade.side.value == "sell": pnl += trade_value - fee else: pnl -= trade_value + fee return pnl except Exception as e: self.logger.error(f"Error calculating daily P&L: {e}") return Decimal(0) async def get_current_drawdown(self) -> Decimal: """Get current drawdown percentage.""" from sqlalchemy import select async with self.db.get_session() as session: try: # Get peak portfolio value stmt_peak = select(PortfolioSnapshot).order_by( PortfolioSnapshot.total_value.desc() ).limit(1) result_peak = await session.execute(stmt_peak) peak = result_peak.scalar_one_or_none() if not peak: return Decimal(0) # Get current value stmt_current = select(PortfolioSnapshot).order_by( PortfolioSnapshot.timestamp.desc() ).limit(1) result_current = await session.execute(stmt_current) current = result_current.scalar_one_or_none() if not current or current.total_value >= peak.total_value: return Decimal(0) drawdown = (peak.total_value - current.total_value) / peak.total_value return drawdown except Exception as e: self.logger.error(f"Error calculating drawdown: {e}") return Decimal(0) async def get_portfolio_value(self) -> Decimal: """Get current portfolio value.""" from sqlalchemy import select async with self.db.get_session() as session: try: stmt = select(PortfolioSnapshot).order_by( PortfolioSnapshot.timestamp.desc() ).limit(1) result = await session.execute(stmt) latest = result.scalar_one_or_none() if latest: return latest.total_value return Decimal(0) except Exception as e: self.logger.error(f"Error getting portfolio value: {e}") return Decimal(0) def get_all_limits(self) -> Dict[str, Any]: """Get all risk limits configuration. Returns: Dictionary of limit configurations """ return { 'max_drawdown_percent': self.config.get("risk.max_drawdown_percent", 20.0), 'daily_loss_limit_percent': self.config.get("risk.daily_loss_limit_percent", 5.0), 'position_size_percent': self.config.get("risk.position_size_percent", 2.0), }