451 lines
14 KiB
Python
451 lines
14 KiB
Python
|
|
"""Base strategy class and strategy registry system."""
|
||
|
|
|
||
|
|
import pandas as pd
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
from decimal import Decimal
|
||
|
|
from typing import Dict, Optional, List, Any, Callable
|
||
|
|
from datetime import datetime
|
||
|
|
from enum import Enum
|
||
|
|
from src.core.logger import get_logger
|
||
|
|
from src.core.database import OrderSide, OrderType
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class SignalType(str, Enum):
|
||
|
|
"""Trading signal types."""
|
||
|
|
BUY = "buy"
|
||
|
|
SELL = "sell"
|
||
|
|
HOLD = "hold"
|
||
|
|
CLOSE = "close"
|
||
|
|
|
||
|
|
|
||
|
|
class StrategySignal:
|
||
|
|
"""Trading signal from strategy."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
signal_type: SignalType,
|
||
|
|
symbol: str,
|
||
|
|
strength: float = 1.0,
|
||
|
|
price: Optional[Decimal] = None,
|
||
|
|
quantity: Optional[Decimal] = None,
|
||
|
|
metadata: Optional[Dict[str, Any]] = None
|
||
|
|
):
|
||
|
|
"""Initialize strategy signal.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
signal_type: Signal type (buy, sell, hold, close)
|
||
|
|
symbol: Trading symbol
|
||
|
|
strength: Signal strength (0.0 to 1.0)
|
||
|
|
price: Suggested price
|
||
|
|
quantity: Suggested quantity
|
||
|
|
metadata: Additional metadata
|
||
|
|
"""
|
||
|
|
self.signal_type = signal_type
|
||
|
|
self.symbol = symbol
|
||
|
|
self.strength = strength
|
||
|
|
self.price = price
|
||
|
|
self.quantity = quantity
|
||
|
|
self.metadata = metadata or {}
|
||
|
|
self.timestamp = datetime.utcnow()
|
||
|
|
|
||
|
|
|
||
|
|
class BaseStrategy(ABC):
|
||
|
|
"""Base class for all trading strategies."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
name: str,
|
||
|
|
parameters: Optional[Dict[str, Any]] = None,
|
||
|
|
timeframes: Optional[List[str]] = None
|
||
|
|
):
|
||
|
|
"""Initialize strategy.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: Strategy name
|
||
|
|
parameters: Strategy parameters
|
||
|
|
timeframes: List of timeframes (e.g., ['1h', '15m'])
|
||
|
|
"""
|
||
|
|
self.name = name
|
||
|
|
self.parameters = parameters or {}
|
||
|
|
self.timeframes = timeframes or ['1h']
|
||
|
|
self.enabled = False
|
||
|
|
self.logger = get_logger(f"strategy.{name}")
|
||
|
|
self._data_cache: Dict[str, Any] = {}
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||
|
|
"""Called on each price update.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
symbol: Trading symbol
|
||
|
|
price: Current price
|
||
|
|
timeframe: Timeframe of the update
|
||
|
|
data: Additional market data
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
StrategySignal or None
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||
|
|
"""Process and potentially modify signal.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
signal: Generated signal
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Modified signal or None to cancel
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
def calculate_position_size(
|
||
|
|
self,
|
||
|
|
signal: StrategySignal,
|
||
|
|
balance: Decimal,
|
||
|
|
price: Decimal,
|
||
|
|
exchange_adapter=None
|
||
|
|
) -> Decimal:
|
||
|
|
"""Calculate position size for signal, accounting for fees.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
signal: Trading signal
|
||
|
|
balance: Available balance
|
||
|
|
price: Current price
|
||
|
|
exchange_adapter: Exchange adapter for fee calculation (optional)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Position size
|
||
|
|
"""
|
||
|
|
# Default: use 2% of balance
|
||
|
|
risk_percent = self.parameters.get('position_size_percent', 2.0) / 100.0
|
||
|
|
position_value = balance * Decimal(str(risk_percent))
|
||
|
|
|
||
|
|
# Account for fees by reserving fee amount
|
||
|
|
from src.trading.fee_calculator import get_fee_calculator
|
||
|
|
fee_calculator = get_fee_calculator()
|
||
|
|
|
||
|
|
# Reserve ~0.4% for round-trip fees (conservative estimate)
|
||
|
|
fee_reserve = fee_calculator.calculate_fee_reserve(
|
||
|
|
position_value=position_value,
|
||
|
|
exchange_adapter=exchange_adapter,
|
||
|
|
reserve_percent=0.004 # 0.4% for round-trip
|
||
|
|
)
|
||
|
|
|
||
|
|
# Adjust position value to account for fees
|
||
|
|
adjusted_position_value = position_value - fee_reserve
|
||
|
|
|
||
|
|
# Calculate quantity
|
||
|
|
if price > 0:
|
||
|
|
quantity = adjusted_position_value / price
|
||
|
|
return max(Decimal(0), quantity) # Ensure non-negative
|
||
|
|
|
||
|
|
return Decimal(0)
|
||
|
|
|
||
|
|
def should_execute(self, signal: StrategySignal) -> bool:
|
||
|
|
"""Check if signal should be executed.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
signal: Trading signal
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if should execute
|
||
|
|
"""
|
||
|
|
if not self.enabled:
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Check signal strength threshold
|
||
|
|
min_strength = self.parameters.get('min_signal_strength', 0.5)
|
||
|
|
if signal.strength < min_strength:
|
||
|
|
return False
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def should_execute_with_fees(
|
||
|
|
self,
|
||
|
|
signal: StrategySignal,
|
||
|
|
balance: Decimal,
|
||
|
|
price: Decimal,
|
||
|
|
exchange_adapter=None
|
||
|
|
) -> bool:
|
||
|
|
"""Check if signal should be executed considering fees and minimum profit threshold.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
signal: Trading signal
|
||
|
|
balance: Available balance
|
||
|
|
price: Current price
|
||
|
|
exchange_adapter: Exchange adapter for fee calculation (optional)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if should execute after fee consideration
|
||
|
|
"""
|
||
|
|
# First check basic execution criteria
|
||
|
|
if not self.should_execute(signal):
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Calculate position size
|
||
|
|
quantity = signal.quantity or self.calculate_position_size(signal, balance, price, exchange_adapter)
|
||
|
|
|
||
|
|
if quantity <= 0:
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Check minimum profit threshold
|
||
|
|
from src.trading.fee_calculator import get_fee_calculator
|
||
|
|
fee_calculator = get_fee_calculator()
|
||
|
|
|
||
|
|
# Get minimum profit multiplier from strategy parameters (default 2.0)
|
||
|
|
min_profit_multiplier = self.parameters.get('min_profit_multiplier', 2.0)
|
||
|
|
|
||
|
|
min_profit_threshold = fee_calculator.get_minimum_profit_threshold(
|
||
|
|
quantity=quantity,
|
||
|
|
price=price,
|
||
|
|
exchange_adapter=exchange_adapter,
|
||
|
|
multiplier=min_profit_multiplier
|
||
|
|
)
|
||
|
|
|
||
|
|
# Estimate potential profit (simplified - strategies can override)
|
||
|
|
# For buy signals, we'd need to estimate exit price
|
||
|
|
# For now, we'll use a basic check: if signal strength is high enough
|
||
|
|
# Strategies should override this method for more sophisticated checks
|
||
|
|
|
||
|
|
# If we have a target price in signal metadata, use it
|
||
|
|
target_price = signal.metadata.get('target_price')
|
||
|
|
if target_price:
|
||
|
|
if signal.signal_type.value == "buy":
|
||
|
|
potential_profit = (target_price - price) * quantity
|
||
|
|
else: # sell
|
||
|
|
potential_profit = (price - target_price) * quantity
|
||
|
|
|
||
|
|
if potential_profit < min_profit_threshold:
|
||
|
|
self.logger.debug(
|
||
|
|
f"Signal filtered: potential profit {potential_profit} < "
|
||
|
|
f"minimum threshold {min_profit_threshold}"
|
||
|
|
)
|
||
|
|
return False
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def apply_trend_filter(
|
||
|
|
self,
|
||
|
|
signal: StrategySignal,
|
||
|
|
ohlcv_data: Any,
|
||
|
|
adx_period: int = 14,
|
||
|
|
min_adx: float = 25.0
|
||
|
|
) -> Optional[StrategySignal]:
|
||
|
|
"""Apply ADX-based trend filter to signal.
|
||
|
|
|
||
|
|
Filters signals based on trend strength:
|
||
|
|
- Only allow BUY signals when ADX > threshold (strong trend)
|
||
|
|
- Only allow SELL signals in downtrends with ADX > threshold
|
||
|
|
- Filters out choppy/ranging markets
|
||
|
|
|
||
|
|
Args:
|
||
|
|
signal: Trading signal to filter
|
||
|
|
ohlcv_data: OHLCV DataFrame with columns: high, low, close
|
||
|
|
adx_period: ADX calculation period (default 14)
|
||
|
|
min_adx: Minimum ADX value for signal (default 25.0)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Filtered signal or None if filtered out
|
||
|
|
"""
|
||
|
|
if not self.parameters.get('use_trend_filter', False):
|
||
|
|
return signal
|
||
|
|
|
||
|
|
try:
|
||
|
|
from src.data.indicators import get_indicators
|
||
|
|
|
||
|
|
if ohlcv_data is None or len(ohlcv_data) < adx_period:
|
||
|
|
# Not enough data, allow signal
|
||
|
|
return signal
|
||
|
|
|
||
|
|
# Ensure we have a DataFrame
|
||
|
|
if not isinstance(ohlcv_data, pd.DataFrame):
|
||
|
|
return signal
|
||
|
|
|
||
|
|
indicators = get_indicators()
|
||
|
|
|
||
|
|
# Calculate ADX
|
||
|
|
high = ohlcv_data['high']
|
||
|
|
low = ohlcv_data['low']
|
||
|
|
close = ohlcv_data['close']
|
||
|
|
|
||
|
|
adx = indicators.adx(high, low, close, period=adx_period)
|
||
|
|
current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0
|
||
|
|
|
||
|
|
# Check if trend is strong enough
|
||
|
|
if current_adx < min_adx:
|
||
|
|
# Weak trend - filter out signal
|
||
|
|
self.logger.debug(
|
||
|
|
f"Trend filter: ADX {current_adx:.2f} < {min_adx}, "
|
||
|
|
f"filtering {signal.signal_type.value} signal"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Additional check: for BUY signals, ensure uptrend
|
||
|
|
# For SELL signals, ensure downtrend
|
||
|
|
# We can use price vs moving average to determine trend direction
|
||
|
|
if len(close) >= 20:
|
||
|
|
sma_20 = indicators.sma(close, period=20)
|
||
|
|
current_price = close.iloc[-1]
|
||
|
|
sma_value = sma_20.iloc[-1] if not pd.isna(sma_20.iloc[-1]) else current_price
|
||
|
|
|
||
|
|
if signal.signal_type == SignalType.BUY:
|
||
|
|
# BUY only in uptrend (price above SMA)
|
||
|
|
if current_price < sma_value:
|
||
|
|
self.logger.debug(
|
||
|
|
f"Trend filter: BUY signal filtered - price below SMA "
|
||
|
|
f"(price: {current_price}, SMA: {sma_value})"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
elif signal.signal_type == SignalType.SELL:
|
||
|
|
# SELL only in downtrend (price below SMA)
|
||
|
|
if current_price > sma_value:
|
||
|
|
self.logger.debug(
|
||
|
|
f"Trend filter: SELL signal filtered - price above SMA "
|
||
|
|
f"(price: {current_price}, SMA: {sma_value})"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
return signal
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.warning(f"Error applying trend filter: {e}, allowing signal")
|
||
|
|
return signal
|
||
|
|
|
||
|
|
def get_required_indicators(self) -> List[str]:
|
||
|
|
"""Get list of required indicators.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of indicator names
|
||
|
|
"""
|
||
|
|
return []
|
||
|
|
|
||
|
|
def validate_parameters(self) -> bool:
|
||
|
|
"""Validate strategy parameters.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if parameters are valid
|
||
|
|
"""
|
||
|
|
return True
|
||
|
|
|
||
|
|
def get_state(self) -> Dict[str, Any]:
|
||
|
|
"""Get strategy state for persistence.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
State dictionary
|
||
|
|
"""
|
||
|
|
return {
|
||
|
|
'name': self.name,
|
||
|
|
'parameters': self.parameters,
|
||
|
|
'timeframes': self.timeframes,
|
||
|
|
'enabled': self.enabled,
|
||
|
|
}
|
||
|
|
|
||
|
|
def set_state(self, state: Dict[str, Any]):
|
||
|
|
"""Restore strategy state.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
state: State dictionary
|
||
|
|
"""
|
||
|
|
self.parameters = state.get('parameters', {})
|
||
|
|
self.timeframes = state.get('timeframes', ['1h'])
|
||
|
|
self.enabled = state.get('enabled', False)
|
||
|
|
|
||
|
|
|
||
|
|
class StrategyRegistry:
|
||
|
|
"""Registry for managing strategies."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
"""Initialize strategy registry."""
|
||
|
|
self._strategies: Dict[str, type] = {}
|
||
|
|
self._instances: Dict[int, BaseStrategy] = {}
|
||
|
|
self.logger = get_logger(__name__)
|
||
|
|
|
||
|
|
def register(self, name: str, strategy_class: type):
|
||
|
|
"""Register a strategy class.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: Strategy name
|
||
|
|
strategy_class: Strategy class (subclass of BaseStrategy)
|
||
|
|
"""
|
||
|
|
if not issubclass(strategy_class, BaseStrategy):
|
||
|
|
raise ValueError(f"Strategy class must inherit from BaseStrategy")
|
||
|
|
|
||
|
|
self._strategies[name.lower()] = strategy_class
|
||
|
|
self.logger.info(f"Registered strategy: {name}")
|
||
|
|
|
||
|
|
def create_instance(
|
||
|
|
self,
|
||
|
|
strategy_id: int,
|
||
|
|
name: str,
|
||
|
|
parameters: Optional[Dict[str, Any]] = None,
|
||
|
|
timeframes: Optional[List[str]] = None
|
||
|
|
) -> Optional[BaseStrategy]:
|
||
|
|
"""Create strategy instance.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
strategy_id: Strategy ID from src.database
|
||
|
|
name: Strategy name
|
||
|
|
parameters: Strategy parameters
|
||
|
|
timeframes: List of timeframes
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Strategy instance or None
|
||
|
|
"""
|
||
|
|
strategy_class = self._strategies.get(name.lower())
|
||
|
|
if not strategy_class:
|
||
|
|
self.logger.error(f"Strategy {name} not registered")
|
||
|
|
return None
|
||
|
|
|
||
|
|
try:
|
||
|
|
instance = strategy_class(name, parameters, timeframes)
|
||
|
|
self._instances[strategy_id] = instance
|
||
|
|
return instance
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Failed to create strategy instance: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
def get_instance(self, strategy_id: int) -> Optional[BaseStrategy]:
|
||
|
|
"""Get strategy instance by ID.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
strategy_id: Strategy ID
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Strategy instance or None
|
||
|
|
"""
|
||
|
|
return self._instances.get(strategy_id)
|
||
|
|
|
||
|
|
def list_available(self) -> List[str]:
|
||
|
|
"""List available strategy types.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of strategy names
|
||
|
|
"""
|
||
|
|
return list(self._strategies.keys())
|
||
|
|
|
||
|
|
def unregister(self, name: str):
|
||
|
|
"""Unregister a strategy.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: Strategy name
|
||
|
|
"""
|
||
|
|
if name.lower() in self._strategies:
|
||
|
|
del self._strategies[name.lower()]
|
||
|
|
self.logger.info(f"Unregistered strategy: {name}")
|
||
|
|
|
||
|
|
|
||
|
|
# Global strategy registry
|
||
|
|
_registry: Optional[StrategyRegistry] = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_strategy_registry() -> StrategyRegistry:
|
||
|
|
"""Get global strategy registry instance."""
|
||
|
|
global _registry
|
||
|
|
if _registry is None:
|
||
|
|
_registry = StrategyRegistry()
|
||
|
|
return _registry
|
||
|
|
|