"""Base strategy class and strategy registry system.""" import pandas as pd from abc import ABC, abstractmethod from decimal import Decimal from typing import Dict, Optional, List, Any, Callable from datetime import datetime from enum import Enum from src.core.logger import get_logger from src.core.database import OrderSide, OrderType logger = get_logger(__name__) class SignalType(str, Enum): """Trading signal types.""" BUY = "buy" SELL = "sell" HOLD = "hold" CLOSE = "close" class StrategySignal: """Trading signal from strategy.""" def __init__( self, signal_type: SignalType, symbol: str, strength: float = 1.0, price: Optional[Decimal] = None, quantity: Optional[Decimal] = None, metadata: Optional[Dict[str, Any]] = None ): """Initialize strategy signal. Args: signal_type: Signal type (buy, sell, hold, close) symbol: Trading symbol strength: Signal strength (0.0 to 1.0) price: Suggested price quantity: Suggested quantity metadata: Additional metadata """ self.signal_type = signal_type self.symbol = symbol self.strength = strength self.price = price self.quantity = quantity self.metadata = metadata or {} self.timestamp = datetime.utcnow() class BaseStrategy(ABC): """Base class for all trading strategies.""" def __init__( self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[List[str]] = None ): """Initialize strategy. Args: name: Strategy name parameters: Strategy parameters timeframes: List of timeframes (e.g., ['1h', '15m']) """ self.name = name self.parameters = parameters or {} self.timeframes = timeframes or ['1h'] self.enabled = False self.logger = get_logger(f"strategy.{name}") self._data_cache: Dict[str, Any] = {} @abstractmethod async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]: """Called on each price update. Args: symbol: Trading symbol price: Current price timeframe: Timeframe of the update data: Additional market data Returns: StrategySignal or None """ pass @abstractmethod def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]: """Process and potentially modify signal. Args: signal: Generated signal Returns: Modified signal or None to cancel """ pass def calculate_position_size( self, signal: StrategySignal, balance: Decimal, price: Decimal, exchange_adapter=None ) -> Decimal: """Calculate position size for signal, accounting for fees. Args: signal: Trading signal balance: Available balance price: Current price exchange_adapter: Exchange adapter for fee calculation (optional) Returns: Position size """ # Default: use 2% of balance risk_percent = self.parameters.get('position_size_percent', 2.0) / 100.0 position_value = balance * Decimal(str(risk_percent)) # Account for fees by reserving fee amount from src.trading.fee_calculator import get_fee_calculator fee_calculator = get_fee_calculator() # Reserve ~0.4% for round-trip fees (conservative estimate) fee_reserve = fee_calculator.calculate_fee_reserve( position_value=position_value, exchange_adapter=exchange_adapter, reserve_percent=0.004 # 0.4% for round-trip ) # Adjust position value to account for fees adjusted_position_value = position_value - fee_reserve # Calculate quantity if price > 0: quantity = adjusted_position_value / price return max(Decimal(0), quantity) # Ensure non-negative return Decimal(0) def should_execute(self, signal: StrategySignal) -> bool: """Check if signal should be executed. Args: signal: Trading signal Returns: True if should execute """ if not self.enabled: return False # Check signal strength threshold min_strength = self.parameters.get('min_signal_strength', 0.5) if signal.strength < min_strength: return False return True def should_execute_with_fees( self, signal: StrategySignal, balance: Decimal, price: Decimal, exchange_adapter=None ) -> bool: """Check if signal should be executed considering fees and minimum profit threshold. Args: signal: Trading signal balance: Available balance price: Current price exchange_adapter: Exchange adapter for fee calculation (optional) Returns: True if should execute after fee consideration """ # First check basic execution criteria if not self.should_execute(signal): return False # Calculate position size quantity = signal.quantity or self.calculate_position_size(signal, balance, price, exchange_adapter) if quantity <= 0: return False # Check minimum profit threshold from src.trading.fee_calculator import get_fee_calculator fee_calculator = get_fee_calculator() # Get minimum profit multiplier from strategy parameters (default 2.0) min_profit_multiplier = self.parameters.get('min_profit_multiplier', 2.0) min_profit_threshold = fee_calculator.get_minimum_profit_threshold( quantity=quantity, price=price, exchange_adapter=exchange_adapter, multiplier=min_profit_multiplier ) # Estimate potential profit (simplified - strategies can override) # For buy signals, we'd need to estimate exit price # For now, we'll use a basic check: if signal strength is high enough # Strategies should override this method for more sophisticated checks # If we have a target price in signal metadata, use it target_price = signal.metadata.get('target_price') if target_price: if signal.signal_type.value == "buy": potential_profit = (target_price - price) * quantity else: # sell potential_profit = (price - target_price) * quantity if potential_profit < min_profit_threshold: self.logger.debug( f"Signal filtered: potential profit {potential_profit} < " f"minimum threshold {min_profit_threshold}" ) return False return True def apply_trend_filter( self, signal: StrategySignal, ohlcv_data: Any, adx_period: int = 14, min_adx: float = 25.0 ) -> Optional[StrategySignal]: """Apply ADX-based trend filter to signal. Filters signals based on trend strength: - Only allow BUY signals when ADX > threshold (strong trend) - Only allow SELL signals in downtrends with ADX > threshold - Filters out choppy/ranging markets Args: signal: Trading signal to filter ohlcv_data: OHLCV DataFrame with columns: high, low, close adx_period: ADX calculation period (default 14) min_adx: Minimum ADX value for signal (default 25.0) Returns: Filtered signal or None if filtered out """ if not self.parameters.get('use_trend_filter', False): return signal try: from src.data.indicators import get_indicators if ohlcv_data is None or len(ohlcv_data) < adx_period: # Not enough data, allow signal return signal # Ensure we have a DataFrame if not isinstance(ohlcv_data, pd.DataFrame): return signal indicators = get_indicators() # Calculate ADX high = ohlcv_data['high'] low = ohlcv_data['low'] close = ohlcv_data['close'] adx = indicators.adx(high, low, close, period=adx_period) current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0 # Check if trend is strong enough if current_adx < min_adx: # Weak trend - filter out signal self.logger.debug( f"Trend filter: ADX {current_adx:.2f} < {min_adx}, " f"filtering {signal.signal_type.value} signal" ) return None # Additional check: for BUY signals, ensure uptrend # For SELL signals, ensure downtrend # We can use price vs moving average to determine trend direction if len(close) >= 20: sma_20 = indicators.sma(close, period=20) current_price = close.iloc[-1] sma_value = sma_20.iloc[-1] if not pd.isna(sma_20.iloc[-1]) else current_price if signal.signal_type == SignalType.BUY: # BUY only in uptrend (price above SMA) if current_price < sma_value: self.logger.debug( f"Trend filter: BUY signal filtered - price below SMA " f"(price: {current_price}, SMA: {sma_value})" ) return None elif signal.signal_type == SignalType.SELL: # SELL only in downtrend (price below SMA) if current_price > sma_value: self.logger.debug( f"Trend filter: SELL signal filtered - price above SMA " f"(price: {current_price}, SMA: {sma_value})" ) return None return signal except Exception as e: self.logger.warning(f"Error applying trend filter: {e}, allowing signal") return signal def get_required_indicators(self) -> List[str]: """Get list of required indicators. Returns: List of indicator names """ return [] def validate_parameters(self) -> bool: """Validate strategy parameters. Returns: True if parameters are valid """ return True def get_state(self) -> Dict[str, Any]: """Get strategy state for persistence. Returns: State dictionary """ return { 'name': self.name, 'parameters': self.parameters, 'timeframes': self.timeframes, 'enabled': self.enabled, } def set_state(self, state: Dict[str, Any]): """Restore strategy state. Args: state: State dictionary """ self.parameters = state.get('parameters', {}) self.timeframes = state.get('timeframes', ['1h']) self.enabled = state.get('enabled', False) class StrategyRegistry: """Registry for managing strategies.""" def __init__(self): """Initialize strategy registry.""" self._strategies: Dict[str, type] = {} self._instances: Dict[int, BaseStrategy] = {} self.logger = get_logger(__name__) def register(self, name: str, strategy_class: type): """Register a strategy class. Args: name: Strategy name strategy_class: Strategy class (subclass of BaseStrategy) """ if not issubclass(strategy_class, BaseStrategy): raise ValueError(f"Strategy class must inherit from BaseStrategy") self._strategies[name.lower()] = strategy_class self.logger.info(f"Registered strategy: {name}") def create_instance( self, strategy_id: int, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[List[str]] = None ) -> Optional[BaseStrategy]: """Create strategy instance. Args: strategy_id: Strategy ID from src.database name: Strategy name parameters: Strategy parameters timeframes: List of timeframes Returns: Strategy instance or None """ strategy_class = self._strategies.get(name.lower()) if not strategy_class: self.logger.error(f"Strategy {name} not registered") return None try: instance = strategy_class(name, parameters, timeframes) self._instances[strategy_id] = instance return instance except Exception as e: self.logger.error(f"Failed to create strategy instance: {e}") return None def get_instance(self, strategy_id: int) -> Optional[BaseStrategy]: """Get strategy instance by ID. Args: strategy_id: Strategy ID Returns: Strategy instance or None """ return self._instances.get(strategy_id) def list_available(self) -> List[str]: """List available strategy types. Returns: List of strategy names """ return list(self._strategies.keys()) def unregister(self, name: str): """Unregister a strategy. Args: name: Strategy name """ if name.lower() in self._strategies: del self._strategies[name.lower()] self.logger.info(f"Unregistered strategy: {name}") # Global strategy registry _registry: Optional[StrategyRegistry] = None def get_strategy_registry() -> StrategyRegistry: """Get global strategy registry instance.""" global _registry if _registry is None: _registry = StrategyRegistry() return _registry