Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
450
src/strategies/base.py
Normal file
450
src/strategies/base.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""Base strategy class and strategy registry system."""
|
||||
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Optional, List, Any, Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from src.core.logger import get_logger
|
||||
from src.core.database import OrderSide, OrderType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SignalType(str, Enum):
|
||||
"""Trading signal types."""
|
||||
BUY = "buy"
|
||||
SELL = "sell"
|
||||
HOLD = "hold"
|
||||
CLOSE = "close"
|
||||
|
||||
|
||||
class StrategySignal:
|
||||
"""Trading signal from strategy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
signal_type: SignalType,
|
||||
symbol: str,
|
||||
strength: float = 1.0,
|
||||
price: Optional[Decimal] = None,
|
||||
quantity: Optional[Decimal] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Initialize strategy signal.
|
||||
|
||||
Args:
|
||||
signal_type: Signal type (buy, sell, hold, close)
|
||||
symbol: Trading symbol
|
||||
strength: Signal strength (0.0 to 1.0)
|
||||
price: Suggested price
|
||||
quantity: Suggested quantity
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
self.signal_type = signal_type
|
||||
self.symbol = symbol
|
||||
self.strength = strength
|
||||
self.price = price
|
||||
self.quantity = quantity
|
||||
self.metadata = metadata or {}
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""Base class for all trading strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
timeframes: Optional[List[str]] = None
|
||||
):
|
||||
"""Initialize strategy.
|
||||
|
||||
Args:
|
||||
name: Strategy name
|
||||
parameters: Strategy parameters
|
||||
timeframes: List of timeframes (e.g., ['1h', '15m'])
|
||||
"""
|
||||
self.name = name
|
||||
self.parameters = parameters or {}
|
||||
self.timeframes = timeframes or ['1h']
|
||||
self.enabled = False
|
||||
self.logger = get_logger(f"strategy.{name}")
|
||||
self._data_cache: Dict[str, Any] = {}
|
||||
|
||||
@abstractmethod
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Called on each price update.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Current price
|
||||
timeframe: Timeframe of the update
|
||||
data: Additional market data
|
||||
|
||||
Returns:
|
||||
StrategySignal or None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process and potentially modify signal.
|
||||
|
||||
Args:
|
||||
signal: Generated signal
|
||||
|
||||
Returns:
|
||||
Modified signal or None to cancel
|
||||
"""
|
||||
pass
|
||||
|
||||
def calculate_position_size(
|
||||
self,
|
||||
signal: StrategySignal,
|
||||
balance: Decimal,
|
||||
price: Decimal,
|
||||
exchange_adapter=None
|
||||
) -> Decimal:
|
||||
"""Calculate position size for signal, accounting for fees.
|
||||
|
||||
Args:
|
||||
signal: Trading signal
|
||||
balance: Available balance
|
||||
price: Current price
|
||||
exchange_adapter: Exchange adapter for fee calculation (optional)
|
||||
|
||||
Returns:
|
||||
Position size
|
||||
"""
|
||||
# Default: use 2% of balance
|
||||
risk_percent = self.parameters.get('position_size_percent', 2.0) / 100.0
|
||||
position_value = balance * Decimal(str(risk_percent))
|
||||
|
||||
# Account for fees by reserving fee amount
|
||||
from src.trading.fee_calculator import get_fee_calculator
|
||||
fee_calculator = get_fee_calculator()
|
||||
|
||||
# Reserve ~0.4% for round-trip fees (conservative estimate)
|
||||
fee_reserve = fee_calculator.calculate_fee_reserve(
|
||||
position_value=position_value,
|
||||
exchange_adapter=exchange_adapter,
|
||||
reserve_percent=0.004 # 0.4% for round-trip
|
||||
)
|
||||
|
||||
# Adjust position value to account for fees
|
||||
adjusted_position_value = position_value - fee_reserve
|
||||
|
||||
# Calculate quantity
|
||||
if price > 0:
|
||||
quantity = adjusted_position_value / price
|
||||
return max(Decimal(0), quantity) # Ensure non-negative
|
||||
|
||||
return Decimal(0)
|
||||
|
||||
def should_execute(self, signal: StrategySignal) -> bool:
|
||||
"""Check if signal should be executed.
|
||||
|
||||
Args:
|
||||
signal: Trading signal
|
||||
|
||||
Returns:
|
||||
True if should execute
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
# Check signal strength threshold
|
||||
min_strength = self.parameters.get('min_signal_strength', 0.5)
|
||||
if signal.strength < min_strength:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def should_execute_with_fees(
|
||||
self,
|
||||
signal: StrategySignal,
|
||||
balance: Decimal,
|
||||
price: Decimal,
|
||||
exchange_adapter=None
|
||||
) -> bool:
|
||||
"""Check if signal should be executed considering fees and minimum profit threshold.
|
||||
|
||||
Args:
|
||||
signal: Trading signal
|
||||
balance: Available balance
|
||||
price: Current price
|
||||
exchange_adapter: Exchange adapter for fee calculation (optional)
|
||||
|
||||
Returns:
|
||||
True if should execute after fee consideration
|
||||
"""
|
||||
# First check basic execution criteria
|
||||
if not self.should_execute(signal):
|
||||
return False
|
||||
|
||||
# Calculate position size
|
||||
quantity = signal.quantity or self.calculate_position_size(signal, balance, price, exchange_adapter)
|
||||
|
||||
if quantity <= 0:
|
||||
return False
|
||||
|
||||
# Check minimum profit threshold
|
||||
from src.trading.fee_calculator import get_fee_calculator
|
||||
fee_calculator = get_fee_calculator()
|
||||
|
||||
# Get minimum profit multiplier from strategy parameters (default 2.0)
|
||||
min_profit_multiplier = self.parameters.get('min_profit_multiplier', 2.0)
|
||||
|
||||
min_profit_threshold = fee_calculator.get_minimum_profit_threshold(
|
||||
quantity=quantity,
|
||||
price=price,
|
||||
exchange_adapter=exchange_adapter,
|
||||
multiplier=min_profit_multiplier
|
||||
)
|
||||
|
||||
# Estimate potential profit (simplified - strategies can override)
|
||||
# For buy signals, we'd need to estimate exit price
|
||||
# For now, we'll use a basic check: if signal strength is high enough
|
||||
# Strategies should override this method for more sophisticated checks
|
||||
|
||||
# If we have a target price in signal metadata, use it
|
||||
target_price = signal.metadata.get('target_price')
|
||||
if target_price:
|
||||
if signal.signal_type.value == "buy":
|
||||
potential_profit = (target_price - price) * quantity
|
||||
else: # sell
|
||||
potential_profit = (price - target_price) * quantity
|
||||
|
||||
if potential_profit < min_profit_threshold:
|
||||
self.logger.debug(
|
||||
f"Signal filtered: potential profit {potential_profit} < "
|
||||
f"minimum threshold {min_profit_threshold}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def apply_trend_filter(
|
||||
self,
|
||||
signal: StrategySignal,
|
||||
ohlcv_data: Any,
|
||||
adx_period: int = 14,
|
||||
min_adx: float = 25.0
|
||||
) -> Optional[StrategySignal]:
|
||||
"""Apply ADX-based trend filter to signal.
|
||||
|
||||
Filters signals based on trend strength:
|
||||
- Only allow BUY signals when ADX > threshold (strong trend)
|
||||
- Only allow SELL signals in downtrends with ADX > threshold
|
||||
- Filters out choppy/ranging markets
|
||||
|
||||
Args:
|
||||
signal: Trading signal to filter
|
||||
ohlcv_data: OHLCV DataFrame with columns: high, low, close
|
||||
adx_period: ADX calculation period (default 14)
|
||||
min_adx: Minimum ADX value for signal (default 25.0)
|
||||
|
||||
Returns:
|
||||
Filtered signal or None if filtered out
|
||||
"""
|
||||
if not self.parameters.get('use_trend_filter', False):
|
||||
return signal
|
||||
|
||||
try:
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
if ohlcv_data is None or len(ohlcv_data) < adx_period:
|
||||
# Not enough data, allow signal
|
||||
return signal
|
||||
|
||||
# Ensure we have a DataFrame
|
||||
if not isinstance(ohlcv_data, pd.DataFrame):
|
||||
return signal
|
||||
|
||||
indicators = get_indicators()
|
||||
|
||||
# Calculate ADX
|
||||
high = ohlcv_data['high']
|
||||
low = ohlcv_data['low']
|
||||
close = ohlcv_data['close']
|
||||
|
||||
adx = indicators.adx(high, low, close, period=adx_period)
|
||||
current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0
|
||||
|
||||
# Check if trend is strong enough
|
||||
if current_adx < min_adx:
|
||||
# Weak trend - filter out signal
|
||||
self.logger.debug(
|
||||
f"Trend filter: ADX {current_adx:.2f} < {min_adx}, "
|
||||
f"filtering {signal.signal_type.value} signal"
|
||||
)
|
||||
return None
|
||||
|
||||
# Additional check: for BUY signals, ensure uptrend
|
||||
# For SELL signals, ensure downtrend
|
||||
# We can use price vs moving average to determine trend direction
|
||||
if len(close) >= 20:
|
||||
sma_20 = indicators.sma(close, period=20)
|
||||
current_price = close.iloc[-1]
|
||||
sma_value = sma_20.iloc[-1] if not pd.isna(sma_20.iloc[-1]) else current_price
|
||||
|
||||
if signal.signal_type == SignalType.BUY:
|
||||
# BUY only in uptrend (price above SMA)
|
||||
if current_price < sma_value:
|
||||
self.logger.debug(
|
||||
f"Trend filter: BUY signal filtered - price below SMA "
|
||||
f"(price: {current_price}, SMA: {sma_value})"
|
||||
)
|
||||
return None
|
||||
elif signal.signal_type == SignalType.SELL:
|
||||
# SELL only in downtrend (price below SMA)
|
||||
if current_price > sma_value:
|
||||
self.logger.debug(
|
||||
f"Trend filter: SELL signal filtered - price above SMA "
|
||||
f"(price: {current_price}, SMA: {sma_value})"
|
||||
)
|
||||
return None
|
||||
|
||||
return signal
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error applying trend filter: {e}, allowing signal")
|
||||
return signal
|
||||
|
||||
def get_required_indicators(self) -> List[str]:
|
||||
"""Get list of required indicators.
|
||||
|
||||
Returns:
|
||||
List of indicator names
|
||||
"""
|
||||
return []
|
||||
|
||||
def validate_parameters(self) -> bool:
|
||||
"""Validate strategy parameters.
|
||||
|
||||
Returns:
|
||||
True if parameters are valid
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get strategy state for persistence.
|
||||
|
||||
Returns:
|
||||
State dictionary
|
||||
"""
|
||||
return {
|
||||
'name': self.name,
|
||||
'parameters': self.parameters,
|
||||
'timeframes': self.timeframes,
|
||||
'enabled': self.enabled,
|
||||
}
|
||||
|
||||
def set_state(self, state: Dict[str, Any]):
|
||||
"""Restore strategy state.
|
||||
|
||||
Args:
|
||||
state: State dictionary
|
||||
"""
|
||||
self.parameters = state.get('parameters', {})
|
||||
self.timeframes = state.get('timeframes', ['1h'])
|
||||
self.enabled = state.get('enabled', False)
|
||||
|
||||
|
||||
class StrategyRegistry:
|
||||
"""Registry for managing strategies."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize strategy registry."""
|
||||
self._strategies: Dict[str, type] = {}
|
||||
self._instances: Dict[int, BaseStrategy] = {}
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def register(self, name: str, strategy_class: type):
|
||||
"""Register a strategy class.
|
||||
|
||||
Args:
|
||||
name: Strategy name
|
||||
strategy_class: Strategy class (subclass of BaseStrategy)
|
||||
"""
|
||||
if not issubclass(strategy_class, BaseStrategy):
|
||||
raise ValueError(f"Strategy class must inherit from BaseStrategy")
|
||||
|
||||
self._strategies[name.lower()] = strategy_class
|
||||
self.logger.info(f"Registered strategy: {name}")
|
||||
|
||||
def create_instance(
|
||||
self,
|
||||
strategy_id: int,
|
||||
name: str,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
timeframes: Optional[List[str]] = None
|
||||
) -> Optional[BaseStrategy]:
|
||||
"""Create strategy instance.
|
||||
|
||||
Args:
|
||||
strategy_id: Strategy ID from src.database
|
||||
name: Strategy name
|
||||
parameters: Strategy parameters
|
||||
timeframes: List of timeframes
|
||||
|
||||
Returns:
|
||||
Strategy instance or None
|
||||
"""
|
||||
strategy_class = self._strategies.get(name.lower())
|
||||
if not strategy_class:
|
||||
self.logger.error(f"Strategy {name} not registered")
|
||||
return None
|
||||
|
||||
try:
|
||||
instance = strategy_class(name, parameters, timeframes)
|
||||
self._instances[strategy_id] = instance
|
||||
return instance
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create strategy instance: {e}")
|
||||
return None
|
||||
|
||||
def get_instance(self, strategy_id: int) -> Optional[BaseStrategy]:
|
||||
"""Get strategy instance by ID.
|
||||
|
||||
Args:
|
||||
strategy_id: Strategy ID
|
||||
|
||||
Returns:
|
||||
Strategy instance or None
|
||||
"""
|
||||
return self._instances.get(strategy_id)
|
||||
|
||||
def list_available(self) -> List[str]:
|
||||
"""List available strategy types.
|
||||
|
||||
Returns:
|
||||
List of strategy names
|
||||
"""
|
||||
return list(self._strategies.keys())
|
||||
|
||||
def unregister(self, name: str):
|
||||
"""Unregister a strategy.
|
||||
|
||||
Args:
|
||||
name: Strategy name
|
||||
"""
|
||||
if name.lower() in self._strategies:
|
||||
del self._strategies[name.lower()]
|
||||
self.logger.info(f"Unregistered strategy: {name}")
|
||||
|
||||
|
||||
# Global strategy registry
|
||||
_registry: Optional[StrategyRegistry] = None
|
||||
|
||||
|
||||
def get_strategy_registry() -> StrategyRegistry:
|
||||
"""Get global strategy registry instance."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = StrategyRegistry()
|
||||
return _registry
|
||||
|
||||
Reference in New Issue
Block a user