Files
crypto_trader/src/strategies/base.py

451 lines
14 KiB
Python

"""Base strategy class and strategy registry system."""
import pandas as pd
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import Dict, Optional, List, Any, Callable
from datetime import datetime
from enum import Enum
from src.core.logger import get_logger
from src.core.database import OrderSide, OrderType
logger = get_logger(__name__)
class SignalType(str, Enum):
"""Trading signal types."""
BUY = "buy"
SELL = "sell"
HOLD = "hold"
CLOSE = "close"
class StrategySignal:
"""Trading signal from strategy."""
def __init__(
self,
signal_type: SignalType,
symbol: str,
strength: float = 1.0,
price: Optional[Decimal] = None,
quantity: Optional[Decimal] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""Initialize strategy signal.
Args:
signal_type: Signal type (buy, sell, hold, close)
symbol: Trading symbol
strength: Signal strength (0.0 to 1.0)
price: Suggested price
quantity: Suggested quantity
metadata: Additional metadata
"""
self.signal_type = signal_type
self.symbol = symbol
self.strength = strength
self.price = price
self.quantity = quantity
self.metadata = metadata or {}
self.timestamp = datetime.utcnow()
class BaseStrategy(ABC):
"""Base class for all trading strategies."""
def __init__(
self,
name: str,
parameters: Optional[Dict[str, Any]] = None,
timeframes: Optional[List[str]] = None
):
"""Initialize strategy.
Args:
name: Strategy name
parameters: Strategy parameters
timeframes: List of timeframes (e.g., ['1h', '15m'])
"""
self.name = name
self.parameters = parameters or {}
self.timeframes = timeframes or ['1h']
self.enabled = False
self.logger = get_logger(f"strategy.{name}")
self._data_cache: Dict[str, Any] = {}
@abstractmethod
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Called on each price update.
Args:
symbol: Trading symbol
price: Current price
timeframe: Timeframe of the update
data: Additional market data
Returns:
StrategySignal or None
"""
pass
@abstractmethod
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process and potentially modify signal.
Args:
signal: Generated signal
Returns:
Modified signal or None to cancel
"""
pass
def calculate_position_size(
self,
signal: StrategySignal,
balance: Decimal,
price: Decimal,
exchange_adapter=None
) -> Decimal:
"""Calculate position size for signal, accounting for fees.
Args:
signal: Trading signal
balance: Available balance
price: Current price
exchange_adapter: Exchange adapter for fee calculation (optional)
Returns:
Position size
"""
# Default: use 2% of balance
risk_percent = self.parameters.get('position_size_percent', 2.0) / 100.0
position_value = balance * Decimal(str(risk_percent))
# Account for fees by reserving fee amount
from src.trading.fee_calculator import get_fee_calculator
fee_calculator = get_fee_calculator()
# Reserve ~0.4% for round-trip fees (conservative estimate)
fee_reserve = fee_calculator.calculate_fee_reserve(
position_value=position_value,
exchange_adapter=exchange_adapter,
reserve_percent=0.004 # 0.4% for round-trip
)
# Adjust position value to account for fees
adjusted_position_value = position_value - fee_reserve
# Calculate quantity
if price > 0:
quantity = adjusted_position_value / price
return max(Decimal(0), quantity) # Ensure non-negative
return Decimal(0)
def should_execute(self, signal: StrategySignal) -> bool:
"""Check if signal should be executed.
Args:
signal: Trading signal
Returns:
True if should execute
"""
if not self.enabled:
return False
# Check signal strength threshold
min_strength = self.parameters.get('min_signal_strength', 0.5)
if signal.strength < min_strength:
return False
return True
def should_execute_with_fees(
self,
signal: StrategySignal,
balance: Decimal,
price: Decimal,
exchange_adapter=None
) -> bool:
"""Check if signal should be executed considering fees and minimum profit threshold.
Args:
signal: Trading signal
balance: Available balance
price: Current price
exchange_adapter: Exchange adapter for fee calculation (optional)
Returns:
True if should execute after fee consideration
"""
# First check basic execution criteria
if not self.should_execute(signal):
return False
# Calculate position size
quantity = signal.quantity or self.calculate_position_size(signal, balance, price, exchange_adapter)
if quantity <= 0:
return False
# Check minimum profit threshold
from src.trading.fee_calculator import get_fee_calculator
fee_calculator = get_fee_calculator()
# Get minimum profit multiplier from strategy parameters (default 2.0)
min_profit_multiplier = self.parameters.get('min_profit_multiplier', 2.0)
min_profit_threshold = fee_calculator.get_minimum_profit_threshold(
quantity=quantity,
price=price,
exchange_adapter=exchange_adapter,
multiplier=min_profit_multiplier
)
# Estimate potential profit (simplified - strategies can override)
# For buy signals, we'd need to estimate exit price
# For now, we'll use a basic check: if signal strength is high enough
# Strategies should override this method for more sophisticated checks
# If we have a target price in signal metadata, use it
target_price = signal.metadata.get('target_price')
if target_price:
if signal.signal_type.value == "buy":
potential_profit = (target_price - price) * quantity
else: # sell
potential_profit = (price - target_price) * quantity
if potential_profit < min_profit_threshold:
self.logger.debug(
f"Signal filtered: potential profit {potential_profit} < "
f"minimum threshold {min_profit_threshold}"
)
return False
return True
def apply_trend_filter(
self,
signal: StrategySignal,
ohlcv_data: Any,
adx_period: int = 14,
min_adx: float = 25.0
) -> Optional[StrategySignal]:
"""Apply ADX-based trend filter to signal.
Filters signals based on trend strength:
- Only allow BUY signals when ADX > threshold (strong trend)
- Only allow SELL signals in downtrends with ADX > threshold
- Filters out choppy/ranging markets
Args:
signal: Trading signal to filter
ohlcv_data: OHLCV DataFrame with columns: high, low, close
adx_period: ADX calculation period (default 14)
min_adx: Minimum ADX value for signal (default 25.0)
Returns:
Filtered signal or None if filtered out
"""
if not self.parameters.get('use_trend_filter', False):
return signal
try:
from src.data.indicators import get_indicators
if ohlcv_data is None or len(ohlcv_data) < adx_period:
# Not enough data, allow signal
return signal
# Ensure we have a DataFrame
if not isinstance(ohlcv_data, pd.DataFrame):
return signal
indicators = get_indicators()
# Calculate ADX
high = ohlcv_data['high']
low = ohlcv_data['low']
close = ohlcv_data['close']
adx = indicators.adx(high, low, close, period=adx_period)
current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0
# Check if trend is strong enough
if current_adx < min_adx:
# Weak trend - filter out signal
self.logger.debug(
f"Trend filter: ADX {current_adx:.2f} < {min_adx}, "
f"filtering {signal.signal_type.value} signal"
)
return None
# Additional check: for BUY signals, ensure uptrend
# For SELL signals, ensure downtrend
# We can use price vs moving average to determine trend direction
if len(close) >= 20:
sma_20 = indicators.sma(close, period=20)
current_price = close.iloc[-1]
sma_value = sma_20.iloc[-1] if not pd.isna(sma_20.iloc[-1]) else current_price
if signal.signal_type == SignalType.BUY:
# BUY only in uptrend (price above SMA)
if current_price < sma_value:
self.logger.debug(
f"Trend filter: BUY signal filtered - price below SMA "
f"(price: {current_price}, SMA: {sma_value})"
)
return None
elif signal.signal_type == SignalType.SELL:
# SELL only in downtrend (price below SMA)
if current_price > sma_value:
self.logger.debug(
f"Trend filter: SELL signal filtered - price above SMA "
f"(price: {current_price}, SMA: {sma_value})"
)
return None
return signal
except Exception as e:
self.logger.warning(f"Error applying trend filter: {e}, allowing signal")
return signal
def get_required_indicators(self) -> List[str]:
"""Get list of required indicators.
Returns:
List of indicator names
"""
return []
def validate_parameters(self) -> bool:
"""Validate strategy parameters.
Returns:
True if parameters are valid
"""
return True
def get_state(self) -> Dict[str, Any]:
"""Get strategy state for persistence.
Returns:
State dictionary
"""
return {
'name': self.name,
'parameters': self.parameters,
'timeframes': self.timeframes,
'enabled': self.enabled,
}
def set_state(self, state: Dict[str, Any]):
"""Restore strategy state.
Args:
state: State dictionary
"""
self.parameters = state.get('parameters', {})
self.timeframes = state.get('timeframes', ['1h'])
self.enabled = state.get('enabled', False)
class StrategyRegistry:
"""Registry for managing strategies."""
def __init__(self):
"""Initialize strategy registry."""
self._strategies: Dict[str, type] = {}
self._instances: Dict[int, BaseStrategy] = {}
self.logger = get_logger(__name__)
def register(self, name: str, strategy_class: type):
"""Register a strategy class.
Args:
name: Strategy name
strategy_class: Strategy class (subclass of BaseStrategy)
"""
if not issubclass(strategy_class, BaseStrategy):
raise ValueError(f"Strategy class must inherit from BaseStrategy")
self._strategies[name.lower()] = strategy_class
self.logger.info(f"Registered strategy: {name}")
def create_instance(
self,
strategy_id: int,
name: str,
parameters: Optional[Dict[str, Any]] = None,
timeframes: Optional[List[str]] = None
) -> Optional[BaseStrategy]:
"""Create strategy instance.
Args:
strategy_id: Strategy ID from src.database
name: Strategy name
parameters: Strategy parameters
timeframes: List of timeframes
Returns:
Strategy instance or None
"""
strategy_class = self._strategies.get(name.lower())
if not strategy_class:
self.logger.error(f"Strategy {name} not registered")
return None
try:
instance = strategy_class(name, parameters, timeframes)
self._instances[strategy_id] = instance
return instance
except Exception as e:
self.logger.error(f"Failed to create strategy instance: {e}")
return None
def get_instance(self, strategy_id: int) -> Optional[BaseStrategy]:
"""Get strategy instance by ID.
Args:
strategy_id: Strategy ID
Returns:
Strategy instance or None
"""
return self._instances.get(strategy_id)
def list_available(self) -> List[str]:
"""List available strategy types.
Returns:
List of strategy names
"""
return list(self._strategies.keys())
def unregister(self, name: str):
"""Unregister a strategy.
Args:
name: Strategy name
"""
if name.lower() in self._strategies:
del self._strategies[name.lower()]
self.logger.info(f"Unregistered strategy: {name}")
# Global strategy registry
_registry: Optional[StrategyRegistry] = None
def get_strategy_registry() -> StrategyRegistry:
"""Get global strategy registry instance."""
global _registry
if _registry is None:
_registry = StrategyRegistry()
return _registry