167 lines
6.2 KiB
Python
167 lines
6.2 KiB
Python
|
|
"""Drawdown and loss limits."""
|
||
|
|
|
||
|
|
from decimal import Decimal
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
from typing import Dict, Any
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
from src.core.database import get_database, PortfolioSnapshot, Trade
|
||
|
|
from src.core.config import get_config
|
||
|
|
from src.core.logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class RiskLimitsManager:
|
||
|
|
"""Manages risk limits (drawdown, daily loss, etc.)."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
"""Initialize risk limits manager."""
|
||
|
|
self.db = get_database()
|
||
|
|
self.config = get_config()
|
||
|
|
self.logger = get_logger(__name__)
|
||
|
|
|
||
|
|
async def check_daily_loss_limit(self) -> bool:
|
||
|
|
"""Check if daily loss limit is exceeded."""
|
||
|
|
daily_loss_limit = Decimal(str(
|
||
|
|
self.config.get("risk.daily_loss_limit_percent", 5.0)
|
||
|
|
)) / 100
|
||
|
|
|
||
|
|
daily_pnl = await self.get_daily_pnl()
|
||
|
|
portfolio_value = await self.get_portfolio_value()
|
||
|
|
|
||
|
|
if portfolio_value == 0:
|
||
|
|
return True
|
||
|
|
|
||
|
|
daily_loss_percent = abs(daily_pnl) / portfolio_value if daily_pnl < 0 else Decimal(0)
|
||
|
|
|
||
|
|
if daily_loss_percent > daily_loss_limit:
|
||
|
|
self.logger.warning(f"Daily loss limit exceeded: {daily_loss_percent:.2%} > {daily_loss_limit:.2%}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
async def check_max_drawdown(self) -> bool:
|
||
|
|
"""Check if maximum drawdown limit is exceeded."""
|
||
|
|
max_drawdown_limit = Decimal(str(
|
||
|
|
self.config.get("risk.max_drawdown_percent", 20.0)
|
||
|
|
)) / 100
|
||
|
|
|
||
|
|
current_drawdown = await self.get_current_drawdown()
|
||
|
|
|
||
|
|
if current_drawdown > max_drawdown_limit:
|
||
|
|
self.logger.warning(f"Max drawdown limit exceeded: {current_drawdown:.2%} > {max_drawdown_limit:.2%}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
async def check_portfolio_allocation(self, symbol: str, position_value: Decimal) -> bool:
|
||
|
|
"""Check portfolio allocation limits."""
|
||
|
|
# Default: max 20% per asset
|
||
|
|
max_allocation_percent = Decimal("0.20")
|
||
|
|
portfolio_value = await self.get_portfolio_value()
|
||
|
|
|
||
|
|
if portfolio_value == 0:
|
||
|
|
return True
|
||
|
|
|
||
|
|
allocation = position_value / portfolio_value
|
||
|
|
|
||
|
|
if allocation > max_allocation_percent:
|
||
|
|
self.logger.warning(f"Portfolio allocation exceeded for {symbol}: {allocation:.2%}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
async def get_daily_pnl(self) -> Decimal:
|
||
|
|
"""Get today's P&L."""
|
||
|
|
from sqlalchemy import select
|
||
|
|
async with self.db.get_session() as session:
|
||
|
|
try:
|
||
|
|
today = datetime.utcnow().date()
|
||
|
|
start_of_day = datetime.combine(today, datetime.min.time())
|
||
|
|
|
||
|
|
# Get trades from today
|
||
|
|
stmt = select(Trade).where(
|
||
|
|
Trade.timestamp >= start_of_day,
|
||
|
|
Trade.paper_trading == True
|
||
|
|
)
|
||
|
|
result = await session.execute(stmt)
|
||
|
|
trades = result.scalars().all()
|
||
|
|
|
||
|
|
# Calculate P&L from trades
|
||
|
|
pnl = Decimal(0)
|
||
|
|
for trade in trades:
|
||
|
|
trade_value = trade.quantity * trade.price
|
||
|
|
fee = trade.fee if trade.fee else Decimal(0)
|
||
|
|
|
||
|
|
if trade.side.value == "sell":
|
||
|
|
pnl += trade_value - fee
|
||
|
|
else:
|
||
|
|
pnl -= trade_value + fee
|
||
|
|
|
||
|
|
return pnl
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error calculating daily P&L: {e}")
|
||
|
|
return Decimal(0)
|
||
|
|
|
||
|
|
async def get_current_drawdown(self) -> Decimal:
|
||
|
|
"""Get current drawdown percentage."""
|
||
|
|
from sqlalchemy import select
|
||
|
|
async with self.db.get_session() as session:
|
||
|
|
try:
|
||
|
|
# Get peak portfolio value
|
||
|
|
stmt_peak = select(PortfolioSnapshot).order_by(
|
||
|
|
PortfolioSnapshot.total_value.desc()
|
||
|
|
).limit(1)
|
||
|
|
result_peak = await session.execute(stmt_peak)
|
||
|
|
peak = result_peak.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not peak:
|
||
|
|
return Decimal(0)
|
||
|
|
|
||
|
|
# Get current value
|
||
|
|
stmt_current = select(PortfolioSnapshot).order_by(
|
||
|
|
PortfolioSnapshot.timestamp.desc()
|
||
|
|
).limit(1)
|
||
|
|
result_current = await session.execute(stmt_current)
|
||
|
|
current = result_current.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not current or current.total_value >= peak.total_value:
|
||
|
|
return Decimal(0)
|
||
|
|
|
||
|
|
drawdown = (peak.total_value - current.total_value) / peak.total_value
|
||
|
|
return drawdown
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error calculating drawdown: {e}")
|
||
|
|
return Decimal(0)
|
||
|
|
|
||
|
|
async def get_portfolio_value(self) -> Decimal:
|
||
|
|
"""Get current portfolio value."""
|
||
|
|
from sqlalchemy import select
|
||
|
|
async with self.db.get_session() as session:
|
||
|
|
try:
|
||
|
|
stmt = select(PortfolioSnapshot).order_by(
|
||
|
|
PortfolioSnapshot.timestamp.desc()
|
||
|
|
).limit(1)
|
||
|
|
result = await session.execute(stmt)
|
||
|
|
latest = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if latest:
|
||
|
|
return latest.total_value
|
||
|
|
return Decimal(0)
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error getting portfolio value: {e}")
|
||
|
|
return Decimal(0)
|
||
|
|
|
||
|
|
def get_all_limits(self) -> Dict[str, Any]:
|
||
|
|
"""Get all risk limits configuration.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary of limit configurations
|
||
|
|
"""
|
||
|
|
return {
|
||
|
|
'max_drawdown_percent': self.config.get("risk.max_drawdown_percent", 20.0),
|
||
|
|
'daily_loss_limit_percent': self.config.get("risk.daily_loss_limit_percent", 5.0),
|
||
|
|
'position_size_percent': self.config.get("risk.position_size_percent", 2.0),
|
||
|
|
}
|
||
|
|
|