Local changes: Updated model training, removed debug instrumentation, and configuration improvements

This commit is contained in:
kfox
2025-12-26 01:15:43 -05:00
commit cc60da49e7
388 changed files with 57127 additions and 0 deletions

0
src/__init__.py Normal file
View File

0
src/alerts/__init__.py Normal file
View File

26
src/alerts/channels.py Normal file
View File

@@ -0,0 +1,26 @@
"""Alert delivery channels."""
from typing import Optional
from src.ui.utils.notifications import get_notification_manager
from src.core.logger import get_logger
logger = get_logger(__name__)
class AlertChannel:
"""Manages alert delivery channels."""
def __init__(self):
"""Initialize alert channel."""
self.notifications = get_notification_manager()
self.logger = get_logger(__name__)
def send(self, alert_type: str, message: str):
"""Send alert through all channels.
Args:
alert_type: Alert type
message: Alert message
"""
self.notifications.notify_alert(alert_type, message)

139
src/alerts/engine.py Normal file
View File

@@ -0,0 +1,139 @@
"""Alert rule engine."""
from decimal import Decimal
from typing import Dict, Optional, Callable, List
from datetime import datetime
from sqlalchemy.orm import Session
from src.core.database import get_database, Alert
from src.core.logger import get_logger
from .channels import AlertChannel
logger = get_logger(__name__)
class AlertEngine:
"""Alert rule engine."""
def __init__(self):
"""Initialize alert engine."""
self.db = get_database()
self.logger = get_logger(__name__)
self.channel = AlertChannel()
self._active_alerts: Dict[int, Dict] = {}
def evaluate_price_alert(
self,
alert_id: int,
symbol: str,
current_price: Decimal
) -> bool:
"""Evaluate price alert.
Args:
alert_id: Alert ID
symbol: Trading symbol
current_price: Current price
Returns:
True if alert should trigger
"""
session = self.db.get_session()
try:
alert = session.query(Alert).filter_by(id=alert_id).first()
if not alert or not alert.enabled:
return False
condition = alert.condition
alert_type = condition.get('type')
threshold = Decimal(str(condition.get('threshold', 0)))
operator = condition.get('operator', '>') # >, <, >=, <=
if alert_type == 'price_above':
return current_price > threshold
elif alert_type == 'price_below':
return current_price < threshold
elif alert_type == 'price_change':
# Would need previous price
return False
return False
finally:
session.close()
def evaluate_indicator_alert(
self,
alert_id: int,
indicator_value: float
) -> bool:
"""Evaluate indicator alert.
Args:
alert_id: Alert ID
indicator_value: Current indicator value
Returns:
True if alert should trigger
"""
session = self.db.get_session()
try:
alert = session.query(Alert).filter_by(id=alert_id).first()
if not alert or not alert.enabled:
return False
condition = alert.condition
threshold = float(condition.get('threshold', 0))
operator = condition.get('operator', '>')
if operator == '>':
return indicator_value > threshold
elif operator == '<':
return indicator_value < threshold
elif operator == 'crosses_above':
# Would need previous value
return False
elif operator == 'crosses_below':
# Would need previous value
return False
return False
finally:
session.close()
def trigger_alert(self, alert_id: int, message: str):
"""Trigger an alert.
Args:
alert_id: Alert ID
message: Alert message
"""
session = self.db.get_session()
try:
alert = session.query(Alert).filter_by(id=alert_id).first()
if not alert:
return
alert.triggered = True
alert.triggered_at = datetime.utcnow()
session.commit()
# Send notification
self.channel.send(alert.alert_type, message)
logger.info(f"Alert {alert_id} triggered: {message}")
except Exception as e:
session.rollback()
logger.error(f"Failed to trigger alert {alert_id}: {e}")
finally:
session.close()
# Global alert engine
_alert_engine: Optional[AlertEngine] = None
def get_alert_engine() -> AlertEngine:
"""Get global alert engine instance."""
global _alert_engine
if _alert_engine is None:
_alert_engine = AlertEngine()
return _alert_engine

78
src/alerts/manager.py Normal file
View File

@@ -0,0 +1,78 @@
"""Alert management."""
from typing import Dict, Optional, List
from sqlalchemy import select
from src.core.database import get_database, Alert
from src.core.logger import get_logger
logger = get_logger(__name__)
class AlertManager:
"""Manages alerts."""
def __init__(self):
"""Initialize alert manager."""
self.db = get_database()
self.logger = get_logger(__name__)
async def create_alert(
self,
name: str,
alert_type: str,
condition: Dict
) -> Alert:
"""Create a new alert.
Args:
name: Alert name
alert_type: Alert type (price, indicator, risk, system)
condition: Alert condition configuration
Returns:
Alert object
"""
async with self.db.get_session() as session:
try:
alert = Alert(
name=name,
alert_type=alert_type,
condition=condition,
enabled=True
)
session.add(alert)
await session.commit()
await session.refresh(alert)
return alert
except Exception as e:
await session.rollback()
logger.error(f"Failed to create alert: {e}")
raise
async def list_alerts(self, enabled_only: bool = False) -> List[Alert]:
"""List all alerts.
Args:
enabled_only: Only return enabled alerts
Returns:
List of alerts
"""
async with self.db.get_session() as session:
stmt = select(Alert)
if enabled_only:
stmt = stmt.where(Alert.enabled == True)
result = await session.execute(stmt)
return result.scalars().all()
# Global alert manager
_alert_manager: Optional[AlertManager] = None
def get_alert_manager() -> AlertManager:
"""Get global alert manager instance."""
global _alert_manager
if _alert_manager is None:
_alert_manager = AlertManager()
return _alert_manager

150
src/autopilot/__init__.py Normal file
View File

@@ -0,0 +1,150 @@
"""AutoPilot autonomous trading engine module.
This module provides an autonomous background service that combines
geometric pattern recognition with NLP-based sentiment analysis
to generate trade signals.
Supported Patterns:
- Head and Shoulders (bullish/bearish)
- Double Top/Bottom
- Triple Top/Bottom
- Triangle patterns (Ascending, Descending, Symmetrical)
- Wedge patterns (Rising, Falling)
- Flag and Pennant patterns
- Candlestick patterns (Engulfing, Hammer, Doji, Stars, etc.)
- MA Crossovers (Golden Cross, Death Cross)
- Support/Resistance levels
- Harmonic patterns (ABCD)
- Gap patterns
"""
from .intelligent_autopilot import (
IntelligentAutopilot,
get_intelligent_autopilot,
stop_all_autopilots,
)
from .market_analyzer import (
MarketAnalyzer,
MarketConditions,
MarketRegime,
get_market_analyzer,
)
from .strategy_selector import (
StrategySelector,
get_strategy_selector,
)
from .performance_tracker import (
PerformanceTracker,
get_performance_tracker,
)
from typing import Dict, Any
def get_autopilot_mode_info() -> Dict[str, Any]:
"""Get information about available autopilot modes.
Returns:
Dictionary with mode information including descriptions, capabilities, and tradeoffs
"""
return {
"modes": {
"pattern": {
"name": "Pattern-Based Autopilot",
"description": "Detects technical chart patterns (Head & Shoulders, triangles, etc.) and combines with sentiment analysis",
"how_it_works": "Rule-based logic - pattern + sentiment alignment = signal",
"best_for": [
"Users who want transparency",
"Users who understand technical analysis",
"Users who prefer explainable decisions"
],
"tradeoffs": {
"pros": [
"Transparent and explainable",
"No training data required",
"Fast and lightweight"
],
"cons": [
"Less adaptive to market changes",
"Fixed decision rules"
]
},
"features": [
"Pattern recognition (40+ patterns)",
"Sentiment analysis (FinBERT)",
"Real-time signal generation",
"Transparent decision logic"
],
"requirements": {
"training_data": False,
"setup_complexity": "Low",
"resource_usage": "Low"
}
},
"intelligent": {
"name": "ML-Based Autopilot",
"description": "Uses machine learning to select the best strategy based on current market conditions",
"how_it_works": "ML model analyzes market conditions and selects optimal strategy from available strategies",
"best_for": [
"Users who want adaptive, data-driven decisions",
"Users who don't need to understand every decision",
"Advanced users seeking optimization"
],
"tradeoffs": {
"pros": [
"Adapts to market conditions",
"Learns from historical performance",
"Can optimize strategy selection"
],
"cons": [
"Requires training data",
"Less transparent (black box)",
"More complex setup"
]
},
"features": [
"ML-based strategy selection",
"Market condition analysis",
"Performance tracking",
"Auto-execution support"
],
"requirements": {
"training_data": True,
"setup_complexity": "Medium",
"resource_usage": "Medium"
}
}
},
"comparison": {
"transparency": {
"pattern": "High - All decisions are explainable",
"intelligent": "Low - ML model decisions are less transparent"
},
"adaptability": {
"pattern": "Low - Fixed rules",
"intelligent": "High - Learns and adapts"
},
"setup_time": {
"pattern": "Immediate - No setup required",
"intelligent": "Requires training data collection"
},
"resource_usage": {
"pattern": "Low - Lightweight",
"intelligent": "Medium - ML model overhead"
}
}
}
__all__ = [
"stop_all_autopilots",
"IntelligentAutopilot",
"get_intelligent_autopilot",
"MarketAnalyzer",
"MarketConditions",
"MarketRegime",
"get_market_analyzer",
"StrategySelector",
"get_strategy_selector",
"PerformanceTracker",
"get_performance_tracker",
"get_autopilot_mode_info",
]

View File

@@ -0,0 +1,717 @@
"""Intelligent autopilot orchestrator with ML-based strategy selection."""
import asyncio
import json
from decimal import Decimal
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime, timedelta
import pandas as pd
from src.core.logger import get_logger
from src.core.database import get_database, OrderSide, OrderType
from src.core.config import get_config
from src.core.redis import get_redis_client
from .market_analyzer import MarketConditions, get_market_analyzer
from .strategy_selector import get_strategy_selector
from .performance_tracker import get_performance_tracker
from src.strategies.base import get_strategy_registry, BaseStrategy, StrategySignal, SignalType
from src.trading.engine import get_trading_engine
from src.risk.manager import get_risk_manager
from src.data.collector import get_data_collector
logger = get_logger(__name__)
# Strategies that require user configuration (excluded from default autopilot)
# These are only loaded if the user has configured them in the Strategies page
STRATEGIES_REQUIRING_CONFIG = {
'pairs_trading', # Requires second_symbol
'grid', # Requires grid_spacing, num_levels
'dca', # Requires amount, interval
'market_making', # Requires spread_percent, max_inventory
}
class IntelligentAutopilot:
"""Intelligent autopilot with ML-based strategy selection and auto-execution."""
def __init__(
self,
symbol: str,
exchange_id: int = 1,
timeframe: str = "1h",
interval: float = 60.0,
paper_trading: bool = True
):
"""Initialize intelligent autopilot.
Args:
symbol: Trading symbol (e.g., "BTC/USD")
exchange_id: Exchange ID
timeframe: Analysis timeframe
interval: Analysis cycle interval in seconds
paper_trading: Paper trading mode
"""
self.symbol = symbol
self.exchange_id = exchange_id
self.timeframe = timeframe
self.interval = interval
self.paper_trading = paper_trading
self.logger = get_logger(__name__)
self.config = get_config()
self.db = get_database()
self.redis = get_redis_client()
# Core components
self.market_analyzer = get_market_analyzer()
self.strategy_selector = get_strategy_selector()
self.performance_tracker = get_performance_tracker()
self.strategy_registry = get_strategy_registry()
self.trading_engine = get_trading_engine()
self.risk_manager = get_risk_manager()
self.data_collector = get_data_collector()
# Configuration
self.min_confidence_threshold = self.config.get(
"autopilot.intelligent.min_confidence_threshold",
0.75
)
self.max_trades_per_day = self.config.get(
"autopilot.intelligent.max_trades_per_day",
10
)
self.min_profit_target = self.config.get(
"autopilot.intelligent.min_profit_target",
0.02 # 2% minimum expected profit
)
self.enable_auto_execution = self.config.get(
"autopilot.intelligent.enable_auto_execution",
True
)
# State
self._local_running = False # Local flag for the loop
self._last_analysis: Optional[Dict[str, Any]] = None
self._selected_strategy: Optional[str] = None
self._strategy_instances: Dict[str, BaseStrategy] = {}
self._ohlcv_data: Optional[pd.DataFrame] = None
# Redis Keys
safe_symbol = symbol.replace("/", "-")
self.key_running = f"autopilot:running:{safe_symbol}:{timeframe}"
self.key_trades = f"autopilot:trades:{safe_symbol}"
# Initialize strategy instances
self._initialize_strategies()
@property
def _trades_today_key(self) -> str:
"""Get Redis key for today's trade count."""
today = datetime.utcnow().strftime("%Y-%m-%d")
return f"{self.key_trades}:{today}"
async def _fetch_market_data(self):
"""Fetch OHLCV data from src.database."""
from sqlalchemy import select
try:
async with self.db.get_session() as session:
from src.core.database import MarketData, Exchange
# Get exchange name
stmt = select(Exchange).filter_by(id=self.exchange_id).limit(1)
result = await session.execute(stmt)
exchange = result.scalar_one_or_none()
exchange_name = exchange.name if exchange else "coinbase"
# Fetch recent OHLCV data (last 200 candles for analysis)
ohlcv_stmt = select(MarketData).filter_by(
exchange=exchange_name,
symbol=self.symbol,
timeframe=self.timeframe
).order_by(MarketData.timestamp.desc()).limit(200)
ohlcv_result = await session.execute(ohlcv_stmt)
market_data = ohlcv_result.scalars().all()
if market_data:
# Convert to DataFrame
data = {
'timestamp': [md.timestamp for md in reversed(market_data)],
'open': [float(md.open) for md in reversed(market_data)],
'high': [float(md.high) for md in reversed(market_data)],
'low': [float(md.low) for md in reversed(market_data)],
'close': [float(md.close) for md in reversed(market_data)],
'volume': [float(md.volume) for md in reversed(market_data)],
}
self._ohlcv_data = pd.DataFrame(data)
self.logger.info(f"Loaded {len(self._ohlcv_data)} candles from src.database")
else:
self.logger.warning("No market data found in database")
except Exception as e:
self.logger.error(f"Error fetching market data: {e}")
def _initialize_strategies(self):
"""Initialize strategy instances for strategies that work without configuration."""
available_strategies = self.strategy_registry.list_available()
for strategy_name in available_strategies:
# Skip strategies that require user configuration
if strategy_name.lower() in STRATEGIES_REQUIRING_CONFIG:
self.logger.debug(f"Skipping {strategy_name} (requires configuration)")
continue
try:
# Create instance with default parameters
strategy_class = self.strategy_registry._strategies.get(strategy_name.lower())
if strategy_class:
instance = strategy_class(
name=strategy_name,
parameters={},
timeframes=[self.timeframe]
)
instance.enabled = True
self._strategy_instances[strategy_name] = instance
self.logger.debug(f"Initialized strategy: {strategy_name}")
except Exception as e:
self.logger.warning(f"Failed to initialize strategy {strategy_name}: {e}")
async def _load_user_configured_strategies(self):
"""Load user-configured strategies from database.
This enables strategies that require configuration (like pairs_trading)
when the user has set them up in the Strategies page.
Also overrides default parameters if user has configured them.
"""
from sqlalchemy import select
from src.core.database import Strategy
try:
async with self.db.get_session() as session:
# Get all enabled strategies from database
stmt = select(Strategy).where(Strategy.enabled == True)
result = await session.execute(stmt)
user_strategies = result.scalars().all()
for db_strategy in user_strategies:
strategy_type = db_strategy.strategy_type
if not strategy_type:
continue
strategy_type_lower = strategy_type.lower()
# Check if this strategy type is registered
strategy_class = self.strategy_registry._strategies.get(strategy_type_lower)
if not strategy_class:
self.logger.debug(f"Strategy type {strategy_type} not registered")
continue
try:
# Create instance with user's parameters
instance = strategy_class(
name=db_strategy.name,
parameters=db_strategy.parameters or {},
timeframes=db_strategy.timeframes or [self.timeframe]
)
instance.enabled = True
# Use strategy_type as key (overwrites default if exists)
self._strategy_instances[strategy_type_lower] = instance
self.logger.info(
f"Loaded user-configured strategy: {db_strategy.name} "
f"(type: {strategy_type}, params: {db_strategy.parameters})"
)
except Exception as e:
self.logger.warning(
f"Failed to load user strategy {db_strategy.name}: {e}"
)
except Exception as e:
self.logger.error(f"Error loading user-configured strategies: {e}")
def update_market_data(self, ohlcv_data: pd.DataFrame):
"""Update OHLCV market data.
Args:
ohlcv_data: DataFrame with columns: open, high, low, close, volume
"""
self._ohlcv_data = ohlcv_data.copy()
self.logger.debug(f"Updated market data: {len(ohlcv_data)} candles")
async def start(self):
"""Start the intelligent autopilot loop."""
redis_client = self.redis.get_client()
# Check if already running (distributed lock style)
is_running = await redis_client.get(self.key_running)
if is_running:
self.logger.warning(f"Autopilot for {self.symbol} is already running elsewhere")
return
# Set running state in Redis
await redis_client.set(self.key_running, "1")
self._local_running = True
self.logger.info(
f"Intelligent autopilot started for {self.symbol} "
f"(timeframe: {self.timeframe}, interval: {self.interval}s)"
)
# Initial data fetch
await self._fetch_market_data()
# Try to load or train model
try:
# Check if this task should be offloaded to Celery in future
await self.strategy_selector.train_model(force_retrain=False)
except Exception as e:
self.logger.warning(f"Model training failed, will use fallback: {e}")
# Load user-configured strategies from database
await self._load_user_configured_strategies()
self.logger.info(
f"Autopilot has {len(self._strategy_instances)} strategies available: "
f"{list(self._strategy_instances.keys())}"
)
try:
while self._local_running:
# Distributed check: verify key still exists
if not await redis_client.exists(self.key_running):
self.logger.info("Remote stop signal received")
break
try:
await self._analysis_cycle()
except Exception as e:
self.logger.error(f"Analysis cycle error: {e}")
await asyncio.sleep(self.interval)
finally:
self._local_running = False
# Ensure key is deleted (cleanup)
await redis_client.delete(self.key_running)
self.logger.info("Intelligent autopilot stopped")
async def _analysis_cycle(self):
"""Perform one analysis and execution cycle."""
# Check daily trade limit from Redis
trades_today = await self._get_trades_today()
if trades_today >= self.max_trades_per_day:
self.logger.debug(f"Daily trade limit reached: {trades_today}/{self.max_trades_per_day}")
return
# Get market data (refresh if needed)
if self._ohlcv_data is None or len(self._ohlcv_data) < 50:
await self._fetch_market_data()
if self._ohlcv_data is None or len(self._ohlcv_data) < 50:
self.logger.warning("Insufficient market data for analysis")
return
# Analyze market conditions
market_conditions = self.market_analyzer.analyze_current_conditions(
symbol=self.symbol,
timeframe=self.timeframe,
ohlcv_data=self._ohlcv_data
)
# Select best strategy using ML
selection_result = self.strategy_selector.select_best_strategy(
market_conditions,
min_confidence=self.min_confidence_threshold
)
if selection_result is None:
self.logger.debug("No strategy selected (confidence too low)")
self._last_analysis = {
'market_conditions': market_conditions.to_dict(),
'selected_strategy': None,
'confidence': 0.0
}
return
selected_strategy_name, confidence, all_predictions = selection_result
self._selected_strategy = selected_strategy_name
# Get strategy instance
strategy_instance = self._strategy_instances.get(selected_strategy_name)
if not strategy_instance:
self.logger.error(f"Strategy instance not found: {selected_strategy_name}")
return
# Generate signal from selected strategy
current_price = Decimal(str(self._ohlcv_data['close'].iloc[-1]))
signal = await self._generate_strategy_signal(
strategy_instance,
current_price
)
if signal is None or signal.signal_type == SignalType.HOLD:
self.logger.debug(f"No actionable signal from {selected_strategy_name}")
self._last_analysis = {
'market_conditions': market_conditions.to_dict(),
'selected_strategy': selected_strategy_name,
'confidence': confidence,
'signal': None
}
return
# Evaluate opportunity
if not self._evaluate_opportunity(signal, confidence, market_conditions):
self.logger.debug("Opportunity does not meet execution criteria")
return
# Execute trade
if self.enable_auto_execution:
await self._execute_opportunity(strategy_instance, signal, market_conditions)
else:
self.logger.info(f"Auto-execution disabled. Signal: {signal.signal_type.value}")
# Store analysis result
self._last_analysis = {
'market_conditions': market_conditions.to_dict(),
'selected_strategy': selected_strategy_name,
'confidence': confidence,
'signal': {
'type': signal.signal_type.value,
'strength': signal.strength,
'price': float(signal.price) if signal.price else None
}
}
async def _generate_strategy_signal(
self,
strategy: BaseStrategy,
current_price: Decimal
) -> Optional[StrategySignal]:
"""Generate signal from strategy."""
try:
# Prepare market data for strategy
latest_candle = self._ohlcv_data.iloc[-1]
data = {
'open': latest_candle['open'],
'high': latest_candle['high'],
'low': latest_candle['low'],
'volume': latest_candle['volume']
}
# Generate signal
signal = await strategy.on_tick(
symbol=self.symbol,
price=current_price,
timeframe=self.timeframe,
data=data
)
if signal:
# Process signal
signal = strategy.on_signal(signal)
return signal
except Exception as e:
self.logger.error(f"Error generating strategy signal: {e}")
return None
def _evaluate_opportunity(
self,
signal: StrategySignal,
confidence: float,
market_conditions: MarketConditions
) -> bool:
"""Evaluate if opportunity meets execution criteria."""
# Check confidence threshold
if confidence < self.min_confidence_threshold:
return False
# Check signal strength
if signal.strength < 0.5:
return False
return True
async def _can_execute_order(
self,
side: OrderSide,
quantity: Decimal,
price: Decimal
) -> Tuple[bool, str]:
"""Pre-flight check if order can be executed."""
# 1. Check minimum order value ($1 USD)
order_value = quantity * price
if order_value < Decimal('1.0'):
return False, f"Order value ${order_value:.2f} below minimum $1.00"
# 2. For BUY: check sufficient balance (include fee buffer)
if side == OrderSide.BUY:
if self.paper_trading:
balance = self.trading_engine.paper_trading.get_balance()
else:
balance = Decimal('0') # Live trading balance check handled elsewhere
# Add 1% fee buffer
fee_estimate = order_value * Decimal('0.01')
total_required = order_value + fee_estimate
if balance < total_required:
return False, f"Insufficient funds: ${balance:.2f} < ${total_required:.2f}"
# 3. For SELL: check position exists with sufficient quantity
if side == OrderSide.SELL:
if self.paper_trading:
positions = self.trading_engine.paper_trading.get_positions()
position = next((p for p in positions if p.symbol == self.symbol), None)
if not position:
return False, f"No position to sell for {self.symbol}"
if position.quantity < quantity:
return False, f"Insufficient position: {position.quantity} < {quantity}"
return True, "OK"
def _determine_order_type_and_price(
self,
side: OrderSide,
signal_strength: float,
current_price: Decimal,
is_stop_loss: bool = False
) -> Tuple[OrderType, Optional[Decimal]]:
"""Determine optimal order type and price for execution."""
# Strong signals (>80% strength) or stop-loss use MARKET for speed
if signal_strength > 0.8 or is_stop_loss:
reason = "stop-loss" if is_stop_loss else f"high strength ({signal_strength:.2f})"
self.logger.debug(f"Using MARKET order ({reason})")
return OrderType.MARKET, None
# Normal signals use LIMIT for better price
# BUY: bid slightly below market (0.1% discount)
# SELL (take-profit): ask slightly above market (0.1% premium)
if side == OrderSide.BUY:
limit_price = current_price * Decimal('0.999') # 0.1% below
else:
limit_price = current_price * Decimal('1.001') # 0.1% above
# Round to reasonable precision (2 decimal places for USD pairs)
limit_price = limit_price.quantize(Decimal('0.01'))
return OrderType.LIMIT, limit_price
async def _execute_opportunity(
self,
strategy: BaseStrategy,
signal: StrategySignal,
market_conditions: MarketConditions
):
"""Execute trade opportunity."""
try:
current_price = signal.price or Decimal(str(self._ohlcv_data['close'].iloc[-1]))
# Calculate position size
balance = self.trading_engine.paper_trading.get_balance() if self.paper_trading else Decimal(0)
quantity = strategy.calculate_position_size(signal, balance, current_price)
if quantity <= 0:
self.logger.warning("Invalid position size calculated")
return
# Determine order side
if signal.signal_type == SignalType.BUY:
side = OrderSide.BUY
elif signal.signal_type == SignalType.SELL:
side = OrderSide.SELL
else:
self.logger.warning(f"Unsupported signal type: {signal.signal_type}")
return
# Pre-flight validation - skip order if it would fail
can_execute, reason = await self._can_execute_order(side, quantity, current_price)
if not can_execute:
self.logger.info(f"Skipping order ({strategy.name}): {reason}")
return
# Determine optimal order type
is_stop_loss = False
if side == OrderSide.SELL and self.paper_trading:
positions = self.trading_engine.paper_trading.get_positions()
position = next((p for p in positions if p.symbol == self.symbol), None)
if position:
is_stop_loss = current_price < position.entry_price
order_type, limit_price = self._determine_order_type_and_price(
side=side,
signal_strength=signal.strength,
current_price=current_price,
is_stop_loss=is_stop_loss
)
# Execute order with smart order type
order = await self.trading_engine.execute_order(
exchange_id=self.exchange_id,
strategy_id=None, # Intelligent autopilot doesn't use strategy_id
symbol=self.symbol,
side=side,
order_type=order_type,
quantity=quantity,
price=limit_price, # None for MARKET, price for LIMIT
paper_trading=self.paper_trading
)
order_type_str = "LIMIT" if order_type == OrderType.LIMIT else "MARKET"
if order:
self.logger.info(
f"Executed {side.value} {order_type_str} order: {quantity} {self.symbol} "
f"at {limit_price or current_price} (strategy: {strategy.name})"
)
# Increment Redis trade counter
await self._increment_trades_today()
# Record trade for ML training (async, don't wait)
asyncio.create_task(self._record_trade_for_learning(
strategy.name,
market_conditions,
order
))
else:
self.logger.warning("Order execution failed")
except Exception as e:
self.logger.error(f"Error executing opportunity: {e}")
async def _record_trade_for_learning(
self,
strategy_name: str,
market_conditions: MarketConditions,
order: Any
):
"""Record trade for ML learning (called after trade completes)."""
try:
# Wait a bit for order to complete
await asyncio.sleep(5)
trade_result = {
'return_pct': 0.0,
'sharpe_ratio': 0.0,
'win_rate': 0.0,
'max_drawdown': 0.0,
'trade_count': 1
}
await self.performance_tracker.record_trade(
strategy_name=strategy_name,
market_conditions=market_conditions,
trade_result=trade_result
)
except Exception as e:
self.logger.error(f"Error recording trade for learning: {e}")
async def _get_trades_today(self) -> int:
"""Get number of trades executed today from Redis."""
try:
redis_client = self.redis.get_client()
count = await redis_client.get(self._trades_today_key)
return int(count) if count else 0
except Exception as e:
self.logger.error(f"Error getting trade count from Redis: {e}")
return 0
async def _increment_trades_today(self):
"""Increment daily trade counter in Redis."""
try:
redis_client = self.redis.get_client()
key = self._trades_today_key
# Increment
await redis_client.incr(key)
# Set expiry to 24 hours (if new key)
if await redis_client.ttl(key) == -1:
await redis_client.expire(key, 86400)
except Exception as e:
self.logger.error(f"Error incrementing trade count in Redis: {e}")
def stop(self):
"""Stop the intelligent autopilot (signals distributed stop)."""
# Synchronous method to trigger stop
# Since we need to delete Redis key async, we create a task
asyncio.create_task(self._stop_async())
async def _stop_async(self):
"""Async stop implementation."""
self._local_running = False
redis_client = self.redis.get_client()
await redis_client.delete(self.key_running)
self.logger.info("Intelligent autopilot stopping... (signal sent)")
@property
def is_running(self) -> bool:
"""Check if autopilot is running (Local check)."""
return self._local_running
def get_status(self) -> Dict[str, Any]:
"""Get current autopilot status (Synchronous wrapper)."""
return {
'symbol': self.symbol,
'timeframe': self.timeframe,
'running': self._local_running,
'selected_strategy': self._selected_strategy,
'trades_today': 0, # Fetch async if needed via get_distributed_status
'max_trades_per_day': self.max_trades_per_day,
'min_confidence_threshold': self.min_confidence_threshold,
'enable_auto_execution': self.enable_auto_execution,
'last_analysis': self._last_analysis,
'model_info': self.strategy_selector.get_model_info()
}
async def get_distributed_status(self) -> Dict[str, Any]:
"""Get full distributed status (Async)."""
try:
redis_client = self.redis.get_client()
is_running_distributed = await redis_client.exists(self.key_running)
trades_today = await self._get_trades_today()
status = self.get_status()
status['running'] = bool(is_running_distributed)
status['trades_today'] = trades_today
return status
except Exception:
# Fallback to local status if Redis fails
return self.get_status()
# Global instances (factory cache)
_intelligent_autopilots: Dict[str, IntelligentAutopilot] = {}
def get_intelligent_autopilot(
symbol: str,
exchange_id: int = 1,
timeframe: str = "1h",
**kwargs
) -> IntelligentAutopilot:
"""Get or create intelligent autopilot instance."""
key = f"{symbol}:{exchange_id}:{timeframe}"
if key not in _intelligent_autopilots:
_intelligent_autopilots[key] = IntelligentAutopilot(
symbol=symbol,
exchange_id=exchange_id,
timeframe=timeframe,
**kwargs
)
logger.info(f"Created new IntelligentAutopilot instance for {key}")
return _intelligent_autopilots[key]
def stop_all_autopilots():
"""Stop all running IntelligentAutoPilot instances."""
for key, autopilot in _intelligent_autopilots.items():
if autopilot.is_running:
autopilot.stop()
logger.info(f"Stopped IntelligentAutopilot for {key}")

View File

@@ -0,0 +1,485 @@
"""Market condition analyzer for intelligent autopilot.
Analyzes real-time market conditions and extracts features for ML model.
"""
from decimal import Decimal
from typing import Dict, Any, Optional
from enum import Enum
from dataclasses import dataclass
import pandas as pd
import numpy as np
from src.core.logger import get_logger
from src.data.indicators import get_indicators
logger = get_logger(__name__)
class MarketRegime(str, Enum):
"""Market regime classification."""
TRENDING_UP = "trending_up"
TRENDING_DOWN = "trending_down"
RANGING = "ranging"
HIGH_VOLATILITY = "high_volatility"
LOW_VOLATILITY = "low_volatility"
BREAKOUT = "breakout"
REVERSAL = "reversal"
UNKNOWN = "unknown"
@dataclass
class MarketConditions:
"""Market conditions snapshot."""
symbol: str
timeframe: str
regime: MarketRegime
features: Dict[str, float]
timestamp: pd.Timestamp
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage."""
return {
'symbol': self.symbol,
'timeframe': self.timeframe,
'regime': self.regime.value,
'features': self.features,
'timestamp': self.timestamp.isoformat()
}
class MarketAnalyzer:
"""Analyzes market conditions and extracts features for ML."""
def __init__(self):
"""Initialize market analyzer."""
self.logger = get_logger(__name__)
self.indicators = get_indicators()
def analyze_current_conditions(
self,
symbol: str,
timeframe: str,
ohlcv_data: pd.DataFrame
) -> MarketConditions:
"""Analyze current market conditions.
Args:
symbol: Trading symbol
timeframe: Timeframe
ohlcv_data: OHLCV DataFrame with columns: open, high, low, close, volume
Returns:
MarketConditions object
"""
if len(ohlcv_data) < 50:
self.logger.warning(f"Insufficient data for analysis: {len(ohlcv_data)} candles")
return MarketConditions(
symbol=symbol,
timeframe=timeframe,
regime=MarketRegime.UNKNOWN,
features={},
timestamp=pd.Timestamp.now()
)
# Extract features
features = self.extract_features(ohlcv_data)
# Classify market regime
regime = self.classify_market_regime(features, ohlcv_data)
return MarketConditions(
symbol=symbol,
timeframe=timeframe,
regime=regime,
features=features,
timestamp=pd.Timestamp.now()
)
def extract_features(self, df: pd.DataFrame) -> Dict[str, float]:
"""Extract comprehensive market features.
Args:
df: OHLCV DataFrame
Returns:
Dictionary of feature names to values
"""
features = {}
try:
close = df['close']
high = df['high']
low = df['low']
open_price = df['open']
volume = df['volume']
# Trend Features
features['sma_20'] = float(close.rolling(20).mean().iloc[-1])
features['sma_50'] = float(close.rolling(50).mean().iloc[-1]) if len(close) >= 50 else features['sma_20']
features['ema_12'] = float(self.indicators.ema(close, period=12).iloc[-1])
features['ema_26'] = float(self.indicators.ema(close, period=26).iloc[-1]) if len(close) >= 26 else features['ema_12']
# Price position relative to MAs
current_price = float(close.iloc[-1])
features['price_vs_sma20'] = (current_price - features['sma_20']) / features['sma_20'] if features['sma_20'] > 0 else 0.0
features['price_vs_sma50'] = (current_price - features['sma_50']) / features['sma_50'] if features['sma_50'] > 0 else 0.0
features['sma20_vs_sma50'] = (features['sma_20'] - features['sma_50']) / features['sma_50'] if features['sma_50'] > 0 else 0.0
# Trend strength (ADX)
if len(df) >= 14:
adx = self.indicators.adx(high, low, close, period=14)
features['adx'] = float(adx.iloc[-1]) if not pd.isna(adx.iloc[-1]) else 0.0
else:
features['adx'] = 0.0
# Momentum Features
if len(close) >= 14:
rsi = self.indicators.rsi(close, period=14)
features['rsi'] = float(rsi.iloc[-1]) if not pd.isna(rsi.iloc[-1]) else 50.0
else:
features['rsi'] = 50.0
# MACD
if len(close) >= 26:
macd_result = self.indicators.macd(close, fast=12, slow=26, signal=9)
features['macd'] = float(macd_result['macd'].iloc[-1]) if not pd.isna(macd_result['macd'].iloc[-1]) else 0.0
features['macd_signal'] = float(macd_result['signal'].iloc[-1]) if not pd.isna(macd_result['signal'].iloc[-1]) else 0.0
features['macd_histogram'] = float(macd_result['histogram'].iloc[-1]) if not pd.isna(macd_result['histogram'].iloc[-1]) else 0.0
else:
features['macd'] = 0.0
features['macd_signal'] = 0.0
features['macd_histogram'] = 0.0
# Volatility Features
if len(df) >= 14:
atr = self.indicators.atr(high, low, close, period=14)
features['atr'] = float(atr.iloc[-1]) if not pd.isna(atr.iloc[-1]) else 0.0
features['atr_percent'] = (features['atr'] / current_price) * 100 if current_price > 0 else 0.0
else:
features['atr'] = 0.0
features['atr_percent'] = 0.0
# Bollinger Bands
if len(close) >= 20:
bb = self.indicators.bollinger_bands(close, period=20, std_dev=2)
features['bb_upper'] = float(bb['upper'].iloc[-1]) if not pd.isna(bb['upper'].iloc[-1]) else current_price
features['bb_lower'] = float(bb['lower'].iloc[-1]) if not pd.isna(bb['lower'].iloc[-1]) else current_price
features['bb_middle'] = float(bb['middle'].iloc[-1]) if not pd.isna(bb['middle'].iloc[-1]) else current_price
features['bb_width'] = (features['bb_upper'] - features['bb_lower']) / features['bb_middle'] if features['bb_middle'] > 0 else 0.0
features['bb_position'] = (current_price - features['bb_lower']) / (features['bb_upper'] - features['bb_lower']) if (features['bb_upper'] - features['bb_lower']) > 0 else 0.5
else:
features['bb_upper'] = current_price
features['bb_lower'] = current_price
features['bb_middle'] = current_price
features['bb_width'] = 0.0
features['bb_position'] = 0.5
# Volume Features
if len(volume) >= 20:
volume_ma = volume.rolling(20).mean()
features['volume_ratio'] = float(volume.iloc[-1] / volume_ma.iloc[-1]) if volume_ma.iloc[-1] > 0 else 1.0
features['volume_trend'] = float((volume.iloc[-5:].mean() - volume.iloc[-20:-5].mean()) / volume.iloc[-20:-5].mean()) if volume.iloc[-20:-5].mean() > 0 else 0.0
else:
features['volume_ratio'] = 1.0
features['volume_trend'] = 0.0
# OBV (On-Balance Volume)
if len(close) >= 10:
obv = self.indicators.obv(close, volume)
features['obv_trend'] = float((obv.iloc[-1] - obv.iloc[-10]) / abs(obv.iloc[-10])) if obv.iloc[-10] != 0 else 0.0
else:
features['obv_trend'] = 0.0
# Price Action Features
features['price_change_1'] = float((close.iloc[-1] - close.iloc[-2]) / close.iloc[-2]) if len(close) >= 2 and close.iloc[-2] > 0 else 0.0
features['price_change_5'] = float((close.iloc[-1] - close.iloc[-6]) / close.iloc[-6]) if len(close) >= 6 and close.iloc[-6] > 0 else 0.0
features['price_change_20'] = float((close.iloc[-1] - close.iloc[-21]) / close.iloc[-21]) if len(close) >= 21 and close.iloc[-21] > 0 else 0.0
# High-Low Range
features['high_low_range'] = float((high.iloc[-1] - low.iloc[-1]) / current_price) if current_price > 0 else 0.0
features['high_low_range_5'] = float((high.iloc[-5:].max() - low.iloc[-5:].min()) / current_price) if current_price > 0 else 0.0
# Candlestick Patterns (simplified)
if len(df) >= 3:
features['is_bullish_candle'] = 1.0 if close.iloc[-1] > open_price.iloc[-1] else 0.0
features['candle_body_size'] = float(abs(close.iloc[-1] - open_price.iloc[-1]) / current_price) if current_price > 0 else 0.0
features['candle_wick_upper'] = float((high.iloc[-1] - max(close.iloc[-1], open_price.iloc[-1])) / current_price) if current_price > 0 else 0.0
features['candle_wick_lower'] = float((min(close.iloc[-1], open_price.iloc[-1]) - low.iloc[-1]) / current_price) if current_price > 0 else 0.0
else:
features['is_bullish_candle'] = 0.5
features['candle_body_size'] = 0.0
features['candle_wick_upper'] = 0.0
features['candle_wick_lower'] = 0.0
# Support/Resistance Proximity (simplified - using recent highs/lows)
if len(df) >= 20:
recent_high = float(high.iloc[-20:].max())
recent_low = float(low.iloc[-20:].min())
features['distance_to_resistance'] = (recent_high - current_price) / current_price if current_price > 0 else 0.0
features['distance_to_support'] = (current_price - recent_low) / current_price if current_price > 0 else 0.0
else:
features['distance_to_resistance'] = 0.0
features['distance_to_support'] = 0.0
# ==========================================
# NEW FEATURES FOR IMPROVED ML ACCURACY
# ==========================================
# Volatility Percentile - current ATR vs 30-day rolling ATR
if len(df) >= 30:
atr_series = self.indicators.atr(high, low, close, period=14)
atr_30_mean = float(atr_series.iloc[-30:].mean()) if len(atr_series) >= 30 else features.get('atr', 0)
features['volatility_percentile'] = features.get('atr', 0) / atr_30_mean if atr_30_mean > 0 else 1.0
else:
features['volatility_percentile'] = 1.0
# Momentum Delta - 3-period vs 10-period price change (acceleration)
if len(close) >= 11:
momentum_3 = float((close.iloc[-1] - close.iloc[-4]) / close.iloc[-4]) if close.iloc[-4] > 0 else 0.0
momentum_10 = float((close.iloc[-1] - close.iloc[-11]) / close.iloc[-11]) if close.iloc[-11] > 0 else 0.0
features['momentum_3'] = momentum_3
features['momentum_10'] = momentum_10
features['momentum_delta'] = momentum_3 - momentum_10 # Positive = accelerating
else:
features['momentum_3'] = 0.0
features['momentum_10'] = 0.0
features['momentum_delta'] = 0.0
# Trend Strength Change - ADX rate of change
if len(df) >= 20:
adx_series = self.indicators.adx(high, low, close, period=14)
if len(adx_series) >= 5:
adx_current = float(adx_series.iloc[-1]) if not pd.isna(adx_series.iloc[-1]) else 0.0
adx_prev = float(adx_series.iloc[-5]) if not pd.isna(adx_series.iloc[-5]) else adx_current
features['adx_change'] = adx_current - adx_prev # Positive = strengthening trend
else:
features['adx_change'] = 0.0
else:
features['adx_change'] = 0.0
# Volume Climax Detection - extreme volume with reversal candle
if len(df) >= 20:
volume_std = float(volume.iloc[-20:].std())
volume_mean = float(volume.iloc[-20:].mean())
current_volume = float(volume.iloc[-1])
z_score = (current_volume - volume_mean) / volume_std if volume_std > 0 else 0.0
is_reversal = (close.iloc[-1] < open_price.iloc[-1]) != (close.iloc[-2] < open_price.iloc[-2])
features['volume_z_score'] = z_score
features['volume_climax'] = 1.0 if (z_score > 2.0 and is_reversal) else 0.0
else:
features['volume_z_score'] = 0.0
features['volume_climax'] = 0.0
# RSI Extremes - overbought/oversold signals
rsi_value = features.get('rsi', 50.0)
features['rsi_oversold'] = 1.0 if rsi_value < 30 else 0.0
features['rsi_overbought'] = 1.0 if rsi_value > 70 else 0.0
features['rsi_extreme'] = 1.0 if (rsi_value < 25 or rsi_value > 75) else 0.0
# Price Position in Range - where is price in N-day range
if len(df) >= 20:
range_high = float(high.iloc[-20:].max())
range_low = float(low.iloc[-20:].min())
range_size = range_high - range_low
features['price_in_range'] = (current_price - range_low) / range_size if range_size > 0 else 0.5
else:
features['price_in_range'] = 0.5
# Trend Alignment - short term vs medium term
sma_10 = float(close.rolling(10).mean().iloc[-1]) if len(close) >= 10 else current_price
features['sma_10'] = sma_10
features['trend_aligned'] = 1.0 if (sma_10 > features.get('sma_20', sma_10) > features.get('sma_50', sma_10)) or \
(sma_10 < features.get('sma_20', sma_10) < features.get('sma_50', sma_10)) else 0.0
# Consecutive Candles - streak of bullish/bearish candles
if len(df) >= 5:
bullish_streak = 0
bearish_streak = 0
for i in range(1, 6):
if close.iloc[-i] > open_price.iloc[-i]:
if bearish_streak == 0:
bullish_streak += 1
else:
break
else:
if bullish_streak == 0:
bearish_streak += 1
else:
break
features['bullish_streak'] = float(bullish_streak)
features['bearish_streak'] = float(bearish_streak)
else:
features['bullish_streak'] = 0.0
features['bearish_streak'] = 0.0
# MACD Crossover proximity
macd_val = features.get('macd', 0.0)
macd_sig = features.get('macd_signal', 0.0)
if macd_sig != 0:
features['macd_signal_ratio'] = macd_val / abs(macd_sig) if macd_sig != 0 else 0.0
else:
features['macd_signal_ratio'] = 0.0
# Bollinger Band squeeze detection
bb_width = features.get('bb_width', 0.0)
if len(df) >= 30:
bb_history = []
for i in range(30):
if len(close) > i + 20:
bb_temp = self.indicators.bollinger_bands(close.iloc[:-i-1] if i > 0 else close, period=20, std_dev=2)
if 'upper' in bb_temp and 'lower' in bb_temp:
width = (float(bb_temp['upper'].iloc[-1]) - float(bb_temp['lower'].iloc[-1])) / float(bb_temp['middle'].iloc[-1]) if not pd.isna(bb_temp['middle'].iloc[-1]) and bb_temp['middle'].iloc[-1] > 0 else 0
bb_history.append(width)
if bb_history:
avg_width = np.mean(bb_history)
features['bb_squeeze'] = 1.0 if bb_width < avg_width * 0.7 else 0.0
else:
features['bb_squeeze'] = 0.0
else:
features['bb_squeeze'] = 0.0
# ==========================================
# SENTIMENT FEATURES
# ==========================================
# Fear & Greed Index (0-100, 0=extreme fear, 100=extreme greed)
# Fetch from API or use cached value
fear_greed = self._get_fear_greed_index()
features['fear_greed_index'] = fear_greed
# Normalize to -1 to 1 range (fear negative, greed positive)
features['fear_greed_normalized'] = (fear_greed - 50) / 50
# Binary indicators
features['is_extreme_fear'] = 1.0 if fear_greed < 25 else 0.0
features['is_extreme_greed'] = 1.0 if fear_greed > 75 else 0.0
except Exception as e:
self.logger.error(f"Error extracting features: {e}")
# Return empty features on error
return {}
return features
def _get_fear_greed_index(self) -> float:
"""Fetch the current Fear & Greed Index.
Uses alternative.me API for Bitcoin Fear & Greed Index.
Falls back to neutral (50) if unavailable.
Returns:
Fear & Greed value 0-100
"""
try:
import requests
# Check cache
cache_key = '_fear_greed_cache'
cache_time_key = '_fear_greed_cache_time'
if hasattr(self, cache_key) and hasattr(self, cache_time_key):
# Cache for 1 hour
cache_age = (pd.Timestamp.now() - getattr(self, cache_time_key)).total_seconds()
if cache_age < 3600:
return getattr(self, cache_key)
# Fetch from API
response = requests.get(
"https://api.alternative.me/fng/?limit=1",
timeout=5
)
if response.status_code == 200:
data = response.json()
if data.get('data') and len(data['data']) > 0:
value = float(data['data'][0].get('value', 50))
# Cache the value
setattr(self, cache_key, value)
setattr(self, cache_time_key, pd.Timestamp.now())
return value
except Exception as e:
self.logger.debug(f"Could not fetch Fear & Greed Index: {e}")
# Return neutral if unavailable
return 50.0
def classify_market_regime(
self,
features: Dict[str, float],
df: pd.DataFrame
) -> MarketRegime:
"""Classify market regime based on features.
Args:
features: Extracted features
df: OHLCV DataFrame
Returns:
MarketRegime classification
"""
if not features:
return MarketRegime.UNKNOWN
try:
close = df['close']
current_price = float(close.iloc[-1])
# Check for trending conditions
adx = features.get('adx', 0.0)
price_vs_sma20 = features.get('price_vs_sma20', 0.0)
sma20_vs_sma50 = features.get('sma20_vs_sma50', 0.0)
# Strong uptrend
if adx > 25 and price_vs_sma20 > 0.01 and sma20_vs_sma50 > 0.01:
return MarketRegime.TRENDING_UP
# Strong downtrend
if adx > 25 and price_vs_sma20 < -0.01 and sma20_vs_sma50 < -0.01:
return MarketRegime.TRENDING_DOWN
# Check volatility
atr_percent = features.get('atr_percent', 0.0)
bb_width = features.get('bb_width', 0.0)
if atr_percent > 3.0 or bb_width > 0.05:
return MarketRegime.HIGH_VOLATILITY
if atr_percent < 1.0 and bb_width < 0.02:
return MarketRegime.LOW_VOLATILITY
# Check for ranging (low ADX, price oscillating)
if adx < 20:
# Check if price is oscillating around MAs
price_change_20 = features.get('price_change_20', 0.0)
if abs(price_change_20) < 0.05: # Less than 5% change over 20 periods
return MarketRegime.RANGING
# Check for breakout
bb_position = features.get('bb_position', 0.5)
volume_ratio = features.get('volume_ratio', 1.0)
if (bb_position > 0.95 or bb_position < 0.05) and volume_ratio > 1.5:
return MarketRegime.BREAKOUT
# Check for reversal signals
rsi = features.get('rsi', 50.0)
macd_histogram = features.get('macd_histogram', 0.0)
if (rsi > 70 and macd_histogram < 0) or (rsi < 30 and macd_histogram > 0):
return MarketRegime.REVERSAL
# Default to ranging if unclear
return MarketRegime.RANGING
except Exception as e:
self.logger.error(f"Error classifying market regime: {e}")
return MarketRegime.UNKNOWN
# Global instance
_market_analyzer: Optional[MarketAnalyzer] = None
def get_market_analyzer() -> MarketAnalyzer:
"""Get global market analyzer instance."""
global _market_analyzer
if _market_analyzer is None:
_market_analyzer = MarketAnalyzer()
return _market_analyzer

519
src/autopilot/models.py Normal file
View File

@@ -0,0 +1,519 @@
"""ML model definitions and persistence for strategy selection.
Implements automated model selection with:
- XGBoost
- LightGBM
- RandomForest
- MLP Neural Network
Uses TimeSeriesSplit cross-validation and walk-forward validation.
"""
import os
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import TimeSeriesSplit, cross_val_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, top_k_accuracy_score
import joblib
import warnings
from src.core.logger import get_logger
from src.core.config import get_config
logger = get_logger(__name__)
# Try importing optional dependencies
try:
from xgboost import XGBClassifier
HAS_XGBOOST = True
except ImportError:
HAS_XGBOOST = False
logger.warning("XGBoost not installed. Run: pip install xgboost")
try:
from lightgbm import LGBMClassifier
HAS_LIGHTGBM = True
except ImportError:
HAS_LIGHTGBM = False
logger.warning("LightGBM not installed. Run: pip install lightgbm")
class ModelSelector:
"""Automated model selection for strategy prediction.
Trains multiple models, compares via cross-validation, picks the best.
"""
def __init__(self):
"""Initialize model selector."""
self.logger = get_logger(__name__)
self.config = get_config()
self.scaler = StandardScaler()
self.label_encoder = LabelEncoder()
self.feature_names: List[str] = []
self.strategy_names: List[str] = []
self.is_trained = False
self.training_metadata: Dict[str, Any] = {}
# Best model after training
self.best_model = None
self.best_model_name: str = ""
# Model storage path
self.model_dir = Path.home() / ".local" / "share" / "crypto_trader" / "models"
self.model_dir.mkdir(parents=True, exist_ok=True)
def _get_candidate_models(self) -> Dict[str, Any]:
"""Get dictionary of candidate models to compare.
Returns:
Dictionary of model name -> model instance
"""
models = {}
# RandomForest - stable, interpretable baseline
models['random_forest'] = RandomForestClassifier(
n_estimators=200,
max_depth=15,
min_samples_split=10,
min_samples_leaf=4,
class_weight='balanced',
random_state=42,
n_jobs=-1
)
# XGBoost - typically best for tabular data
if HAS_XGBOOST:
models['xgboost'] = XGBClassifier(
n_estimators=200,
max_depth=8,
learning_rate=0.1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
n_jobs=-1,
use_label_encoder=False,
eval_metric='mlogloss'
)
# LightGBM - fast, handles noisy data well
if HAS_LIGHTGBM:
models['lightgbm'] = LGBMClassifier(
n_estimators=200,
max_depth=10,
learning_rate=0.1,
subsample=0.8,
colsample_bytree=0.8,
class_weight='balanced',
random_state=42,
n_jobs=-1,
verbose=-1
)
# MLP Neural Network - high ceiling with enough data
models['mlp'] = MLPClassifier(
hidden_layer_sizes=(128, 64, 32),
activation='relu',
solver='adam',
alpha=0.01, # L2 regularization
batch_size='auto',
learning_rate='adaptive',
learning_rate_init=0.001,
max_iter=500,
early_stopping=True,
validation_fraction=0.1,
n_iter_no_change=20,
random_state=42
)
return models
def train(
self,
X: pd.DataFrame,
y: np.ndarray,
strategy_names: List[str],
use_ensemble: bool = False,
n_splits: int = 5,
training_symbols: List[str] = None
) -> Dict[str, Any]:
"""Train and select best model via cross-validation.
Args:
X: Feature matrix (market conditions)
y: Target values (strategy names)
strategy_names: List of strategy names
use_ensemble: If True, use voting ensemble of top models instead of single best
n_splits: Number of cross-validation splits
training_symbols: List of symbols used for training
Returns:
Training metrics dictionary
"""
self.strategy_names = strategy_names
self.feature_names = list(X.columns)
self.training_symbols = training_symbols or []
# Encode labels
y_encoded = self.label_encoder.fit_transform(y)
# Scale features
X_scaled = self.scaler.fit_transform(X)
# Time series cross-validation
tscv = TimeSeriesSplit(n_splits=n_splits)
# Get candidate models
models = self._get_candidate_models()
self.logger.info(f"Training {len(models)} candidate models with {len(X)} samples...")
# Evaluate each model
model_scores = {}
model_cv_results = {}
with warnings.catch_warnings():
warnings.simplefilter("ignore")
for name, model in models.items():
try:
self.logger.info(f"Evaluating {name}...")
scores = cross_val_score(
model, X_scaled, y_encoded,
cv=tscv, scoring='accuracy', n_jobs=-1
)
# Filter out nan scores before calculating mean
valid_scores = [s for s in scores if not np.isnan(s)]
if valid_scores:
mean_score = np.mean(valid_scores)
std_score = np.std(valid_scores) if len(valid_scores) > 1 else 0.0
model_scores[name] = mean_score
model_cv_results[name] = {
'mean': float(mean_score),
'std': float(std_score),
'scores': [float(s) if not np.isnan(s) else 0.0 for s in scores]
}
self.logger.info(f" {name}: {mean_score:.3f} (+/- {std_score:.3f})")
else:
self.logger.warning(f" {name}: All scores were nan, skipping")
model_scores[name] = 0.0
except Exception as e:
self.logger.warning(f" {name} failed: {e}")
model_scores[name] = 0.0
# Filter out models with zero scores before selecting best
valid_model_scores = {k: v for k, v in model_scores.items() if v > 0}
# Select best model or create ensemble
if not valid_model_scores:
self.logger.warning("No models trained successfully, using RandomForest as fallback")
self.best_model_name = "random_forest"
self.best_model = models["random_forest"]
elif use_ensemble and len(valid_model_scores) >= 2:
# Use top 3 models in voting ensemble, but only if they have at least 10% accuracy
MIN_ENSEMBLE_ACCURACY = 0.10
good_models = {k: v for k, v in valid_model_scores.items() if v >= MIN_ENSEMBLE_ACCURACY}
if len(good_models) >= 2:
top_models = sorted(good_models.items(), key=lambda x: x[1], reverse=True)[:3]
estimators = [(name, models[name]) for name, _ in top_models]
self.best_model = VotingClassifier(estimators=estimators, voting='soft')
self.best_model_name = "ensemble"
self.logger.info(f"Using ensemble of: {[n for n, _ in top_models]} (excluding models below {MIN_ENSEMBLE_ACCURACY:.0%} accuracy)")
else:
# Not enough good models for ensemble, use single best
self.best_model_name = max(valid_model_scores, key=valid_model_scores.get)
self.best_model = models[self.best_model_name]
self.logger.info(f"Not enough models for ensemble (min {MIN_ENSEMBLE_ACCURACY:.0%}), using best: {self.best_model_name}")
else:
# Use single best model
self.best_model_name = max(valid_model_scores, key=valid_model_scores.get)
self.best_model = models[self.best_model_name]
self.logger.info(f"Best model: {self.best_model_name} ({valid_model_scores[self.best_model_name]:.3f})")
# Train best model on full dataset
self.best_model.fit(X_scaled, y_encoded)
# Calculate final metrics with walk-forward validation
train_pred = self.best_model.predict(X_scaled)
train_acc = accuracy_score(y_encoded, train_pred)
# Top-3 accuracy if we have probabilities
top3_acc = None
try:
if hasattr(self.best_model, 'predict_proba'):
proba = self.best_model.predict_proba(X_scaled)
k = min(3, len(self.strategy_names))
top3_acc = float(top_k_accuracy_score(y_encoded, proba, k=k))
except Exception:
pass
# Run walk-forward validation
wf_scores = self._walk_forward_validation(X_scaled, y_encoded)
# Calculate estimated CV accuracy for ensemble (average of component models)
if self.best_model_name == "ensemble" and len(valid_model_scores) > 0:
# Use average of models in the ensemble for CV estimate
ensemble_cv_acc = np.mean([v for v in valid_model_scores.values() if v >= 0.10])
else:
ensemble_cv_acc = model_scores.get(self.best_model_name, train_acc)
metrics = {
'train_accuracy': float(train_acc),
'test_accuracy': float(ensemble_cv_acc),
'cv_mean_accuracy': float(ensemble_cv_acc),
'walk_forward_accuracy': float(np.mean(wf_scores)) if wf_scores else None,
'top3_accuracy': top3_acc,
'n_samples': len(X),
'n_features': len(self.feature_names),
'n_strategies': len(strategy_names),
'best_model': self.best_model_name,
'all_model_scores': model_cv_results
}
self.is_trained = True
self.training_metadata = {
'trained_at': datetime.utcnow().isoformat(),
'metrics': metrics,
'model_type': 'classifier',
'best_model_name': self.best_model_name,
'training_symbols': self.training_symbols if hasattr(self, 'training_symbols') else []
}
self.logger.info(f"Training complete. Best: {self.best_model_name}, "
f"CV accuracy: {metrics['cv_mean_accuracy']:.1%}")
return metrics
def _walk_forward_validation(
self,
X: np.ndarray,
y: np.ndarray,
train_ratio: float = 0.7,
step_size: int = None
) -> List[float]:
"""Perform walk-forward validation.
Train on months 1-6, test on 7. Train on 1-7, test on 8. Etc.
Args:
X: Scaled feature matrix
y: Encoded labels
train_ratio: Initial training set ratio
step_size: Step size for rolling window
Returns:
List of accuracy scores for each step
"""
n_samples = len(X)
if n_samples < 100:
self.logger.warning("Insufficient data for walk-forward validation")
return []
initial_train_size = int(n_samples * train_ratio)
if step_size is None:
step_size = max(10, n_samples // 20)
scores = []
models = self._get_candidate_models()
model = models.get(self.best_model_name, list(models.values())[0])
for i in range(initial_train_size, n_samples - step_size, step_size):
X_train = X[:i]
y_train = y[:i]
X_test = X[i:i + step_size]
y_test = y[i:i + step_size]
try:
model_clone = type(model)(**model.get_params())
model_clone.fit(X_train, y_train)
pred = model_clone.predict(X_test)
scores.append(accuracy_score(y_test, pred))
except Exception as e:
self.logger.warning(f"Walk-forward step failed: {e}")
if scores:
self.logger.info(f"Walk-forward validation: {np.mean(scores):.3f} "
f"(+/- {np.std(scores):.3f}) over {len(scores)} steps")
return scores
def predict(
self,
features: Dict[str, float]
) -> Tuple[str, float, Dict[str, float]]:
"""Predict best strategy for given market conditions.
Args:
features: Market condition features
Returns:
Tuple of (best_strategy_name, confidence_score, all_predictions)
"""
if not self.is_trained:
raise ValueError("Model not trained. Call train() first.")
# Convert features to DataFrame
feature_df = pd.DataFrame([features])
# Ensure all required features are present
for feat in self.feature_names:
if feat not in feature_df.columns:
feature_df[feat] = 0.0
# Select only required features in correct order
X = feature_df[self.feature_names]
# Scale
X_scaled = self.scaler.transform(X)
# Get probabilities for each strategy
if hasattr(self.best_model, 'predict_proba'):
probabilities = self.best_model.predict_proba(X_scaled)[0]
else:
# Fallback for models without predict_proba
pred = self.best_model.predict(X_scaled)[0]
probabilities = np.zeros(len(self.strategy_names))
probabilities[pred] = 1.0
# Map probabilities to strategy names
encoded_classes = self.label_encoder.classes_
strategy_probs = {}
for idx, prob in enumerate(probabilities):
if idx < len(encoded_classes):
strategy_name = encoded_classes[idx]
strategy_probs[strategy_name] = float(prob)
# Get best strategy
best_idx = np.argmax(probabilities)
best_strategy = encoded_classes[best_idx]
confidence = float(probabilities[best_idx])
return best_strategy, confidence, strategy_probs
def get_feature_importance(self) -> Dict[str, float]:
"""Get feature importance scores.
Returns:
Dictionary of feature names to importance scores (JSON serializable)
"""
if not self.is_trained or self.best_model is None:
return {}
try:
if hasattr(self.best_model, 'feature_importances_'):
importances = self.best_model.feature_importances_
# Convert numpy types to native Python floats for JSON serialization
return {name: float(val) for name, val in zip(self.feature_names, importances)}
elif hasattr(self.best_model, 'coefs_'):
# For MLP, use absolute mean of first layer weights
importances = np.abs(self.best_model.coefs_[0]).mean(axis=1)
return {name: float(val) for name, val in zip(self.feature_names, importances)}
except Exception as e:
self.logger.warning(f"Could not get feature importance: {e}")
return {}
def save(self, filename: Optional[str] = None) -> str:
"""Save model to disk.
Args:
filename: Optional custom filename
Returns:
Path to saved model
"""
if filename is None:
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"strategy_selector_{self.best_model_name}_{timestamp}.joblib"
filepath = self.model_dir / filename
model_data = {
'model': self.best_model,
'scaler': self.scaler,
'label_encoder': self.label_encoder,
'feature_names': self.feature_names,
'strategy_names': self.strategy_names,
'is_trained': self.is_trained,
'training_metadata': self.training_metadata,
'model_type': 'classifier',
'best_model_name': self.best_model_name
}
joblib.dump(model_data, filepath)
self.logger.info(f"Model saved to {filepath}")
return str(filepath)
def load(self, filepath: str) -> bool:
"""Load model from disk.
Args:
filepath: Path to model file
Returns:
True if loaded successfully
"""
try:
if not os.path.exists(filepath):
self.logger.error(f"Model file not found: {filepath}")
return False
model_data = joblib.load(filepath)
self.best_model = model_data['model']
self.scaler = model_data['scaler']
self.label_encoder = model_data.get('label_encoder', LabelEncoder())
self.feature_names = model_data['feature_names']
self.strategy_names = model_data['strategy_names']
self.is_trained = model_data['is_trained']
self.training_metadata = model_data.get('training_metadata', {})
self.best_model_name = model_data.get('best_model_name', 'unknown')
self.logger.info(f"Model loaded from {filepath}")
return True
except Exception as e:
self.logger.error(f"Failed to load model: {e}")
return False
def get_latest_model_path(self) -> Optional[str]:
"""Get path to latest saved model.
Returns:
Path to latest model or None
"""
model_files = list(self.model_dir.glob("strategy_selector_*.joblib"))
if not model_files:
return None
# Sort by modification time
latest = max(model_files, key=lambda p: p.stat().st_mtime)
return str(latest)
# Backwards-compatible alias
class StrategySelectorModel(ModelSelector):
"""Backwards-compatible alias for ModelSelector.
Now uses automated model selection instead of just RandomForest.
"""
def __init__(self, model_type: str = "classifier"):
"""Initialize model.
Args:
model_type: Model type (only 'classifier' supported now)
"""
super().__init__()
self.model_type = model_type
if model_type != "classifier":
self.logger.warning("Only 'classifier' model type is supported. Using classifier.")

View File

@@ -0,0 +1,351 @@
"""Strategy performance tracker for ML training data collection."""
from decimal import Decimal
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
import pandas as pd
import numpy as np
from src.core.database import get_database, Trade, Strategy
from src.core.logger import get_logger
from .market_analyzer import MarketConditions
logger = get_logger(__name__)
class PerformanceTracker:
"""Tracks strategy performance for ML training."""
def __init__(self):
"""Initialize performance tracker."""
self.db = get_database()
self.logger = get_logger(__name__)
async def record_trade(
self,
strategy_name: str,
market_conditions: MarketConditions,
trade_result: Dict[str, Any]
) -> bool:
"""Record a trade for ML training.
Args:
strategy_name: Name of strategy used
market_conditions: Market conditions at trade time
trade_result: Trade result with performance metrics
Returns:
True if recorded successfully
"""
try:
async with self.db.get_session() as session:
try:
# Store market conditions snapshot
from src.core.database import MarketConditionsSnapshot
snapshot = MarketConditionsSnapshot(
symbol=market_conditions.symbol,
timeframe=market_conditions.timeframe,
regime=market_conditions.regime.value,
features=market_conditions.features,
strategy_name=strategy_name,
timestamp=market_conditions.timestamp
)
session.add(snapshot)
# Store performance record
from src.core.database import StrategyPerformance
performance = StrategyPerformance(
strategy_name=strategy_name,
symbol=market_conditions.symbol,
timeframe=market_conditions.timeframe,
market_regime=market_conditions.regime.value,
return_pct=float(trade_result.get('return_pct', 0.0)),
sharpe_ratio=float(trade_result.get('sharpe_ratio', 0.0)),
win_rate=float(trade_result.get('win_rate', 0.0)),
max_drawdown=float(trade_result.get('max_drawdown', 0.0)),
trade_count=int(trade_result.get('trade_count', 1)),
timestamp=datetime.utcnow()
)
session.add(performance)
await session.commit()
return True
except Exception as e:
await session.rollback()
self.logger.error(f"Failed to record trade: {e}")
return False
except Exception as e:
self.logger.error(f"Error recording trade: {e}")
return False
async def get_performance_history(
self,
strategy_name: Optional[str] = None,
market_regime: Optional[str] = None,
days: int = 30
) -> pd.DataFrame:
"""Get performance history for training.
Args:
strategy_name: Filter by strategy name
market_regime: Filter by market regime
days: Number of days to look back
Returns:
DataFrame with performance history
"""
from sqlalchemy import select
try:
async with self.db.get_session() as session:
from src.core.database import StrategyPerformance, MarketConditionsSnapshot
# Query performance records
stmt = select(StrategyPerformance)
if strategy_name:
stmt = stmt.where(StrategyPerformance.strategy_name == strategy_name)
if market_regime:
stmt = stmt.where(StrategyPerformance.market_regime == market_regime)
cutoff_date = datetime.utcnow() - timedelta(days=days)
stmt = stmt.where(StrategyPerformance.timestamp >= cutoff_date)
# Limit to prevent excessive queries - if we have lots of data, sample it
stmt = stmt.order_by(StrategyPerformance.timestamp.desc()).limit(10000)
result = await session.execute(stmt)
records = result.scalars().all()
if not records:
return pd.DataFrame()
self.logger.info(f"Processing {len(records)} performance records for training data")
# Convert to DataFrame - optimize by batching snapshot queries
data = []
# Batch snapshot lookups to reduce N+1 query problem
snapshot_cache = {}
for record in records:
cache_key = f"{record.strategy_name}:{record.symbol}:{record.timeframe}:{record.timestamp.date()}"
if cache_key not in snapshot_cache:
# Get corresponding market conditions (only once per day per strategy)
snapshot_stmt = select(MarketConditionsSnapshot).filter_by(
strategy_name=record.strategy_name,
symbol=record.symbol,
timeframe=record.timeframe
).where(
MarketConditionsSnapshot.timestamp <= record.timestamp
).order_by(
MarketConditionsSnapshot.timestamp.desc()
).limit(1)
snapshot_result = await session.execute(snapshot_stmt)
snapshot = snapshot_result.scalar_one_or_none()
snapshot_cache[cache_key] = snapshot
else:
snapshot = snapshot_cache[cache_key]
row = {
'strategy_name': record.strategy_name,
'symbol': record.symbol,
'timeframe': record.timeframe,
'market_regime': record.market_regime,
'return_pct': float(record.return_pct),
'sharpe_ratio': float(record.sharpe_ratio),
'win_rate': float(record.win_rate),
'max_drawdown': float(record.max_drawdown),
'trade_count': int(record.trade_count),
'timestamp': record.timestamp
}
# Add market features if available
if snapshot and snapshot.features:
row.update(snapshot.features)
data.append(row)
return pd.DataFrame(data)
except Exception as e:
self.logger.error(f"Error getting performance history: {e}")
return pd.DataFrame()
async def calculate_metrics(
self,
strategy_name: str,
period_days: int = 30
) -> Dict[str, Any]:
"""Calculate performance metrics for a strategy.
Args:
strategy_name: Strategy name
period_days: Period to calculate metrics over
Returns:
Dictionary of performance metrics
"""
from sqlalchemy import select
try:
async with self.db.get_session() as session:
from src.core.database import StrategyPerformance
cutoff_date = datetime.utcnow() - timedelta(days=period_days)
stmt = select(StrategyPerformance).where(
StrategyPerformance.strategy_name == strategy_name,
StrategyPerformance.timestamp >= cutoff_date
)
result = await session.execute(stmt)
records = result.scalars().all()
if not records:
return {
'total_trades': 0,
'win_rate': 0.0,
'avg_return': 0.0,
'total_return': 0.0,
'sharpe_ratio': 0.0,
'max_drawdown': 0.0
}
returns = [float(r.return_pct) for r in records]
wins = [r for r in returns if r > 0]
total_return = sum(returns)
avg_return = np.mean(returns) if returns else 0.0
win_rate = len(wins) / len(returns) if returns else 0.0
# Calculate Sharpe ratio (simplified)
if len(returns) > 1:
sharpe = (avg_return / np.std(returns)) * np.sqrt(252) if np.std(returns) > 0 else 0.0
else:
sharpe = 0.0
# Max drawdown
max_dd = max([float(r.max_drawdown) for r in records]) if records else 0.0
return {
'total_trades': len(records),
'win_rate': float(win_rate),
'avg_return': float(avg_return),
'total_return': float(total_return),
'sharpe_ratio': float(sharpe),
'max_drawdown': float(max_dd)
}
except Exception as e:
self.logger.error(f"Error calculating metrics: {e}")
return {}
async def prepare_training_data(
self,
min_samples_per_strategy: int = 10
) -> Optional[Dict[str, Any]]:
"""Prepare training data for ML model.
Args:
min_samples_per_strategy: Minimum samples required per strategy
Returns:
Dictionary with 'X' (features), 'y' (targets), and 'strategy_names'
"""
try:
# Get all performance history - extended to 365 days for better training
df = await self.get_performance_history(days=365)
if df.empty:
self.logger.warning("No performance history available for training")
return None
# Filter strategies with enough samples
strategy_counts = df['strategy_name'].value_counts()
valid_strategies = strategy_counts[strategy_counts >= min_samples_per_strategy].index.tolist()
if not valid_strategies:
self.logger.warning(f"No strategies with at least {min_samples_per_strategy} samples")
return None
df = df[df['strategy_name'].isin(valid_strategies)]
# Extract features (exclude metadata columns)
feature_cols = [col for col in df.columns if col not in [
'strategy_name', 'symbol', 'timeframe', 'market_regime',
'return_pct', 'sharpe_ratio', 'win_rate', 'max_drawdown',
'trade_count', 'timestamp'
]]
# Fill missing features with 0
for col in feature_cols:
if col not in df.columns:
df[col] = 0.0
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
X = df[feature_cols]
y = df['strategy_name'].values
return {
'X': X,
'y': y,
'strategy_names': valid_strategies,
'feature_names': feature_cols,
'training_symbols': sorted(df['symbol'].unique().tolist())
}
except Exception as e:
self.logger.error(f"Error preparing training data: {e}")
return None
async def get_strategy_sample_counts(self, days: int = 365) -> Dict[str, int]:
"""Get sample counts per strategy.
Args:
days: Number of days to look back
Returns:
Dictionary mapping strategy names to sample counts
"""
from sqlalchemy import select, func
try:
async with self.db.get_session() as session:
from src.core.database import StrategyPerformance
cutoff_date = datetime.utcnow() - timedelta(days=days)
# Group by strategy name and count
stmt = select(
StrategyPerformance.strategy_name,
func.count(StrategyPerformance.id)
).where(
StrategyPerformance.timestamp >= cutoff_date
).group_by(
StrategyPerformance.strategy_name
)
result = await session.execute(stmt)
counts = dict(result.all())
return counts
except Exception as e:
self.logger.error(f"Error getting strategy sample counts: {e}")
return {}
# Global instance
_performance_tracker: Optional[PerformanceTracker] = None
def get_performance_tracker() -> PerformanceTracker:
"""Get global performance tracker instance."""
global _performance_tracker
if _performance_tracker is None:
_performance_tracker = PerformanceTracker()
return _performance_tracker

View File

@@ -0,0 +1,190 @@
"""Strategy grouping for improved ML accuracy.
Groups 14 individual strategies into 5 logical groups to reduce classification
complexity and improve model accuracy.
"""
from typing import Dict, List, Optional, Tuple
from enum import Enum
from src.core.logger import get_logger
logger = get_logger(__name__)
class StrategyGroup(str, Enum):
"""Strategy group classifications."""
TREND_FOLLOWING = "trend_following"
MEAN_REVERSION = "mean_reversion"
MOMENTUM = "momentum"
MARKET_MAKING = "market_making"
SENTIMENT_BASED = "sentiment_based"
# Map each strategy to its group
STRATEGY_TO_GROUP: Dict[str, StrategyGroup] = {
# Trend Following: Follow established trends
"moving_average": StrategyGroup.TREND_FOLLOWING,
"macd": StrategyGroup.TREND_FOLLOWING,
"confirmed": StrategyGroup.TREND_FOLLOWING,
# Mean Reversion: Bet on price returning to mean
"rsi": StrategyGroup.MEAN_REVERSION,
"bollinger_mean_reversion": StrategyGroup.MEAN_REVERSION,
"grid": StrategyGroup.MEAN_REVERSION,
"divergence": StrategyGroup.MEAN_REVERSION,
# Momentum: Capture fast price moves
"momentum": StrategyGroup.MOMENTUM,
"volatility_breakout": StrategyGroup.MOMENTUM,
# Market Making: Profit from bid-ask spread
"market_making": StrategyGroup.MARKET_MAKING,
"dca": StrategyGroup.MARKET_MAKING,
# Sentiment Based: Use external signals
"sentiment": StrategyGroup.SENTIMENT_BASED,
"pairs_trading": StrategyGroup.SENTIMENT_BASED,
"consensus": StrategyGroup.SENTIMENT_BASED,
}
# Reverse mapping: group -> list of strategies
GROUP_TO_STRATEGIES: Dict[StrategyGroup, List[str]] = {}
for strategy, group in STRATEGY_TO_GROUP.items():
if group not in GROUP_TO_STRATEGIES:
GROUP_TO_STRATEGIES[group] = []
GROUP_TO_STRATEGIES[group].append(strategy)
def get_strategy_group(strategy_name: str) -> Optional[StrategyGroup]:
"""Get the group a strategy belongs to.
Args:
strategy_name: Name of the strategy (case-insensitive)
Returns:
StrategyGroup or None if strategy not found
"""
return STRATEGY_TO_GROUP.get(strategy_name.lower())
def get_strategies_in_group(group: StrategyGroup) -> List[str]:
"""Get all strategies belonging to a group.
Args:
group: Strategy group
Returns:
List of strategy names in the group
"""
return GROUP_TO_STRATEGIES.get(group, [])
def get_all_groups() -> List[StrategyGroup]:
"""Get all available strategy groups.
Returns:
List of all strategy groups
"""
return list(StrategyGroup)
def get_best_strategy_in_group(
group: StrategyGroup,
market_features: Dict[str, float],
available_strategies: Optional[List[str]] = None
) -> Tuple[str, float]:
"""Select the best individual strategy within a group based on market conditions.
Uses rule-based heuristics to pick the optimal strategy when the ML model
has predicted a group.
Args:
group: The predicted strategy group
market_features: Current market condition features
available_strategies: Optional list of available strategies (filters choices)
Returns:
Tuple of (best_strategy_name, confidence_score)
"""
strategies = get_strategies_in_group(group)
# Filter by availability if provided
if available_strategies:
strategies = [s for s in strategies if s in available_strategies]
if not strategies:
# Fallback if no strategies available in group
logger.warning(f"No strategies available in group {group}")
return ("rsi", 0.5) # Safe default
# If only one option, return it
if len(strategies) == 1:
return (strategies[0], 0.7)
# Rule-based selection within group based on market features
rsi = market_features.get("rsi", 50.0)
adx = market_features.get("adx", 25.0)
volatility = market_features.get("atr_percent", 2.0)
volume_ratio = market_features.get("volume_ratio", 1.0)
if group == StrategyGroup.TREND_FOLLOWING:
# Strong trends: use confirmed, moderate: moving_average, diverging: macd
if adx > 30:
return ("confirmed", 0.75) if "confirmed" in strategies else ("moving_average", 0.7)
elif adx > 20:
return ("moving_average", 0.7) if "moving_average" in strategies else ("macd", 0.65)
else:
return ("macd", 0.6) if "macd" in strategies else (strategies[0], 0.55)
elif group == StrategyGroup.MEAN_REVERSION:
# Extreme RSI: rsi strategy, tight range: bollinger, low vol: grid
if rsi < 30 or rsi > 70:
return ("rsi", 0.75) if "rsi" in strategies else ("bollinger_mean_reversion", 0.7)
elif volatility < 1.5:
return ("grid", 0.7) if "grid" in strategies else ("bollinger_mean_reversion", 0.65)
else:
return ("bollinger_mean_reversion", 0.65) if "bollinger_mean_reversion" in strategies else (strategies[0], 0.6)
elif group == StrategyGroup.MOMENTUM:
# High volume spike: volatility_breakout, otherwise momentum
if volume_ratio > 1.5:
return ("volatility_breakout", 0.75) if "volatility_breakout" in strategies else ("momentum", 0.7)
else:
return ("momentum", 0.7) if "momentum" in strategies else (strategies[0], 0.65)
elif group == StrategyGroup.MARKET_MAKING:
# Low volatility: market_making, otherwise dca
if volatility < 2.0:
return ("market_making", 0.7) if "market_making" in strategies else ("dca", 0.65)
else:
return ("dca", 0.7) if "dca" in strategies else (strategies[0], 0.6)
elif group == StrategyGroup.SENTIMENT_BASED:
# Default to sentiment, fall back to consensus
if "sentiment" in strategies:
return ("sentiment", 0.65)
elif "consensus" in strategies:
return ("consensus", 0.6)
else:
return (strategies[0], 0.55)
# Fallback
return (strategies[0], 0.5)
def convert_strategy_to_group_label(strategy_name: str) -> str:
"""Convert a strategy name to its group label for ML training.
Args:
strategy_name: Individual strategy name
Returns:
Group label string (e.g., "trend_following")
"""
group = get_strategy_group(strategy_name)
if group:
return group.value
else:
logger.warning(f"Strategy '{strategy_name}' not in any group, using as-is")
return strategy_name

View File

@@ -0,0 +1,581 @@
"""ML-based strategy selector for intelligent autopilot."""
import asyncio
from typing import Dict, Any, Optional, List, Tuple
from src.core.logger import get_logger
from src.core.config import get_config
from .market_analyzer import MarketConditions, MarketRegime, get_market_analyzer
from .models import StrategySelectorModel
from .performance_tracker import get_performance_tracker
from .strategy_groups import (
StrategyGroup, get_strategy_group, get_strategies_in_group,
get_best_strategy_in_group, convert_strategy_to_group_label, get_all_groups
)
from src.strategies import get_strategy_registry
logger = get_logger(__name__)
class StrategySelector:
"""ML-based strategy selector."""
def __init__(self):
"""Initialize strategy selector."""
self.logger = get_logger(__name__)
self.config = get_config()
self.market_analyzer = get_market_analyzer()
self.performance_tracker = get_performance_tracker()
self.model = StrategySelectorModel(model_type="classifier")
self.strategy_registry = get_strategy_registry()
self._available_strategies: List[str] = []
# Bootstrap configuration - improved defaults for better ML accuracy
self.bootstrap_days = self.config.get("autopilot.intelligent.bootstrap.days", 365)
self.bootstrap_timeframe = self.config.get("autopilot.intelligent.bootstrap.timeframe", "1h")
self.min_samples_per_strategy = self.config.get("autopilot.intelligent.bootstrap.min_samples_per_strategy", 10)
self.bootstrap_symbols = self.config.get("autopilot.intelligent.bootstrap.symbols", ["BTC/USD", "ETH/USD"])
self.bootstrap_timeframes = ["15m", "1h", "4h"] # Multi-timeframe for more data
self._available_strategies: List[str] = []
self._load_available_strategies()
self._last_model_ts = 0.0
# Auto-load persisted model if available
self._try_load_saved_model()
def _try_load_saved_model(self):
"""Try to load a previously saved model."""
import os
try:
latest_model = self.model.get_latest_model_path()
if latest_model:
mtime = os.path.getmtime(latest_model)
# Only load if it's new or we haven't loaded anything
if mtime > self._last_model_ts:
if self.model.load(latest_model):
self._last_model_ts = mtime
self.logger.info(f"Loaded persisted model: {latest_model}")
else:
self.logger.warning(f"Failed to load model {latest_model}")
else:
if self._last_model_ts == 0:
self.logger.info("No saved model found, model needs training")
except Exception as e:
self.logger.warning(f"Failed to load saved model: {e}")
def _load_available_strategies(self):
"""Load list of available strategies."""
self._available_strategies = self.strategy_registry.list_available()
self.logger.info(f"Available strategies: {self._available_strategies}")
async def train_model(
self,
min_samples_per_strategy: int = 10,
force_retrain: bool = False
) -> Dict[str, Any]:
"""Train the ML model on historical performance data.
Args:
min_samples_per_strategy: Minimum samples required per strategy
force_retrain: Force retraining even if model exists
Returns:
Training metrics dictionary
"""
# Try to load existing model first
if not force_retrain:
latest_model = self.model.get_latest_model_path()
if latest_model and self.model.load(latest_model):
self.logger.info("Loaded existing trained model")
return self.model.training_metadata.get('metrics', {})
# Prepare training data
training_data = await self.performance_tracker.prepare_training_data(
min_samples_per_strategy=min_samples_per_strategy
)
if training_data is None:
self.logger.warning("Insufficient training data. Using fallback rule-based selection.")
return {}
X = training_data['X']
y = training_data['y']
strategy_names = training_data['strategy_names']
# Convert individual strategy names to group labels for better ML accuracy
# This reduces the number of classes from 14 to 5
y_groups = [convert_strategy_to_group_label(name) for name in y]
group_names = [g.value for g in get_all_groups()]
self.logger.info(f"Training with {len(set(y_groups))} strategy groups (from {len(strategy_names)} strategies)")
if len(X) < min_samples_per_strategy * len(group_names):
self.logger.warning("Insufficient training data. Using fallback rule-based selection.")
return {}
# Train model with ensemble mode for better accuracy
# Ensemble combines XGBoost + LightGBM + RandomForest via voting
# Now predicts GROUPS instead of individual strategies
metrics = self.model.train(
X,
y_groups, # Use group labels
group_names, # Use group names
use_ensemble=True,
training_symbols=training_data.get('training_symbols', [])
)
# Save model
self.model.save()
return metrics
def select_best_strategy(
self,
market_conditions: MarketConditions,
min_confidence: float = 0.5
) -> Optional[Tuple[str, float, Dict[str, float]]]:
"""Select best strategy for current market conditions.
The ML model predicts a strategy GROUP, then we use rule-based logic
to select the best individual strategy within that group.
Args:
market_conditions: Current market conditions
min_confidence: Minimum confidence threshold
Returns:
Tuple of (strategy_name, confidence, all_predictions) or None
"""
# If model not trained, use fallback
if not self.model.is_trained:
self.logger.debug("Model not trained, using rule-based fallback")
return self._fallback_strategy_selection(market_conditions)
try:
# Predict strategy GROUP using ML model
predicted_group, group_confidence, all_group_predictions = self.model.predict(
market_conditions.features
)
if group_confidence < min_confidence:
self.logger.debug(
f"ML group confidence {group_confidence:.2f} below threshold {min_confidence}, "
"using fallback"
)
return self._fallback_strategy_selection(market_conditions)
# Convert group string to enum
try:
group_enum = StrategyGroup(predicted_group)
except ValueError:
self.logger.warning(f"Unknown group '{predicted_group}', using fallback")
return self._fallback_strategy_selection(market_conditions)
# Select best individual strategy within the predicted group
best_strategy, strategy_confidence = get_best_strategy_in_group(
group_enum,
market_conditions.features,
available_strategies=self._available_strategies
)
# Combined confidence: group confidence * strategy confidence
combined_confidence = group_confidence * strategy_confidence
# Build all_predictions dict with individual strategies
all_predictions = {}
for group_name, group_score in all_group_predictions.items():
try:
group_e = StrategyGroup(group_name)
strategies_in_group = get_strategies_in_group(group_e)
for strat in strategies_in_group:
if strat in self._available_strategies:
all_predictions[strat] = group_score
except ValueError:
pass
self.logger.info(
f"ML selected group: {predicted_group} (conf: {group_confidence:.2f}) "
f"-> strategy: {best_strategy} (combined conf: {combined_confidence:.2f})"
)
return best_strategy, combined_confidence, all_predictions
except Exception as e:
self.logger.error(f"Error in ML prediction: {e}")
return self._fallback_strategy_selection(market_conditions)
def _fallback_strategy_selection(
self,
market_conditions: MarketConditions
) -> Optional[Tuple[str, float, Dict[str, float]]]:
"""Fallback rule-based strategy selection.
Args:
market_conditions: Current market conditions
Returns:
Tuple of (strategy_name, confidence, all_predictions)
"""
features = market_conditions.features
regime = market_conditions.regime
# Rule-based selection based on market regime
strategy_rules = {
MarketRegime.TRENDING_UP: "moving_average",
MarketRegime.TRENDING_DOWN: "moving_average",
MarketRegime.RANGING: "rsi",
MarketRegime.HIGH_VOLATILITY: "momentum",
MarketRegime.LOW_VOLATILITY: "grid",
MarketRegime.BREAKOUT: "momentum",
MarketRegime.REVERSAL: "rsi"
}
# Map regime enum to string
regime_str = regime.value if hasattr(regime, 'value') else str(regime)
# Select strategy based on regime
selected_strategy = strategy_rules.get(regime, "rsi")
# Check if strategy is available
if selected_strategy not in self._available_strategies:
# Fallback to first available
if self._available_strategies:
selected_strategy = self._available_strategies[0]
else:
return None
# Calculate confidence based on regime clarity
rsi = features.get('rsi', 50.0)
adx = features.get('adx', 0.0)
# Higher confidence for clear signals
if regime in [MarketRegime.TRENDING_UP, MarketRegime.TRENDING_DOWN] and adx > 25:
confidence = 0.7
elif regime == MarketRegime.RANGING and (rsi < 30 or rsi > 70):
confidence = 0.65
else:
confidence = 0.5
all_predictions = {selected_strategy: confidence}
self.logger.info(
f"Fallback selected strategy: {selected_strategy} "
f"(confidence: {confidence:.2f}, regime: {regime_str})"
)
return selected_strategy, confidence, all_predictions
def get_strategy_rankings(
self,
market_conditions: MarketConditions
) -> List[Tuple[str, float]]:
"""Get all strategies ranked by expected performance.
Args:
market_conditions: Current market conditions
Returns:
List of (strategy_name, score) tuples, sorted by score descending
"""
result = self.select_best_strategy(market_conditions, min_confidence=0.0)
if result is None:
return []
_, _, all_predictions = result
# Sort by score descending
rankings = sorted(
all_predictions.items(),
key=lambda x: x[1],
reverse=True
)
return rankings
def update_model(self, trade_result: Dict[str, Any]) -> bool:
"""Update model with new trade result (incremental learning).
Args:
trade_result: Trade result with performance metrics
Returns:
True if update successful
"""
# For now, we'll retrain periodically rather than incremental updates
# This can be enhanced later with online learning algorithms
self.logger.debug("Model update requested (will retrain on next cycle)")
return True
async def bootstrap_training_data(
self,
symbol: str = "BTC/USD",
timeframe: Optional[str] = None,
days: Optional[int] = None,
exchange_name: str = "Binance Public"
) -> Dict[str, Any]:
"""Bootstrap training data by running backtests on historical data.
Args:
symbol: Trading symbol
timeframe: Timeframe (defaults to config value)
days: Number of days of historical data (defaults to config value)
exchange_name: Exchange name
Returns:
Dictionary with bootstrap results
"""
# Use config values as defaults
if timeframe is None:
timeframe = self.bootstrap_timeframe
if days is None:
days = self.bootstrap_days
from src.backtesting.engine import BacktestingEngine
from src.exchanges.public_data import PublicDataAdapter
from src.data.collector import get_data_collector
from src.core.database import get_database, MarketData
from datetime import datetime, timedelta
from decimal import Decimal
from sqlalchemy import select, func
import time
self.logger.info(f"Bootstrapping training data for {symbol} ({timeframe})")
db = get_database()
# Step 1: Check and fetch historical data
async with db.get_session() as session:
try:
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=days)
# Check if we have data
stmt = select(func.count()).select_from(MarketData).where(
MarketData.exchange == exchange_name,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe,
MarketData.timestamp >= start_date
)
result = await session.execute(stmt)
existing_data = result.scalar()
if existing_data < 100: # Need at least 100 candles
self.logger.info(f"Fetching historical data ({days} days)...")
adapter = PublicDataAdapter()
if await adapter.connect():
collector = get_data_collector()
current_date = start_date
chunk_days = 30
while current_date < end_date:
chunk_end = min(current_date + timedelta(days=chunk_days), end_date)
ohlcv = await adapter.get_ohlcv(
symbol=symbol,
timeframe=timeframe,
since=current_date,
limit=1000
)
if ohlcv:
await collector.store_ohlcv(exchange_name, symbol, timeframe, ohlcv)
current_date = chunk_end
time.sleep(1)
await adapter.disconnect()
self.logger.info("Historical data fetched")
else:
self.logger.error("Failed to connect to exchange")
return {"error": "Failed to fetch historical data"}
except Exception as e:
self.logger.error(f"Error checking/fetching historical data: {e}")
return {"error": f"Database error: {e}"}
# Step 2: Run backtests for each strategy
backtest_engine = BacktestingEngine()
bootstrap_results = []
self.logger.info(f"Available strategies for bootstrap: {self._available_strategies}")
total_strategies = len(self._available_strategies)
for strategy_idx, strategy_name in enumerate(self._available_strategies):
try:
strategy_class = self.strategy_registry._strategies.get(strategy_name.lower())
if not strategy_class:
continue
strategy = strategy_class(
name=strategy_name,
parameters={},
timeframes=[timeframe]
)
strategy.enabled = True
self.logger.info(f"Running backtest for {strategy_name} ({strategy_idx + 1}/{total_strategies})...")
# Run backtest
results = await backtest_engine.run_backtest(
strategy=strategy,
symbol=symbol,
exchange=exchange_name,
timeframe=timeframe,
start_date=start_date,
end_date=end_date,
initial_capital=Decimal("10000.0"),
slippage=0.001
)
if "error" in results:
self.logger.warning(f"Backtest failed for {strategy_name}: {results['error']}")
continue
self.logger.info(f"Backtest successful for {strategy_name}, return: {results.get('total_return', 0.0)}%")
# Get market conditions for the period
market_data = await backtest_engine._get_historical_data(
exchange_name, symbol, timeframe, start_date, end_date
)
if len(market_data) > 0:
# =====================================================
# IMPROVED SAMPLING: Regime-change detection + time spacing
# =====================================================
# This creates more diverse, independent training samples
# by sampling at meaningful market transitions rather than
# every N candles (which creates nearly-identical samples)
# Minimum time spacing between samples (in candles)
# For 1h: 24 candles = 24 hours
# For 15m: 96 candles = 24 hours
# For 4h: 6 candles = 24 hours
timeframe_to_spacing = {
"1m": 1440, # 24 hours
"5m": 288,
"15m": 96,
"30m": 48,
"1h": 24,
"4h": 6,
"1d": 1
}
min_spacing = timeframe_to_spacing.get(timeframe, 24)
samples_recorded = 0
last_regime = None
last_sample_idx = -min_spacing # Allow first sample immediately
# Limit processing to prevent excessive computation
# For small datasets (5 days), process all points but yield periodically
data_points = len(market_data) - 50
self.logger.info(f"Processing {data_points} data points for {strategy_name}...")
# Need at least 50 candles for feature calculation
for i in range(50, len(market_data)):
# Yield control periodically to prevent blocking (every 10 iterations)
if i % 10 == 0:
await asyncio.sleep(0) # Yield to event loop
sample_data = market_data.iloc[i-50:i]
conditions = self.market_analyzer.analyze_current_conditions(
symbol, timeframe, sample_data
)
current_regime = conditions.regime
# Determine if we should sample at this point
# 1. Sample on regime change (market transition = valuable data)
# 2. Sample after minimum time spacing (ensure independence)
regime_changed = (last_regime is not None and current_regime != last_regime)
time_elapsed = (i - last_sample_idx) >= min_spacing
should_sample = regime_changed or (time_elapsed and last_regime is None) or (time_elapsed and i == 50)
# Also sample periodically even without regime change (every 2x min_spacing)
periodic_sample = (i - last_sample_idx) >= (min_spacing * 2)
if should_sample or periodic_sample:
# Record as training data
trade_result = {
'return_pct': results.get('total_return', 0.0) / 100.0,
'sharpe_ratio': results.get('sharpe_ratio', 0.0),
'win_rate': results.get('win_rate', 0.5),
'max_drawdown': abs(results.get('max_drawdown', 0.0)) / 100.0,
'trade_count': results.get('total_trades', 0)
}
await self.performance_tracker.record_trade(
strategy_name=strategy_name,
market_conditions=conditions,
trade_result=trade_result
)
samples_recorded += 1
last_sample_idx = i
if regime_changed:
self.logger.debug(
f"Sampled at regime change: {last_regime} -> {current_regime}"
)
last_regime = current_regime
self.logger.info(f"Recorded {samples_recorded} samples for {strategy_name}")
bootstrap_results.append({
'strategy': strategy_name,
'trades_recorded': samples_recorded,
'backtest_return': results.get('total_return', 0.0)
})
except Exception as e:
self.logger.error(f"Error bootstrapping {strategy_name}: {e}", exc_info=True)
continue
total_samples = sum(r['trades_recorded'] for r in bootstrap_results)
self.logger.info(f"Bootstrap complete: {total_samples} training samples created")
return {
'status': 'success',
'strategies': bootstrap_results,
'total_samples': total_samples
}
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the current model.
Returns:
Dictionary with model information
"""
# Always check if a newer model is available on disk
self._try_load_saved_model()
info = {
'is_trained': self.model.is_trained,
'model_type': self.model.model_type,
'available_strategies': self._available_strategies,
'feature_count': len(self.model.feature_names) if self.model.is_trained else 0
}
if self.model.is_trained:
info.update({
'training_metadata': self.model.training_metadata,
'feature_importance': dict(
sorted(
self.model.get_feature_importance().items(),
key=lambda x: x[1],
reverse=True
)[:10] # Top 10 features
)
})
return info
# Global instance
_strategy_selector: Optional[StrategySelector] = None
def get_strategy_selector() -> StrategySelector:
"""Get global strategy selector instance."""
global _strategy_selector
if _strategy_selector is None:
_strategy_selector = StrategySelector()
return _strategy_selector

View File

View File

@@ -0,0 +1,51 @@
"""Historical data management for backtesting."""
from datetime import datetime
from typing import List, Optional
from sqlalchemy.orm import Session
from src.core.database import get_database, MarketData
from src.core.logger import get_logger
logger = get_logger(__name__)
class DataProvider:
"""Provides historical data for backtesting."""
def __init__(self):
"""Initialize data provider."""
self.db = get_database()
self.logger = get_logger(__name__)
def get_data(
self,
exchange: str,
symbol: str,
timeframe: str,
start_date: datetime,
end_date: datetime
) -> List[MarketData]:
"""Get historical data.
Args:
exchange: Exchange name
symbol: Trading symbol
timeframe: Timeframe
start_date: Start date
end_date: End date
Returns:
List of MarketData objects
"""
session = self.db.get_session()
try:
return session.query(MarketData).filter(
MarketData.exchange == exchange,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe,
MarketData.timestamp >= start_date,
MarketData.timestamp <= end_date
).order_by(MarketData.timestamp).all()
finally:
session.close()

207
src/backtesting/engine.py Normal file
View File

@@ -0,0 +1,207 @@
"""Backtesting engine with historical data replay and performance metrics."""
import pandas as pd
from decimal import Decimal
from datetime import datetime
from typing import Dict, List, Optional, Any
from sqlalchemy import select
from src.core.database import get_database, MarketData, OrderType
from src.core.logger import get_logger
from src.strategies.base import BaseStrategy
from src.trading.paper_trading import PaperTradingSimulator
from src.exchanges.factory import get_exchange
from .metrics import BacktestMetrics
from .slippage import FeeModel
logger = get_logger(__name__)
class BacktestingEngine:
"""Backtesting engine for strategy evaluation."""
def __init__(self):
"""Initialize backtesting engine."""
self.db = get_database()
self.logger = get_logger(__name__)
async def run_backtest(
self,
strategy: BaseStrategy,
symbol: str,
exchange: str,
timeframe: str,
start_date: datetime,
end_date: datetime,
initial_capital: Decimal = Decimal("100.0"),
slippage: float = 0.001, # 0.1% slippage
fee_model: Optional[FeeModel] = None,
exchange_id: Optional[int] = None,
) -> Dict[str, Any]:
"""Run backtest on strategy.
Args:
strategy: Strategy instance
symbol: Trading symbol
exchange: Exchange name
timeframe: Timeframe
start_date: Start date
end_date: End date
initial_capital: Initial capital
slippage: Slippage percentage
fee_model: FeeModel instance (if None, creates default)
exchange_id: Exchange ID for fee structure retrieval
Returns:
Backtest results dictionary
"""
# Get historical data
data = await self._get_historical_data(exchange, symbol, timeframe, start_date, end_date)
if len(data) == 0:
return {"error": "No historical data available"}
# Initialize fee model
if fee_model is None:
# Try to get exchange adapter for fee structure
exchange_adapter = None
if exchange_id:
try:
exchange_adapter = await get_exchange(exchange_id)
except Exception as e:
self.logger.warning(f"Could not get exchange adapter for fees: {e}")
fee_model = FeeModel(exchange_adapter=exchange_adapter)
# Initialize paper trading simulator
simulator = PaperTradingSimulator(initial_capital)
await simulator.initialize()
strategy.enabled = True
# Run backtest
trades = []
total_fees = Decimal(0)
for i, row in data.iterrows():
price = Decimal(str(row['close']))
# Generate signal
signal = await strategy.on_tick(
symbol,
price,
timeframe,
{'open': row['open'], 'high': row['high'], 'low': row['low'], 'volume': row['volume']}
)
if signal and strategy.should_execute(signal):
# Determine order type (default to MARKET for backtesting)
# In real trading, strategies could specify limit orders
order_type = OrderType.MARKET
is_maker = (order_type == OrderType.LIMIT)
# Execute trade with slippage and fees
slippage_multiplier = (Decimal("1") + Decimal(str(slippage))) if signal.signal_type.value == "buy" else (Decimal("1") - Decimal(str(slippage)))
fill_price = price * slippage_multiplier
quantity = signal.quantity or strategy.calculate_position_size(signal, simulator.get_balance(), fill_price)
# Calculate fee using FeeModel with maker/taker distinction
fee = fee_model.calculate_fee(quantity, fill_price, is_maker=is_maker)
total_fees += fee
# Create and execute order
from src.core.database import Order as DBOrder, OrderSide
from src.trading.order_manager import get_order_manager
order_manager = get_order_manager()
order = await order_manager.create_order(
exchange_id=exchange_id or 1, # Use provided exchange_id or placeholder
strategy_id=None,
symbol=symbol,
order_type=order_type,
side=OrderSide.BUY if signal.signal_type.value == "buy" else OrderSide.SELL,
quantity=quantity,
paper_trading=True
)
if order and await simulator.execute_order(order, fill_price, fee):
trades.append({
'timestamp': i,
'price': float(fill_price),
'quantity': float(quantity),
'side': signal.signal_type.value,
'fee': float(fee),
})
# Calculate metrics
metrics = BacktestMetrics()
results = metrics.calculate_metrics(simulator, trades, initial_capital, total_fees)
# Add fee information to results
results['total_fees'] = float(total_fees)
results['fee_percentage'] = float((total_fees / initial_capital) * 100) if initial_capital > 0 else 0.0
return results
async def _get_historical_data(
self,
exchange: str,
symbol: str,
timeframe: str,
start_date: datetime,
end_date: datetime
) -> pd.DataFrame:
"""Get historical OHLCV data.
Args:
exchange: Exchange name
symbol: Trading symbol
timeframe: Timeframe
start_date: Start date
end_date: End date
Returns:
DataFrame with OHLCV data
"""
async with self.db.get_session() as session:
try:
stmt = select(MarketData).filter(
MarketData.exchange == exchange,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe,
MarketData.timestamp >= start_date,
MarketData.timestamp <= end_date
).order_by(MarketData.timestamp)
result = await session.execute(stmt)
market_data = result.scalars().all()
if len(market_data) == 0:
return pd.DataFrame()
data = {
'timestamp': [md.timestamp for md in market_data],
'open': [float(md.open) for md in market_data],
'high': [float(md.high) for md in market_data],
'low': [float(md.low) for md in market_data],
'close': [float(md.close) for md in market_data],
'volume': [float(md.volume) for md in market_data],
}
df = pd.DataFrame(data)
df.set_index('timestamp', inplace=True)
return df
except Exception as e:
logger.error(f"Failed to get historical data: {e}")
return pd.DataFrame()
# Global backtesting engine
_backtesting_engine: Optional[BacktestingEngine] = None
def get_backtest_engine() -> BacktestingEngine:
"""Get global backtesting engine instance."""
global _backtesting_engine
if _backtesting_engine is None:
_backtesting_engine = BacktestingEngine()
return _backtesting_engine

View File

@@ -0,0 +1,85 @@
"""Performance metrics for backtesting."""
from decimal import Decimal
from typing import Dict, List, Any, Optional
from src.trading.paper_trading import PaperTradingSimulator
class BacktestMetrics:
"""Calculates backtest performance metrics."""
def calculate_metrics(
self,
simulator: PaperTradingSimulator,
trades: List[Dict[str, Any]],
initial_capital: Decimal,
total_fees: Optional[Decimal] = None
) -> Dict[str, Any]:
"""Calculate backtest metrics, including fee-adjusted metrics.
Args:
simulator: Paper trading simulator
trades: List of executed trades (may include 'fee' field)
initial_capital: Initial capital
total_fees: Total fees paid (if None, calculated from trades)
Returns:
Dictionary of metrics
"""
final_value = simulator.get_portfolio_value()
# Calculate total fees if not provided
if total_fees is None:
total_fees = sum(
Decimal(str(trade.get('fee', 0))) for trade in trades
)
# Gross return (before fees)
gross_return = ((final_value - initial_capital) / initial_capital) * 100 if initial_capital > 0 else 0
# Net return (after fees) - fees are already deducted in simulator
# But we calculate what return would be without fees for comparison
value_without_fees = final_value + total_fees
net_return = ((final_value - initial_capital) / initial_capital) * 100 if initial_capital > 0 else 0
gross_return_without_fees = ((value_without_fees - initial_capital) / initial_capital) * 100 if initial_capital > 0 else 0
# Fee impact
fee_impact = gross_return_without_fees - net_return
# Calculate win rate from trades
win_rate = self._calculate_win_rate(trades)
return {
"initial_capital": float(initial_capital),
"final_capital": float(final_value),
"total_return": float(net_return), # Net return (after fees)
"gross_return": float(gross_return_without_fees), # Gross return (before fees)
"total_fees": float(total_fees),
"fee_impact_percent": float(fee_impact),
"fee_percentage": float((total_fees / initial_capital) * 100) if initial_capital > 0 else 0.0,
"total_trades": len(trades),
"win_rate": win_rate,
"sharpe_ratio": 0.0, # Placeholder - would need returns series
"max_drawdown": 0.0, # Placeholder - would need equity curve
}
def _calculate_win_rate(self, trades: List[Dict[str, Any]]) -> float:
"""Calculate win rate from trades.
Args:
trades: List of trades
Returns:
Win rate (0.0 to 1.0)
"""
if len(trades) < 2:
return 0.0
# Simple approach: count buy-sell pairs
# This is a simplified calculation - full implementation would track positions
wins = 0
total_pairs = 0
# Group trades by symbol and calculate pairs
# For now, return placeholder
return 0.5

175
src/backtesting/slippage.py Normal file
View File

@@ -0,0 +1,175 @@
"""Slippage modeling for realistic backtesting."""
from decimal import Decimal
from typing import Dict, Optional, Any
from src.core.logger import get_logger
from src.exchanges.base import BaseExchangeAdapter
logger = get_logger(__name__)
class SlippageModel:
"""Models slippage for realistic backtesting."""
def __init__(self, slippage_rate: float = 0.001):
"""Initialize slippage model.
Args:
slippage_rate: Slippage rate (0.001 = 0.1%)
"""
self.slippage_rate = slippage_rate
self.logger = get_logger(__name__)
def calculate_fill_price(
self,
order_price: Decimal,
side: str, # "buy" or "sell"
order_type: str, # "market" or "limit"
market_price: Decimal,
volume: Decimal = Decimal(0)
) -> Decimal:
"""Calculate fill price with slippage.
Args:
order_price: Order price
side: Buy or sell
order_type: Market or limit
market_price: Current market price
volume: Order volume (for market impact)
Returns:
Fill price with slippage
"""
if order_type == "limit":
# Limit orders fill at order price (if market reaches it)
return order_price
# Market orders have slippage
if side == "buy":
# Buy orders pay more (slippage up)
slippage = market_price * Decimal(str(self.slippage_rate))
# Add market impact based on volume
impact = market_price * Decimal(str(volume)) * Decimal("0.0001") # Simplified
return market_price + slippage + impact
else:
# Sell orders receive less (slippage down)
slippage = market_price * Decimal(str(self.slippage_rate))
impact = market_price * Decimal(str(volume)) * Decimal("0.0001")
return market_price - slippage - impact
class FeeModel:
"""Models exchange fees with support for dynamic fee retrieval."""
def __init__(
self,
maker_fee: Optional[float] = None,
taker_fee: Optional[float] = None,
exchange_adapter: Optional[BaseExchangeAdapter] = None,
minimum_fee: float = 0.0
):
"""Initialize fee model.
Args:
maker_fee: Maker fee rate (if None, retrieved from exchange or default)
taker_fee: Taker fee rate (if None, retrieved from exchange or default)
exchange_adapter: Exchange adapter for dynamic fee retrieval
minimum_fee: Minimum fee amount
"""
self.exchange_adapter = exchange_adapter
self.minimum_fee = Decimal(str(minimum_fee))
self.logger = get_logger(__name__)
# Set fees from parameters or retrieve from exchange
if maker_fee is not None and taker_fee is not None:
self.maker_fee = maker_fee
self.taker_fee = taker_fee
else:
fee_structure = self._get_fee_structure()
self.maker_fee = maker_fee if maker_fee is not None else fee_structure.get('maker', 0.001)
self.taker_fee = taker_fee if taker_fee is not None else fee_structure.get('taker', 0.001)
def _get_fee_structure(self) -> Dict[str, Any]:
"""Get fee structure from exchange adapter or defaults.
Returns:
Fee structure dictionary
"""
if self.exchange_adapter:
try:
return self.exchange_adapter.get_fee_structure()
except Exception as e:
self.logger.warning(f"Failed to get fee structure from exchange: {e}")
return {
'maker': 0.001, # 0.1%
'taker': 0.001, # 0.1%
'minimum': 0.0
}
def calculate_fee(
self,
quantity: Decimal,
price: Decimal,
is_maker: bool = False
) -> Decimal:
"""Calculate trading fee.
Args:
quantity: Trade quantity
price: Trade price
is_maker: True if maker order
Returns:
Trading fee
"""
if quantity <= 0 or price <= 0:
return Decimal(0)
trade_value = quantity * price
fee_rate = self.maker_fee if is_maker else self.taker_fee
fee = trade_value * Decimal(str(fee_rate))
# Apply minimum fee
if self.minimum_fee > 0 and fee < self.minimum_fee:
fee = self.minimum_fee
return fee
def estimate_round_trip_fee(
self,
quantity: Decimal,
price: Decimal
) -> Decimal:
"""Estimate total fees for a round-trip trade (buy + sell).
Args:
quantity: Trade quantity
price: Trade price
Returns:
Total estimated round-trip fee
"""
buy_fee = self.calculate_fee(quantity, price, is_maker=False)
sell_fee = self.calculate_fee(quantity, price, is_maker=False)
return buy_fee + sell_fee
def get_minimum_profit_threshold(
self,
quantity: Decimal,
price: Decimal,
multiplier: float = 2.0
) -> Decimal:
"""Calculate minimum profit threshold needed to break even after fees.
Args:
quantity: Trade quantity
price: Trade price
multiplier: Multiplier for minimum profit (default 2.0 = 2x fees)
Returns:
Minimum profit threshold
"""
round_trip_fee = self.estimate_round_trip_fee(quantity, price)
return round_trip_fee * Decimal(str(multiplier))

0
src/core/__init__.py Normal file
View File

256
src/core/config.py Normal file
View File

@@ -0,0 +1,256 @@
"""Configuration management system with YAML and environment variables."""
import os
import yaml
from pathlib import Path
from typing import Any, Dict, Optional
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
class Config:
"""Configuration manager with XDG directory support."""
def __init__(self, config_file: Optional[str] = None):
"""Initialize configuration manager.
Args:
config_file: Optional path to config file. If None, uses XDG default.
"""
self._setup_xdg_directories()
# Determine config file priority:
# 1. Explicit argument
# 2. Local project config (dev mode)
# 3. XDG config (user mode)
local_config = Path(__file__).parent.parent.parent / "config" / "config.yaml"
if config_file:
self.config_file = Path(config_file)
elif local_config.exists():
self.config_file = local_config
else:
self.config_file = self.config_dir / "config.yaml"
self._config: Dict[str, Any] = {}
self._load_config()
def _setup_xdg_directories(self):
"""Set up XDG Base Directory Specification directories."""
home = Path.home()
# XDG_CONFIG_HOME or default
xdg_config = os.getenv("XDG_CONFIG_HOME", home / ".config")
self.config_dir = Path(xdg_config) / "crypto_trader"
self.config_dir.mkdir(parents=True, exist_ok=True)
# XDG_DATA_HOME or default
xdg_data = os.getenv("XDG_DATA_HOME", home / ".local" / "share")
self.data_dir = Path(xdg_data) / "crypto_trader"
self.data_dir.mkdir(parents=True, exist_ok=True)
# Create subdirectories
(self.data_dir / "historical").mkdir(exist_ok=True)
(self.data_dir / "backups").mkdir(exist_ok=True)
(self.data_dir / "logs").mkdir(exist_ok=True)
# XDG_CACHE_HOME or default
xdg_cache = os.getenv("XDG_CACHE_HOME", home / ".cache")
self.cache_dir = Path(xdg_cache) / "crypto_trader"
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _load_config(self):
"""Load configuration from YAML file and environment variables."""
# Load defaults
default_config = self._get_default_config()
self._config = default_config.copy()
# Load from file if it exists
if self.config_file.exists():
with open(self.config_file, 'r') as f:
file_config = yaml.safe_load(f) or {}
self._config.update(file_config)
# Override with environment variables
self._load_from_env()
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration."""
return {
"app": {
"name": "Crypto Trader",
"version": "0.1.0",
},
"database": {
"type": "postgresql",
"url": None, # For PostgreSQL
},
"logging": {
"level": os.getenv("LOG_LEVEL", "INFO"),
"dir": str(self.data_dir / "logs"),
"retention_days": 30,
"rotation": "daily",
},
"paper_trading": {
"enabled": True,
"default_capital": float(os.getenv("PAPER_TRADING_CAPITAL", "100.0")),
},
"updates": {
"check_on_startup": os.getenv("UPDATE_CHECK_ON_STARTUP", "true").lower() == "true",
"repository_url": os.getenv("UPDATE_REPOSITORY_URL", ""),
},
"exchanges": {},
"strategies": {
"default_timeframe": "1h",
},
"risk": {
"max_drawdown_percent": 20.0,
"daily_loss_limit_percent": 5.0,
"position_size_percent": 2.0,
},
"trading": {
"default_fees": {
"maker": 0.001, # 0.1%
"taker": 0.001, # 0.1%
"minimum": 0.0,
},
"exchanges": {},
},
"data_providers": {
"primary": [
{"name": "kraken", "enabled": True, "priority": 1},
{"name": "coinbase", "enabled": True, "priority": 2},
{"name": "binance", "enabled": True, "priority": 3},
],
"fallback": {
"name": "coingecko",
"enabled": True,
"api_key": "",
},
"caching": {
"ticker_ttl": 2, # seconds
"ohlcv_ttl": 60, # seconds
"max_cache_size": 1000,
},
"websocket": {
"enabled": True,
"reconnect_interval": 5, # seconds
"ping_interval": 30, # seconds
},
},
"redis": {
"host": os.getenv("REDIS_HOST", "127.0.0.1"),
"port": int(os.getenv("REDIS_PORT", 6379)),
"db": int(os.getenv("REDIS_DB", 0)),
"password": os.getenv("REDIS_PASSWORD", None),
"socket_connect_timeout": 5,
},
"celery": {
"broker_url": os.getenv("CELERY_BROKER_URL", "redis://127.0.0.1:6379/0"),
"result_backend": os.getenv("CELERY_RESULT_BACKEND", "redis://127.0.0.1:6379/0"),
},
}
def _load_from_env(self):
"""Load configuration from environment variables."""
# Database
if db_url := os.getenv("DATABASE_URL"):
self._config["database"]["url"] = db_url
self._config["database"]["type"] = "postgresql"
# Logging
if log_level := os.getenv("LOG_LEVEL"):
self._config["logging"]["level"] = log_level
if log_dir := os.getenv("LOG_DIR"):
self._config["logging"]["dir"] = log_dir
# Paper trading
if capital := os.getenv("PAPER_TRADING_CAPITAL"):
self._config["paper_trading"]["default_capital"] = float(capital)
def get(self, key: str, default: Any = None) -> Any:
"""Get configuration value using dot notation.
Args:
key: Configuration key (e.g., "database.path")
default: Default value if key not found
Returns:
Configuration value or default
"""
keys = key.split(".")
value = self._config
for k in keys:
if isinstance(value, dict):
value = value.get(k)
if value is None:
return default
else:
return default
return value
def set(self, key: str, value: Any):
"""Set configuration value using dot notation.
Args:
key: Configuration key (e.g., "database.path")
value: Value to set
"""
keys = key.split(".")
config = self._config
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
config[keys[-1]] = value
def save(self):
"""Save configuration to file."""
with open(self.config_file, 'w') as f:
yaml.dump(self._config, f, default_flow_style=False, sort_keys=False)
@property
def config_dir(self) -> Path:
"""Get config directory path."""
return self._config_dir
@config_dir.setter
def config_dir(self, value: Path):
"""Set config directory."""
self._config_dir = value
@property
def data_dir(self) -> Path:
"""Get data directory path."""
return self._data_dir
@data_dir.setter
def data_dir(self, value: Path):
"""Set data directory."""
self._data_dir = value
@property
def cache_dir(self) -> Path:
"""Get cache directory path."""
return self._cache_dir
@cache_dir.setter
def cache_dir(self, value: Path):
"""Set cache directory."""
self._cache_dir = value
# Global config instance
_config_instance: Optional[Config] = None
def get_config() -> Config:
"""Get global configuration instance."""
global _config_instance
if _config_instance is None:
_config_instance = Config()
return _config_instance

416
src/core/database.py Normal file
View File

@@ -0,0 +1,416 @@
"""Database connection and models using SQLAlchemy."""
from datetime import datetime
from decimal import Decimal
from enum import Enum
from pathlib import Path
from typing import Optional
from sqlalchemy import (
create_engine, Column, Integer, String, Float, Boolean, DateTime,
Text, ForeignKey, JSON, Enum as SQLEnum, Numeric
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship, Session
from .config import get_config
Base = declarative_base()
class OrderType(str, Enum):
"""Order type enumeration."""
MARKET = "market"
LIMIT = "limit"
STOP_LOSS = "stop_loss"
TAKE_PROFIT = "take_profit"
TRAILING_STOP = "trailing_stop"
OCO = "oco"
ICEBERG = "iceberg"
class OrderSide(str, Enum):
"""Order side enumeration."""
BUY = "buy"
SELL = "sell"
class OrderStatus(str, Enum):
"""Order status enumeration."""
PENDING = "pending"
OPEN = "open"
PARTIALLY_FILLED = "partially_filled"
FILLED = "filled"
CANCELLED = "cancelled"
REJECTED = "rejected"
EXPIRED = "expired"
class TradeType(str, Enum):
"""Trade type enumeration."""
SPOT = "spot"
FUTURES = "futures"
MARGIN = "margin"
class Exchange(Base):
"""Exchange configuration and credentials."""
__tablename__ = "exchanges"
id = Column(Integer, primary_key=True)
name = Column(String(50), nullable=False, unique=True)
api_key_encrypted = Column(Text) # Encrypted API key
api_secret_encrypted = Column(Text) # Encrypted API secret
sandbox = Column(Boolean, default=False)
read_only = Column(Boolean, default=True) # Read-only mode
enabled = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
trades = relationship("Trade", back_populates="exchange")
orders = relationship("Order", back_populates="exchange")
positions = relationship("Position", back_populates="exchange")
class Strategy(Base):
"""Strategy definitions and parameters."""
__tablename__ = "strategies"
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
description = Column(Text)
strategy_type = Column(String(50)) # technical, momentum, grid, dca, etc.
class_name = Column(String(100)) # Python class name
parameters = Column(JSON) # Strategy parameters
timeframes = Column(JSON) # Multi-timeframe configuration
enabled = Column(Boolean, default=False) # Available to Autopilot
running = Column(Boolean, default=False) # Currently running manually
paper_trading = Column(Boolean, default=True)
version = Column(String(20), default="1.0.0")
schedule = Column(JSON) # Scheduling configuration
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
trades = relationship("Trade", back_populates="strategy")
backtest_results = relationship("BacktestResult", back_populates="strategy")
class Order(Base):
"""Order history with state tracking."""
__tablename__ = "orders"
id = Column(Integer, primary_key=True)
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=True)
exchange_order_id = Column(String(100)) # Exchange's order ID
symbol = Column(String(20), nullable=False)
order_type = Column(SQLEnum(OrderType), nullable=False)
side = Column(SQLEnum(OrderSide), nullable=False)
status = Column(SQLEnum(OrderStatus), default=OrderStatus.PENDING)
quantity = Column(Numeric(20, 8), nullable=False)
price = Column(Numeric(20, 8)) # For limit orders
filled_quantity = Column(Numeric(20, 8), default=0)
average_fill_price = Column(Numeric(20, 8))
fee = Column(Numeric(20, 8), default=0)
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
leverage = Column(Integer, default=1) # For futures/margin
paper_trading = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
filled_at = Column(DateTime)
# Relationships
exchange = relationship("Exchange", back_populates="orders")
strategy = relationship("Strategy")
trades = relationship("Trade", back_populates="order")
class Trade(Base):
"""All executed trades (paper and live)."""
__tablename__ = "trades"
id = Column(Integer, primary_key=True)
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=True)
order_id = Column(Integer, ForeignKey("orders.id"), nullable=True)
symbol = Column(String(20), nullable=False)
side = Column(SQLEnum(OrderSide), nullable=False)
quantity = Column(Numeric(20, 8), nullable=False)
price = Column(Numeric(20, 8), nullable=False)
fee = Column(Numeric(20, 8), default=0)
total = Column(Numeric(20, 8), nullable=False) # quantity * price + fee
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
paper_trading = Column(Boolean, default=True)
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
# Relationships
exchange = relationship("Exchange", back_populates="trades")
strategy = relationship("Strategy", back_populates="trades")
order = relationship("Order", back_populates="trades")
class Position(Base):
"""Current open positions (spot and futures)."""
__tablename__ = "positions"
id = Column(Integer, primary_key=True)
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
symbol = Column(String(20), nullable=False)
side = Column(String(10)) # long, short
quantity = Column(Numeric(20, 8), nullable=False)
entry_price = Column(Numeric(20, 8), nullable=False)
current_price = Column(Numeric(20, 8))
unrealized_pnl = Column(Numeric(20, 8), default=0)
realized_pnl = Column(Numeric(20, 8), default=0)
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
leverage = Column(Integer, default=1)
margin = Column(Numeric(20, 8))
paper_trading = Column(Boolean, default=True)
opened_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
exchange = relationship("Exchange", back_populates="positions")
class PortfolioSnapshot(Base):
"""Historical portfolio values."""
__tablename__ = "portfolio_snapshots"
id = Column(Integer, primary_key=True)
total_value = Column(Numeric(20, 8), nullable=False)
cash = Column(Numeric(20, 8), nullable=False)
positions_value = Column(Numeric(20, 8), nullable=False)
unrealized_pnl = Column(Numeric(20, 8), default=0)
realized_pnl = Column(Numeric(20, 8), default=0)
paper_trading = Column(Boolean, default=True)
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
class MarketData(Base):
"""OHLCV historical data (multiple timeframes)."""
__tablename__ = "market_data"
id = Column(Integer, primary_key=True)
exchange = Column(String(50), nullable=False)
symbol = Column(String(20), nullable=False)
timeframe = Column(String(10), nullable=False) # 1m, 5m, 15m, 1h, 1d, etc.
timestamp = Column(DateTime, nullable=False, index=True)
open = Column(Numeric(20, 8), nullable=False)
high = Column(Numeric(20, 8), nullable=False)
low = Column(Numeric(20, 8), nullable=False)
close = Column(Numeric(20, 8), nullable=False)
volume = Column(Numeric(20, 8), nullable=False)
class BacktestResult(Base):
"""Backtesting results and metrics."""
__tablename__ = "backtest_results"
id = Column(Integer, primary_key=True)
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=False)
start_date = Column(DateTime, nullable=False)
end_date = Column(DateTime, nullable=False)
initial_capital = Column(Numeric(20, 8), nullable=False)
final_capital = Column(Numeric(20, 8), nullable=False)
total_return = Column(Numeric(10, 4)) # Percentage
sharpe_ratio = Column(Numeric(10, 4))
sortino_ratio = Column(Numeric(10, 4))
max_drawdown = Column(Numeric(10, 4))
win_rate = Column(Numeric(10, 4))
total_trades = Column(Integer, default=0)
parameters = Column(JSON) # Parameters used in backtest
metrics = Column(JSON) # Additional metrics
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
strategy = relationship("Strategy", back_populates="backtest_results")
class RiskLimit(Base):
"""Risk management configuration."""
__tablename__ = "risk_limits"
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
limit_type = Column(String(50)) # max_drawdown, daily_loss, position_size, etc.
value = Column(Numeric(10, 4), nullable=False)
enabled = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class Alert(Base):
"""Alert definitions and history."""
__tablename__ = "alerts"
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
alert_type = Column(String(50)) # price, indicator, risk, system
condition = Column(JSON) # Alert condition configuration
enabled = Column(Boolean, default=True)
triggered = Column(Boolean, default=False)
triggered_at = Column(DateTime)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class RebalancingEvent(Base):
"""Portfolio rebalancing history."""
__tablename__ = "rebalancing_events"
id = Column(Integer, primary_key=True)
trigger_type = Column(String(50)) # time, threshold, manual
target_allocations = Column(JSON) # Target portfolio allocations
before_allocations = Column(JSON) # Allocations before rebalancing
after_allocations = Column(JSON) # Allocations after rebalancing
orders_placed = Column(JSON) # Orders placed for rebalancing
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
class AppState(Base):
"""Application state for recovery."""
__tablename__ = "app_state"
id = Column(Integer, primary_key=True)
key = Column(String(100), unique=True, nullable=False)
value = Column(JSON)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class AuditLog(Base):
"""Security and action audit trail."""
__tablename__ = "audit_log"
id = Column(Integer, primary_key=True)
action = Column(String(100), nullable=False)
entity_type = Column(String(50)) # exchange, strategy, order, etc.
entity_id = Column(Integer)
details = Column(JSON)
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
class MarketConditionsSnapshot(Base):
"""Market conditions snapshot for ML training."""
__tablename__ = "market_conditions_snapshot"
id = Column(Integer, primary_key=True)
symbol = Column(String(20), nullable=False)
timeframe = Column(String(10), nullable=False)
regime = Column(String(50)) # Market regime classification
features = Column(JSON) # Market condition features
strategy_name = Column(String(100)) # Strategy used at this time
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
class StrategyPerformance(Base):
"""Strategy performance records for ML training."""
__tablename__ = "strategy_performance"
id = Column(Integer, primary_key=True)
strategy_name = Column(String(100), nullable=False, index=True)
symbol = Column(String(20), nullable=False)
timeframe = Column(String(10), nullable=False)
market_regime = Column(String(50), index=True) # Market regime when trade executed
return_pct = Column(Numeric(10, 4)) # Return percentage
sharpe_ratio = Column(Numeric(10, 4)) # Sharpe ratio
win_rate = Column(Numeric(5, 2)) # Win rate (0-100)
max_drawdown = Column(Numeric(10, 4)) # Maximum drawdown
trade_count = Column(Integer, default=1) # Number of trades in this period
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
class MLModelMetadata(Base):
"""ML model metadata and versions."""
__tablename__ = "ml_model_metadata"
id = Column(Integer, primary_key=True)
model_name = Column(String(100), nullable=False)
model_type = Column(String(50)) # classifier, regressor
version = Column(String(20))
file_path = Column(String(500)) # Path to saved model file
training_metrics = Column(JSON) # Training metrics (accuracy, MSE, etc.)
feature_names = Column(JSON) # List of feature names
strategy_names = Column(JSON) # List of strategy names
training_samples = Column(Integer) # Number of training samples
trained_at = Column(DateTime, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
class Database:
"""Database connection manager."""
def __init__(self):
"""Initialize database connection."""
self.config = get_config()
self.engine = self._create_engine()
self.SessionLocal = async_sessionmaker(
bind=self.engine,
class_=AsyncSession,
expire_on_commit=False
)
# self._create_tables() # Tables should be created via alembic or separate init script in async
def _create_engine(self):
"""Create database engine."""
db_type = self.config.get("database.type", "postgresql")
if db_type == "postgresql":
db_url = self.config.get("database.url")
if not db_url:
raise ValueError("PostgreSQL URL not configured")
# Ensure URL uses async driver (e.g. postgresql+asyncpg)
if "postgresql://" in db_url and "postgresql+asyncpg://" not in db_url:
# This is a naive replacement, in production we should handle this better
db_url = db_url.replace("postgresql://", "postgresql+asyncpg://")
# Add connection timeout to prevent hanging
# asyncpg connect timeout is set via connect_timeout in connect_args
return create_async_engine(
db_url,
echo=False,
connect_args={
"server_settings": {"application_name": "crypto_trader"},
"timeout": 5, # 5 second connection timeout
},
pool_pre_ping=True, # Verify connections before using
pool_recycle=3600, # Recycle connections after 1 hour
pool_timeout=5, # Timeout when getting connection from pool
)
else:
raise ValueError(f"Unsupported database type: {db_type}. Only 'postgresql' is supported.")
async def create_tables(self):
"""Create all database tables."""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
def get_session(self) -> AsyncSession:
"""Get a database session."""
return self.SessionLocal()
async def close(self):
"""Close database connection."""
await self.engine.dispose()
# Global database instance
_db_instance: Optional[Database] = None
def get_database() -> Database:
"""Get global database instance."""
global _db_instance
if _db_instance is None:
_db_instance = Database()
return _db_instance

128
src/core/logger.py Normal file
View File

@@ -0,0 +1,128 @@
"""Configurable logging system with XDG directory support."""
import logging
import logging.handlers
import yaml
from pathlib import Path
from typing import Optional
from .config import get_config
class LoggingConfig:
"""Logging configuration manager."""
def __init__(self):
"""Initialize logging configuration."""
self.config = get_config()
self.log_dir = Path(self.config.get("logging.dir", "~/.local/share/crypto_trader/logs")).expanduser()
self.log_dir.mkdir(parents=True, exist_ok=True)
self.retention_days = self.config.get("logging.retention_days", 30)
self._setup_logging()
def _setup_logging(self):
"""Set up logging configuration."""
log_level = self.config.get("logging.level", "INFO")
level = getattr(logging, log_level.upper(), logging.INFO)
# Root logger
root_logger = logging.getLogger()
root_logger.setLevel(level)
# Clear existing handlers
root_logger.handlers.clear()
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
console_formatter = logging.Formatter(
'%(asctime)s [%(levelname)s] %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
# File handler with rotation
log_file = self.log_dir / "crypto_trader.log"
file_handler = logging.handlers.TimedRotatingFileHandler(
log_file,
when='midnight',
interval=1,
backupCount=self.retention_days,
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s [%(levelname)s] %(name)s:%(lineno)d: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_formatter)
root_logger.addHandler(file_handler)
# Compress old logs
self._setup_log_compression()
def _setup_log_compression(self):
"""Set up log compression for old log files."""
import gzip
import glob
# Compress logs older than retention period
log_files = list(self.log_dir.glob("crypto_trader.log.*"))
for log_file in log_files:
if not log_file.name.endswith('.gz'):
try:
with open(log_file, 'rb') as f_in:
with gzip.open(f"{log_file}.gz", 'wb') as f_out:
f_out.writelines(f_in)
log_file.unlink()
except Exception:
pass # Skip if compression fails
def get_logger(self, name: str) -> logging.Logger:
"""Get a logger with the specified name.
Args:
name: Logger name (typically module name)
Returns:
Logger instance
"""
logger = logging.getLogger(name)
return logger
def set_level(self, level: str):
"""Set logging level.
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
"""
log_level = getattr(logging, level.upper(), logging.INFO)
logging.getLogger().setLevel(log_level)
for handler in logging.getLogger().handlers:
handler.setLevel(log_level)
# Global logging config instance
_logging_config: Optional[LoggingConfig] = None
def get_logger(name: str) -> logging.Logger:
"""Get a logger instance.
Args:
name: Logger name (typically __name__)
Returns:
Logger instance
"""
global _logging_config
if _logging_config is None:
_logging_config = LoggingConfig()
return _logging_config.get_logger(name)
def setup_logging():
"""Set up logging system."""
global _logging_config
_logging_config = LoggingConfig()

261
src/core/pubsub.py Normal file
View File

@@ -0,0 +1,261 @@
"""Redis Pub/Sub for real-time event broadcasting across workers."""
import asyncio
import json
from typing import Callable, Dict, Any, Optional, List
from src.core.redis import get_redis_client
from src.core.logger import get_logger
logger = get_logger(__name__)
# Channel names
CHANNEL_MARKET_EVENTS = "crypto_trader:market_events"
CHANNEL_TRADE_EVENTS = "crypto_trader:trade_events"
CHANNEL_SYSTEM_EVENTS = "crypto_trader:system_events"
CHANNEL_AUTOPILOT_EVENTS = "crypto_trader:autopilot_events"
class RedisPubSub:
"""Redis Pub/Sub handler for real-time event broadcasting."""
def __init__(self):
"""Initialize Redis Pub/Sub."""
self.redis = get_redis_client()
self._subscribers: Dict[str, List[Callable]] = {}
self._pubsub = None
self._running = False
self._listen_task: Optional[asyncio.Task] = None
async def publish(self, channel: str, event_type: str, data: Dict[str, Any]) -> int:
"""Publish an event to a channel.
Args:
channel: Channel name
event_type: Type of event (e.g., 'price_update', 'trade_executed')
data: Event data
Returns:
Number of subscribers that received the message
"""
message = {
"type": event_type,
"data": data,
"timestamp": asyncio.get_event_loop().time()
}
try:
client = self.redis.get_client()
count = await client.publish(channel, json.dumps(message))
logger.debug(f"Published {event_type} to {channel} ({count} subscribers)")
return count
except Exception as e:
logger.error(f"Failed to publish to {channel}: {e}")
return 0
async def subscribe(self, channel: str, callback: Callable[[Dict[str, Any]], None]) -> None:
"""Subscribe to a channel.
Args:
channel: Channel name
callback: Async function to call when message received
"""
if channel not in self._subscribers:
self._subscribers[channel] = []
self._subscribers[channel].append(callback)
logger.info(f"Subscribed to channel: {channel}")
# Start listener if not running
if not self._running:
await self._start_listener()
async def unsubscribe(self, channel: str, callback: Callable = None) -> None:
"""Unsubscribe from a channel.
Args:
channel: Channel name
callback: Specific callback to remove, or None to remove all
"""
if channel in self._subscribers:
if callback:
self._subscribers[channel] = [c for c in self._subscribers[channel] if c != callback]
else:
del self._subscribers[channel]
logger.info(f"Unsubscribed from channel: {channel}")
async def _start_listener(self) -> None:
"""Start the Pub/Sub listener."""
if self._running:
return
self._running = True
self._listen_task = asyncio.create_task(self._listen())
logger.info("Started Redis Pub/Sub listener")
async def _listen(self) -> None:
"""Listen for messages on subscribed channels."""
try:
client = self.redis.get_client()
self._pubsub = client.pubsub()
# Subscribe to all registered channels
channels = list(self._subscribers.keys())
if channels:
await self._pubsub.subscribe(*channels)
while self._running:
try:
message = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message and message['type'] == 'message':
channel = message['channel']
if isinstance(channel, bytes):
channel = channel.decode('utf-8')
data = message['data']
if isinstance(data, bytes):
data = data.decode('utf-8')
try:
parsed = json.loads(data)
except json.JSONDecodeError:
parsed = {"raw": data}
# Call all subscribers for this channel
callbacks = self._subscribers.get(channel, [])
for callback in callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(parsed)
else:
callback(parsed)
except Exception as e:
logger.error(f"Subscriber callback error: {e}")
await asyncio.sleep(0.01) # Prevent busy loop
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Pub/Sub listener error: {e}")
await asyncio.sleep(1)
finally:
if self._pubsub:
await self._pubsub.close()
self._running = False
async def stop(self) -> None:
"""Stop the Pub/Sub listener."""
self._running = False
if self._listen_task:
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
logger.info("Stopped Redis Pub/Sub listener")
# Convenience methods for common event types
async def publish_price_update(self, symbol: str, price: float, bid: float = None, ask: float = None) -> int:
"""Publish a price update event.
Args:
symbol: Trading symbol
price: Current price
bid: Bid price
ask: Ask price
Returns:
Number of subscribers notified
"""
return await self.publish(CHANNEL_MARKET_EVENTS, "price_update", {
"symbol": symbol,
"price": price,
"bid": bid,
"ask": ask
})
async def publish_trade_executed(
self,
symbol: str,
side: str,
quantity: float,
price: float,
order_id: str = None
) -> int:
"""Publish a trade execution event.
Args:
symbol: Trading symbol
side: 'buy' or 'sell'
quantity: Trade quantity
price: Execution price
order_id: Order ID
Returns:
Number of subscribers notified
"""
return await self.publish(CHANNEL_TRADE_EVENTS, "trade_executed", {
"symbol": symbol,
"side": side,
"quantity": quantity,
"price": price,
"order_id": order_id
})
async def publish_autopilot_status(
self,
symbol: str,
status: str,
action: str = None,
details: Dict[str, Any] = None
) -> int:
"""Publish an autopilot status event.
Args:
symbol: Trading symbol
status: 'started', 'stopped', 'error', 'signal'
action: Optional action taken
details: Additional details
Returns:
Number of subscribers notified
"""
return await self.publish(CHANNEL_AUTOPILOT_EVENTS, "autopilot_status", {
"symbol": symbol,
"status": status,
"action": action,
"details": details or {}
})
async def publish_system_event(self, event_type: str, message: str, severity: str = "info") -> int:
"""Publish a system event.
Args:
event_type: Event type (e.g., 'startup', 'shutdown', 'error')
message: Event message
severity: 'info', 'warning', 'error'
Returns:
Number of subscribers notified
"""
return await self.publish(CHANNEL_SYSTEM_EVENTS, event_type, {
"message": message,
"severity": severity
})
# Global Pub/Sub instance
_redis_pubsub: Optional[RedisPubSub] = None
def get_redis_pubsub() -> RedisPubSub:
"""Get global Redis Pub/Sub instance."""
global _redis_pubsub
if _redis_pubsub is None:
_redis_pubsub = RedisPubSub()
return _redis_pubsub

70
src/core/redis.py Normal file
View File

@@ -0,0 +1,70 @@
"""Redis client wrapper."""
import redis.asyncio as redis
from typing import Optional
from src.core.config import get_config
from src.core.logger import get_logger
logger = get_logger(__name__)
class RedisClient:
"""Redis client wrapper with automatic connection handling."""
def __init__(self):
"""Initialize Redis client."""
self.config = get_config()
self._client: Optional[redis.Redis] = None
self._pool: Optional[redis.ConnectionPool] = None
def get_client(self) -> redis.Redis:
"""Get or create Redis client.
Returns:
Async Redis client
"""
if self._client is None:
self._connect()
return self._client
def _connect(self):
"""Connect to Redis."""
redis_config = self.config.get("redis", {})
host = redis_config.get("host", "localhost")
port = redis_config.get("port", 6379)
db = redis_config.get("db", 0)
password = redis_config.get("password")
logger.info(f"Connecting to Redis at {host}:{port}/{db}")
try:
self._pool = redis.ConnectionPool(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
socket_connect_timeout=redis_config.get("socket_connect_timeout", 5)
)
self._client = redis.Redis(connection_pool=self._pool)
except Exception as e:
logger.error(f"Failed to create Redis client: {e}")
raise
async def close(self):
"""Close Redis connection."""
if self._client:
await self._client.aclose()
logger.info("Redis connection closed")
# Global instance
_redis_client: Optional[RedisClient] = None
def get_redis_client() -> RedisClient:
"""Get global Redis client instance."""
global _redis_client
if _redis_client is None:
_redis_client = RedisClient()
return _redis_client

99
src/core/repositories.py Normal file
View File

@@ -0,0 +1,99 @@
"""Database repositories for data access."""
from typing import Optional, List, Sequence
from datetime import datetime
from decimal import Decimal
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from .database import Order, Position, OrderStatus, OrderSide, OrderType, MarketData
class BaseRepository:
"""Base repository."""
def __init__(self, session: AsyncSession):
"""Initialize repository."""
self.session = session
class OrderRepository(BaseRepository):
"""Order repository."""
async def create(self, order: Order) -> Order:
"""Create new order."""
self.session.add(order)
await self.session.commit()
await self.session.refresh(order)
return order
async def get_by_id(self, order_id: int) -> Optional[Order]:
"""Get order by ID."""
result = await self.session.execute(
select(Order).where(Order.id == order_id)
)
return result.scalar_one_or_none()
async def get_all(self, limit: int = 100, offset: int = 0) -> Sequence[Order]:
"""Get all orders."""
result = await self.session.execute(
select(Order).limit(limit).offset(offset).order_by(Order.created_at.desc())
)
return result.scalars().all()
async def update_status(
self,
order_id: int,
status: OrderStatus,
exchange_order_id: Optional[str] = None,
fee: Optional[Decimal] = None
) -> Optional[Order]:
"""Update order status."""
values = {"status": status, "updated_at": datetime.utcnow()}
if exchange_order_id:
values["exchange_order_id"] = exchange_order_id
if fee is not None:
values["fee"] = fee
await self.session.execute(
update(Order)
.where(Order.id == order_id)
.values(**values)
)
await self.session.commit()
return await self.get_by_id(order_id)
async def get_open_orders(self, paper_trading: bool = True) -> Sequence[Order]:
"""Get open orders."""
result = await self.session.execute(
select(Order).where(
Order.paper_trading == paper_trading,
Order.status.in_([OrderStatus.PENDING, OrderStatus.OPEN, OrderStatus.PARTIALLY_FILLED])
)
)
return result.scalars().all()
async def delete(self, order_id: int) -> bool:
"""Delete order."""
result = await self.session.execute(
delete(Order).where(Order.id == order_id)
)
await self.session.commit()
return result.rowcount > 0
class PositionRepository(BaseRepository):
"""Position repository."""
async def get_all(self, paper_trading: bool = True) -> Sequence[Position]:
"""Get all positions."""
result = await self.session.execute(
select(Position).where(Position.paper_trading == paper_trading)
)
return result.scalars().all()
async def get_by_symbol(self, symbol: str, paper_trading: bool = True) -> Optional[Position]:
"""Get position by symbol."""
result = await self.session.execute(
select(Position).where(
Position.symbol == symbol,
Position.paper_trading == paper_trading
)
)
return result.scalar_one_or_none()

27
src/data/__init__.py Normal file
View File

@@ -0,0 +1,27 @@
"""Data collection and storage module.
Provides:
- DataCollector: Real-time market data collection
- NewsCollector: Crypto news headline aggregation for sentiment analysis
- TechnicalIndicators: Technical analysis indicators
- DataStorage: Data persistence utilities
- DataQualityManager: Data quality checks
"""
from .collector import DataCollector
from .news_collector import NewsCollector, NewsItem, NewsSource, get_news_collector
from .indicators import TechnicalIndicators, get_indicators
from .storage import DataStorage
from .quality import DataQualityManager
__all__ = [
"DataCollector",
"NewsCollector",
"NewsItem",
"NewsSource",
"get_news_collector",
"TechnicalIndicators",
"get_indicators",
"DataStorage",
"DataQualityManager",
]

221
src/data/cache_manager.py Normal file
View File

@@ -0,0 +1,221 @@
"""Intelligent caching system for pricing data."""
import time
from typing import Dict, Any, Optional, Tuple
from datetime import datetime, timedelta
from collections import OrderedDict
from src.core.logger import get_logger
logger = get_logger(__name__)
class CacheEntry:
"""Cache entry with TTL support."""
def __init__(self, data: Any, ttl: float):
"""Initialize cache entry.
Args:
data: Cached data
ttl: Time to live in seconds
"""
self.data = data
self.expires_at = time.time() + ttl
self.created_at = time.time()
self.access_count = 0
self.last_accessed = time.time()
def is_expired(self) -> bool:
"""Check if entry is expired.
Returns:
True if expired
"""
return time.time() > self.expires_at
def touch(self):
"""Update access statistics."""
self.access_count += 1
self.last_accessed = time.time()
def age(self) -> float:
"""Get age of entry in seconds.
Returns:
Age in seconds
"""
return time.time() - self.created_at
class CacheManager:
"""Intelligent cache manager with TTL and size limits.
Implements LRU (Least Recently Used) eviction when size limit is reached.
"""
def __init__(
self,
default_ttl: float = 60.0,
max_size: int = 1000,
ticker_ttl: float = 2.0,
ohlcv_ttl: float = 60.0
):
"""Initialize cache manager.
Args:
default_ttl: Default TTL in seconds
max_size: Maximum number of cache entries
ticker_ttl: TTL for ticker data in seconds
ohlcv_ttl: TTL for OHLCV data in seconds
"""
self.default_ttl = default_ttl
self.max_size = max_size
self.ticker_ttl = ticker_ttl
self.ohlcv_ttl = ohlcv_ttl
# Use OrderedDict for LRU eviction
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._hits = 0
self._misses = 0
self._evictions = 0
self.logger = get_logger(__name__)
def get(self, key: str) -> Optional[Any]:
"""Get value from cache.
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
# Clean expired entries
self._cleanup_expired()
if key not in self._cache:
self._misses += 1
return None
entry = self._cache[key]
if entry.is_expired():
# Remove expired entry
del self._cache[key]
self._misses += 1
return None
# Update access (move to end for LRU)
entry.touch()
self._cache.move_to_end(key)
self._hits += 1
return entry.data
def set(
self,
key: str,
value: Any,
ttl: Optional[float] = None,
cache_type: Optional[str] = None
):
"""Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds (uses type-specific or default if None)
cache_type: Type of cache ('ticker' or 'ohlcv') for type-specific TTL
"""
# Determine TTL
if ttl is None:
if cache_type == 'ticker':
ttl = self.ticker_ttl
elif cache_type == 'ohlcv':
ttl = self.ohlcv_ttl
else:
ttl = self.default_ttl
# Check if we need to evict
if len(self._cache) >= self.max_size and key not in self._cache:
self._evict_lru()
# Create entry
entry = CacheEntry(value, ttl)
# Add or update
if key in self._cache:
self._cache.move_to_end(key)
self._cache[key] = entry
def delete(self, key: str) -> bool:
"""Delete entry from cache.
Args:
key: Cache key
Returns:
True if entry was deleted, False if not found
"""
if key in self._cache:
del self._cache[key]
return True
return False
def clear(self):
"""Clear all cache entries."""
self._cache.clear()
self.logger.info("Cache cleared")
def _cleanup_expired(self):
"""Remove expired entries from cache."""
expired_keys = [
key for key, entry in self._cache.items()
if entry.is_expired()
]
for key in expired_keys:
del self._cache[key]
def _evict_lru(self):
"""Evict least recently used entry."""
if self._cache:
# Remove oldest (first) entry
self._cache.popitem(last=False)
self._evictions += 1
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics.
Returns:
Dictionary with cache statistics
"""
total_requests = self._hits + self._misses
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0.0
# Calculate average age
if self._cache:
avg_age = sum(entry.age() for entry in self._cache.values()) / len(self._cache)
else:
avg_age = 0.0
return {
'size': len(self._cache),
'max_size': self.max_size,
'hits': self._hits,
'misses': self._misses,
'hit_rate': round(hit_rate, 2),
'evictions': self._evictions,
'avg_age_seconds': round(avg_age, 2),
}
def invalidate_pattern(self, pattern: str):
"""Invalidate entries matching a pattern.
Args:
pattern: String pattern to match (simple substring match)
"""
keys_to_delete = [key for key in self._cache.keys() if pattern in key]
for key in keys_to_delete:
del self._cache[key]
if keys_to_delete:
self.logger.info(f"Invalidated {len(keys_to_delete)} cache entries matching '{pattern}'")

139
src/data/collector.py Normal file
View File

@@ -0,0 +1,139 @@
"""Real-time data collection system with WebSocket support."""
import asyncio
from decimal import Decimal
from typing import Dict, Optional, Callable, List
from datetime import datetime
from sqlalchemy import select
from src.core.database import get_database, MarketData
from src.core.logger import get_logger
from .pricing_service import get_pricing_service
logger = get_logger(__name__)
class DataCollector:
"""Collects real-time market data using the unified pricing service."""
def __init__(self):
"""Initialize data collector."""
self.db = get_database()
self.logger = get_logger(__name__)
self._callbacks: Dict[str, List[Callable]] = {}
self._running = False
self._pricing_service = get_pricing_service()
def subscribe(
self,
exchange_id: Optional[int] = None,
symbol: str = "",
callback: Optional[Callable] = None
):
"""Subscribe to real-time data.
Args:
exchange_id: Exchange ID (deprecated, kept for backward compatibility)
symbol: Trading symbol
callback: Callback function(data)
"""
if not symbol or not callback:
logger.warning("subscribe called without symbol or callback")
return
key = f"pricing:{symbol}"
if key not in self._callbacks:
self._callbacks[key] = []
self._callbacks[key].append(callback)
# Subscribe via pricing service
def wrapped_callback(data):
"""Wrap callback to maintain backward compatibility."""
for cb in self._callbacks.get(key, []):
try:
cb(data)
except Exception as e:
logger.error(f"Callback error for {symbol}: {e}")
self._pricing_service.subscribe_ticker(symbol, wrapped_callback)
async def store_ohlcv(
self,
exchange: str,
symbol: str,
timeframe: str,
ohlcv_data: List[List]
):
"""Store OHLCV data in database.
Args:
exchange: Exchange name (can be provider name like 'CCXT-Kraken' or 'CoinGecko')
symbol: Trading symbol
timeframe: Timeframe
ohlcv_data: List of [timestamp, open, high, low, close, volume]
"""
async with self.db.get_session() as session:
try:
for candle in ohlcv_data:
timestamp = datetime.fromtimestamp(candle[0] / 1000)
# Check if already exists
stmt = select(MarketData).filter_by(
exchange=exchange,
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if not existing:
market_data = MarketData(
exchange=exchange,
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp,
open=Decimal(str(candle[1])),
high=Decimal(str(candle[2])),
low=Decimal(str(candle[3])),
close=Decimal(str(candle[4])),
volume=Decimal(str(candle[5]))
)
session.add(market_data)
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Failed to store OHLCV data: {e}")
def get_ohlcv_from_pricing_service(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV data from pricing service.
Args:
symbol: Trading symbol
timeframe: Timeframe
since: Start datetime
limit: Number of candles
Returns:
List of OHLCV candles
"""
return self._pricing_service.get_ohlcv(symbol, timeframe, since, limit)
# Global data collector
_data_collector: Optional[DataCollector] = None
def get_data_collector() -> DataCollector:
"""Get global data collector instance."""
global _data_collector
if _data_collector is None:
_data_collector = DataCollector()
return _data_collector

317
src/data/health_monitor.py Normal file
View File

@@ -0,0 +1,317 @@
"""Health monitoring and failover management for pricing providers."""
import time
from typing import Dict, List, Optional, Any
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from collections import deque
from src.core.logger import get_logger
logger = get_logger(__name__)
class HealthStatus(Enum):
"""Provider health status."""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
UNKNOWN = "unknown"
@dataclass
class HealthMetrics:
"""Health metrics for a provider."""
status: HealthStatus = HealthStatus.UNKNOWN
last_check: Optional[datetime] = None
last_success: Optional[datetime] = None
last_failure: Optional[datetime] = None
success_count: int = 0
failure_count: int = 0
response_times: deque = field(default_factory=lambda: deque(maxlen=100))
consecutive_failures: int = 0
circuit_breaker_open: bool = False
circuit_breaker_opened_at: Optional[datetime] = None
def record_success(self, response_time: float):
"""Record a successful operation.
Args:
response_time: Response time in seconds
"""
self.status = HealthStatus.HEALTHY
self.last_check = datetime.utcnow()
self.last_success = datetime.utcnow()
self.success_count += 1
self.response_times.append(response_time)
self.consecutive_failures = 0
self.circuit_breaker_open = False
self.circuit_breaker_opened_at = None
def record_failure(self):
"""Record a failed operation."""
self.last_check = datetime.utcnow()
self.last_failure = datetime.utcnow()
self.failure_count += 1
self.consecutive_failures += 1
# Open circuit breaker after 5 consecutive failures
if self.consecutive_failures >= 5:
if not self.circuit_breaker_open:
self.circuit_breaker_open = True
self.circuit_breaker_opened_at = datetime.utcnow()
logger.warning(f"Circuit breaker opened after {self.consecutive_failures} failures")
# Update status based on failure rate
total = self.success_count + self.failure_count
if total > 0:
failure_rate = self.failure_count / total
if failure_rate > 0.5:
self.status = HealthStatus.UNHEALTHY
elif failure_rate > 0.2:
self.status = HealthStatus.DEGRADED
def get_avg_response_time(self) -> float:
"""Get average response time.
Returns:
Average response time in seconds, or 0.0 if no data
"""
if not self.response_times:
return 0.0
return sum(self.response_times) / len(self.response_times)
def should_attempt(self, circuit_breaker_timeout: int = 60) -> bool:
"""Check if we should attempt to use this provider.
Args:
circuit_breaker_timeout: Seconds to wait before retrying after circuit breaker opens
Returns:
True if we should attempt, False otherwise
"""
if not self.circuit_breaker_open:
return True
# Check if timeout has passed
if self.circuit_breaker_opened_at:
elapsed = (datetime.utcnow() - self.circuit_breaker_opened_at).total_seconds()
if elapsed >= circuit_breaker_timeout:
# Half-open state: allow one attempt
logger.info("Circuit breaker half-open, allowing attempt")
return True
return False
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API responses."""
return {
'status': self.status.value,
'last_check': self.last_check.isoformat() if self.last_check else None,
'last_success': self.last_success.isoformat() if self.last_success else None,
'last_failure': self.last_failure.isoformat() if self.last_failure else None,
'success_count': self.success_count,
'failure_count': self.failure_count,
'avg_response_time': round(self.get_avg_response_time(), 3),
'consecutive_failures': self.consecutive_failures,
'circuit_breaker_open': self.circuit_breaker_open,
'circuit_breaker_opened_at': (
self.circuit_breaker_opened_at.isoformat()
if self.circuit_breaker_opened_at else None
),
}
class HealthMonitor:
"""Monitors health of pricing providers and manages failover."""
def __init__(
self,
circuit_breaker_timeout: int = 60,
min_success_rate: float = 0.8,
max_avg_response_time: float = 5.0
):
"""Initialize health monitor.
Args:
circuit_breaker_timeout: Seconds to wait before retrying after circuit breaker opens
min_success_rate: Minimum success rate to be considered healthy (0.0-1.0)
max_avg_response_time: Maximum average response time in seconds to be considered healthy
"""
self.circuit_breaker_timeout = circuit_breaker_timeout
self.min_success_rate = min_success_rate
self.max_avg_response_time = max_avg_response_time
self._metrics: Dict[str, HealthMetrics] = {}
self.logger = get_logger(__name__)
def record_success(self, provider_name: str, response_time: float):
"""Record a successful operation for a provider.
Args:
provider_name: Name of the provider
response_time: Response time in seconds
"""
if provider_name not in self._metrics:
self._metrics[provider_name] = HealthMetrics()
self._metrics[provider_name].record_success(response_time)
self.logger.debug(f"Recorded success for {provider_name} ({response_time:.3f}s)")
def record_failure(self, provider_name: str):
"""Record a failed operation for a provider.
Args:
provider_name: Name of the provider
"""
if provider_name not in self._metrics:
self._metrics[provider_name] = HealthMetrics()
self._metrics[provider_name].record_failure()
self.logger.warning(f"Recorded failure for {provider_name} "
f"(consecutive: {self._metrics[provider_name].consecutive_failures})")
def is_healthy(self, provider_name: str) -> bool:
"""Check if a provider is healthy.
Args:
provider_name: Name of the provider
Returns:
True if provider is healthy
"""
if provider_name not in self._metrics:
return True # Assume healthy if no metrics yet
metrics = self._metrics[provider_name]
# Check circuit breaker
if not metrics.should_attempt(self.circuit_breaker_timeout):
return False
# Check status
if metrics.status == HealthStatus.UNHEALTHY:
return False
# Check success rate
total = metrics.success_count + metrics.failure_count
if total > 10: # Need minimum data points
success_rate = metrics.success_count / total
if success_rate < self.min_success_rate:
return False
# Check response time
if metrics.response_times:
avg_response_time = metrics.get_avg_response_time()
if avg_response_time > self.max_avg_response_time:
return False
return True
def get_health_status(self, provider_name: str) -> HealthStatus:
"""Get health status for a provider.
Args:
provider_name: Name of the provider
Returns:
Health status
"""
if provider_name not in self._metrics:
return HealthStatus.UNKNOWN
return self._metrics[provider_name].status
def get_metrics(self, provider_name: str) -> Optional[HealthMetrics]:
"""Get health metrics for a provider.
Args:
provider_name: Name of the provider
Returns:
Health metrics or None if not found
"""
return self._metrics.get(provider_name)
def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
"""Get all provider metrics.
Returns:
Dictionary mapping provider names to their metrics
"""
return {
name: metrics.to_dict()
for name, metrics in self._metrics.items()
}
def select_healthiest(self, provider_names: List[str]) -> Optional[str]:
"""Select the healthiest provider from a list.
Args:
provider_names: List of provider names to choose from
Returns:
Name of healthiest provider, or None if none are healthy
"""
healthy_providers = [
name for name in provider_names
if self.is_healthy(name)
]
if not healthy_providers:
return None
# Sort by health metrics (better providers first)
def health_score(name: str) -> tuple:
metrics = self._metrics.get(name)
if not metrics:
return (1, 0, 0) # Unknown providers get lowest priority
# Score: (status_weight, -avg_response_time, success_rate)
status_weights = {
HealthStatus.HEALTHY: 3,
HealthStatus.DEGRADED: 2,
HealthStatus.UNHEALTHY: 1,
HealthStatus.UNKNOWN: 0,
}
success_rate = (
metrics.success_count / (metrics.success_count + metrics.failure_count)
if (metrics.success_count + metrics.failure_count) > 0
else 0.0
)
return (
status_weights.get(metrics.status, 0),
-metrics.get_avg_response_time(),
success_rate
)
sorted_providers = sorted(healthy_providers, key=health_score, reverse=True)
return sorted_providers[0] if sorted_providers else None
def reset_circuit_breaker(self, provider_name: str):
"""Manually reset circuit breaker for a provider.
Args:
provider_name: Name of the provider
"""
if provider_name in self._metrics:
self._metrics[provider_name].circuit_breaker_open = False
self._metrics[provider_name].circuit_breaker_opened_at = None
self._metrics[provider_name].consecutive_failures = 0
self.logger.info(f"Circuit breaker reset for {provider_name}")
def reset_metrics(self, provider_name: Optional[str] = None):
"""Reset metrics for a provider or all providers.
Args:
provider_name: Name of provider to reset, or None to reset all
"""
if provider_name:
if provider_name in self._metrics:
del self._metrics[provider_name]
self.logger.info(f"Reset metrics for {provider_name}")
else:
self._metrics.clear()
self.logger.info("Reset all provider metrics")

569
src/data/indicators.py Normal file
View File

@@ -0,0 +1,569 @@
"""Comprehensive technical indicator library using pandas-ta and TA-Lib."""
import pandas as pd
import numpy as np
from typing import Optional, Dict, Any, List
# Try to import pandas_ta, but handle if numba is missing
try:
import pandas_ta as ta
PANDAS_TA_AVAILABLE = True
except ImportError:
PANDAS_TA_AVAILABLE = False
ta = None
import warnings
warnings.warn("pandas-ta not available (numba issue), using basic implementations")
try:
import talib
TALIB_AVAILABLE = True
except ImportError:
TALIB_AVAILABLE = False
from src.core.logger import get_logger
logger = get_logger(__name__)
class TechnicalIndicators:
"""Technical indicators library."""
def __init__(self):
"""Initialize indicators library."""
self.talib_available = TALIB_AVAILABLE
# Trend Indicators
def sma(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Simple Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.sma(data, length=period)
return data.rolling(window=period).mean()
def ema(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Exponential Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.ema(data, length=period)
return data.ewm(span=period, adjust=False).mean()
def wma(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Weighted Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.wma(data, length=period)
# Basic WMA implementation
weights = np.arange(1, period + 1)
return data.rolling(window=period).apply(lambda x: np.dot(x, weights) / weights.sum(), raw=True)
def dema(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Double Exponential Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.dema(data, length=period)
ema1 = self.ema(data, period)
return 2 * ema1 - self.ema(ema1, period)
def tema(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Triple Exponential Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.tema(data, length=period)
ema1 = self.ema(data, period)
ema2 = self.ema(ema1, period)
ema3 = self.ema(ema2, period)
return 3 * ema1 - 3 * ema2 + ema3
# Momentum Indicators
def rsi(self, data: pd.Series, period: int = 14) -> pd.Series:
"""Relative Strength Index."""
if self.talib_available:
return pd.Series(talib.RSI(data.values, timeperiod=period), index=data.index)
if PANDAS_TA_AVAILABLE:
return ta.rsi(data, length=period)
# Basic RSI implementation
delta = data.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
return 100 - (100 / (1 + rs))
def macd(
self,
data: pd.Series,
fast: int = 12,
slow: int = 26,
signal: int = 9
) -> Dict[str, pd.Series]:
"""MACD (Moving Average Convergence Divergence)."""
if self.talib_available:
macd, signal_line, histogram = talib.MACD(
data.values, fastperiod=fast, slowperiod=slow, signalperiod=signal
)
return {
'macd': pd.Series(macd, index=data.index),
'signal': pd.Series(signal_line, index=data.index),
'histogram': pd.Series(histogram, index=data.index),
}
if not PANDAS_TA_AVAILABLE or ta is None:
# Basic MACD implementation fallback
ema_fast = self.ema(data, fast)
ema_slow = self.ema(data, slow)
macd_line = ema_fast - ema_slow
signal_line = self.ema(macd_line.dropna(), signal)
histogram = macd_line - signal_line
return {
'macd': macd_line,
'signal': signal_line,
'histogram': histogram,
}
result = ta.macd(data, fast=fast, slow=slow, signal=signal)
return {
'macd': result[f'MACD_{fast}_{slow}_{signal}'],
'signal': result[f'MACDs_{fast}_{slow}_{signal}'],
'histogram': result[f'MACDh_{fast}_{slow}_{signal}'],
}
def stochastic(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
k_period: int = 14,
d_period: int = 3
) -> Dict[str, pd.Series]:
"""Stochastic Oscillator."""
if self.talib_available:
slowk, slowd = talib.STOCH(
high.values, low.values, close.values,
fastk_period=k_period, slowk_period=d_period, slowd_period=d_period
)
return {
'k': pd.Series(slowk, index=close.index),
'd': pd.Series(slowd, index=close.index),
}
if PANDAS_TA_AVAILABLE and ta is not None:
result = ta.stoch(high, low, close, k=k_period, d=d_period)
return {
'k': result[f'STOCHk_{k_period}_{d_period}_{d_period}'],
'd': result[f'STOCHd_{k_period}_{d_period}_{d_period}'],
}
# Basic Stochastic implementation
lowest_low = low.rolling(window=k_period).min()
highest_high = high.rolling(window=k_period).max()
k = 100 * ((close - lowest_low) / (highest_high - lowest_low))
d = k.rolling(window=d_period).mean()
return {'k': k, 'd': d}
def cci(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 20
) -> pd.Series:
"""Commodity Channel Index."""
if self.talib_available:
return pd.Series(
talib.CCI(high.values, low.values, close.values, timeperiod=period),
index=close.index
)
if PANDAS_TA_AVAILABLE:
return ta.cci(high, low, close, length=period)
# Basic CCI implementation
tp = (high + low + close) / 3
sma_tp = tp.rolling(window=period).mean()
mad = tp.rolling(window=period).apply(lambda x: np.abs(x - x.mean()).mean(), raw=True)
return (tp - sma_tp) / (0.015 * mad)
def williams_r(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 14
) -> pd.Series:
"""Williams %R."""
if self.talib_available:
return pd.Series(
talib.WILLR(high.values, low.values, close.values, timeperiod=period),
index=close.index
)
if PANDAS_TA_AVAILABLE:
return ta.willr(high, low, close, length=period)
# Basic Williams %R implementation
highest_high = high.rolling(window=period).max()
lowest_low = low.rolling(window=period).min()
return -100 * ((highest_high - close) / (highest_high - lowest_low))
# Volatility Indicators
def bollinger_bands(
self,
data: pd.Series,
period: int = 20,
std_dev: float = 2.0
) -> Dict[str, pd.Series]:
"""Bollinger Bands."""
if self.talib_available:
upper, middle, lower = talib.BBANDS(
data.values, timeperiod=period, nbdevup=std_dev, nbdevdn=std_dev
)
return {
'upper': pd.Series(upper, index=data.index),
'middle': pd.Series(middle, index=data.index),
'lower': pd.Series(lower, index=data.index),
}
if PANDAS_TA_AVAILABLE:
result = ta.bbands(data, length=period, std=std_dev)
return {
'upper': result[f'BBU_{period}_{std_dev}'],
'middle': result[f'BBM_{period}_{std_dev}'],
'lower': result[f'BBL_{period}_{std_dev}'],
}
# Basic Bollinger Bands implementation
middle = self.sma(data, period)
std = data.rolling(window=period).std()
return {
'upper': middle + (std * std_dev),
'middle': middle,
'lower': middle - (std * std_dev),
}
def atr(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 14
) -> pd.Series:
"""Average True Range."""
if self.talib_available:
return pd.Series(
talib.ATR(high.values, low.values, close.values, timeperiod=period),
index=close.index
)
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.atr(high, low, close, length=period)
# Basic ATR implementation
prev_close = close.shift(1)
tr1 = high - low
tr2 = abs(high - prev_close)
tr3 = abs(low - prev_close)
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
return tr.rolling(window=period).mean()
def keltner_channels(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 20,
multiplier: float = 2.0
) -> Dict[str, pd.Series]:
"""Keltner Channels."""
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.kc(high, low, close, length=period, scalar=multiplier)
# Basic Keltner Channels implementation
middle = self.ema(close, period)
atr_val = self.atr(high, low, close, period)
return {
'lower': middle - (multiplier * atr_val),
'middle': middle,
'upper': middle + (multiplier * atr_val),
}
# Volume Indicators
def obv(self, close: pd.Series, volume: pd.Series) -> pd.Series:
"""On-Balance Volume."""
if self.talib_available:
return pd.Series(
talib.OBV(close.values, volume.values),
index=close.index
)
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.obv(close, volume)
# Basic OBV implementation
price_change = close.diff()
obv = pd.Series(index=close.index, dtype=float)
obv.iloc[0] = 0
for i in range(1, len(close)):
if price_change.iloc[i] > 0:
obv.iloc[i] = obv.iloc[i-1] + volume.iloc[i]
elif price_change.iloc[i] < 0:
obv.iloc[i] = obv.iloc[i-1] - volume.iloc[i]
else:
obv.iloc[i] = obv.iloc[i-1]
return obv
def vwap(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
volume: pd.Series
) -> pd.Series:
"""Volume Weighted Average Price."""
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.vwap(high, low, close, volume)
# Basic VWAP implementation
typical_price = (high + low + close) / 3
cumulative_tp_vol = (typical_price * volume).cumsum()
cumulative_vol = volume.cumsum()
return cumulative_tp_vol / cumulative_vol
# Advanced Indicators
def ichimoku(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
tenkan: int = 9,
kijun: int = 26,
senkou: int = 52
) -> Dict[str, pd.Series]:
"""Ichimoku Cloud."""
if PANDAS_TA_AVAILABLE and ta is not None:
result = ta.ichimoku(high, low, close, tenkan=tenkan, kijun=kijun, senkou=senkou)
return {
'tenkan': result['ITS_9'],
'kijun': result['IKS_26'],
'senkou_a': result['ISA_9'],
'senkou_b': result['ISB_26'],
'chikou': result['ICS_26'],
}
# Basic Ichimoku implementation
tenkan_sen = (high.rolling(window=tenkan).max() + low.rolling(window=tenkan).min()) / 2
kijun_sen = (high.rolling(window=kijun).max() + low.rolling(window=kijun).min()) / 2
senkou_a = ((tenkan_sen + kijun_sen) / 2).shift(kijun)
senkou_b = ((high.rolling(window=senkou).max() + low.rolling(window=senkou).min()) / 2).shift(kijun)
chikou = close.shift(-kijun)
return {
'tenkan': tenkan_sen,
'kijun': kijun_sen,
'senkou_a': senkou_a,
'senkou_b': senkou_b,
'chikou': chikou,
}
def adx(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 14
) -> pd.Series:
"""Average Directional Index."""
if self.talib_available:
return pd.Series(
talib.ADX(high.values, low.values, close.values, timeperiod=period),
index=close.index
)
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.adx(high, low, close, length=period)
# Basic ADX implementation
plus_dm = high.diff()
minus_dm = low.diff().abs()
plus_dm[plus_dm < 0] = 0
minus_dm[minus_dm < 0] = 0
atr_val = self.atr(high, low, close, period)
plus_di = 100 * (plus_dm.rolling(window=period).mean() / atr_val)
minus_di = 100 * (minus_dm.rolling(window=period).mean() / atr_val)
dx = 100 * abs(plus_di - minus_di) / (plus_di + minus_di)
adx = dx.rolling(window=period).mean()
return adx
def detect_divergence(
self,
prices: pd.Series,
indicator: pd.Series,
lookback: int = 20,
min_swings: int = 2
) -> Dict[str, Any]:
"""Detect divergence between price and indicator.
Divergence occurs when price makes new highs/lows but indicator doesn't,
or vice versa. This is a powerful reversal signal.
Args:
prices: Price series
indicator: Indicator series (e.g., RSI, MACD)
lookback: Lookback period for finding swings
min_swings: Minimum number of swings to detect divergence
Returns:
Dictionary with divergence information:
{
'type': 'bullish', 'bearish', or None
'confidence': 0.0 to 1.0
'price_swing_high': price at high swing
'price_swing_low': price at low swing
'indicator_swing_high': indicator at high swing
'indicator_swing_low': indicator at low swing
}
"""
if len(prices) < lookback * 2 or len(indicator) < lookback * 2:
return {
'type': None,
'confidence': 0.0,
'price_swing_high': None,
'price_swing_low': None,
'indicator_swing_high': None,
'indicator_swing_low': None
}
# Find local extrema (swings)
def find_swings(series: pd.Series, lookback: int):
"""Find local maxima and minima."""
highs = []
lows = []
for i in range(lookback, len(series) - lookback):
window = series.iloc[i-lookback:i+lookback+1]
center = series.iloc[i]
# Local maximum
if center == window.max():
highs.append((i, center))
# Local minimum
elif center == window.min():
lows.append((i, center))
return highs, lows
price_highs, price_lows = find_swings(prices, lookback)
indicator_highs, indicator_lows = find_swings(indicator, lookback)
# Need at least min_swings swings to detect divergence
if len(price_highs) < min_swings or len(price_lows) < min_swings:
return {
'type': None,
'confidence': 0.0,
'price_swing_high': None,
'price_swing_low': None,
'indicator_swing_high': None,
'indicator_swing_low': None
}
# Check for bearish divergence (price makes higher high, indicator makes lower high)
if len(price_highs) >= 2 and len(indicator_highs) >= 2:
recent_price_high = price_highs[-1][1]
prev_price_high = price_highs[-2][1]
recent_indicator_high = indicator_highs[-1][1]
prev_indicator_high = indicator_highs[-2][1]
# Price higher high but indicator lower high = bearish divergence
if recent_price_high > prev_price_high and recent_indicator_high < prev_indicator_high:
confidence = min(1.0, abs(recent_price_high - prev_price_high) / prev_price_high * 10)
return {
'type': 'bearish',
'confidence': confidence,
'price_swing_high': (price_highs[-2][0], price_highs[-1][0]),
'price_swing_low': None,
'indicator_swing_high': (indicator_highs[-2][0], indicator_highs[-1][0]),
'indicator_swing_low': None
}
# Check for bullish divergence (price makes lower low, indicator makes higher low)
if len(price_lows) >= 2 and len(indicator_lows) >= 2:
recent_price_low = price_lows[-1][1]
prev_price_low = price_lows[-2][1]
recent_indicator_low = indicator_lows[-1][1]
prev_indicator_low = indicator_lows[-2][1]
# Price lower low but indicator higher low = bullish divergence
if recent_price_low < prev_price_low and recent_indicator_low > prev_indicator_low:
confidence = min(1.0, abs(prev_price_low - recent_price_low) / prev_price_low * 10)
return {
'type': 'bullish',
'confidence': confidence,
'price_swing_high': None,
'price_swing_low': (price_lows[-2][0], price_lows[-1][0]),
'indicator_swing_high': None,
'indicator_swing_low': (indicator_lows[-2][0], indicator_lows[-1][0])
}
return {
'type': None,
'confidence': 0.0,
'price_swing_high': None,
'price_swing_low': None,
'indicator_swing_high': None,
'indicator_swing_low': None
}
def calculate_all(
self,
df: pd.DataFrame,
indicators: Optional[List[str]] = None
) -> pd.DataFrame:
"""Calculate multiple indicators at once.
Args:
df: DataFrame with OHLCV data (columns: open, high, low, close, volume)
indicators: List of indicator names to calculate (None = all)
Returns:
DataFrame with added indicator columns
"""
result = df.copy()
if 'close' not in result.columns:
raise ValueError("DataFrame must have 'close' column")
close = result['close']
high = result.get('high', close)
low = result.get('low', close)
volume = result.get('volume', pd.Series(1, index=close.index))
# Default indicators if none specified
if indicators is None:
indicators = [
'sma_20', 'ema_20', 'rsi', 'macd', 'bollinger_bands',
'atr', 'obv', 'adx'
]
for indicator in indicators:
try:
if indicator.startswith('sma_'):
period = int(indicator.split('_')[1])
result[f'SMA_{period}'] = self.sma(close, period)
elif indicator.startswith('ema_'):
period = int(indicator.split('_')[1])
result[f'EMA_{period}'] = self.ema(close, period)
elif indicator == 'rsi':
result['RSI'] = self.rsi(close)
elif indicator == 'macd':
macd_data = self.macd(close)
result['MACD'] = macd_data['macd']
result['MACD_Signal'] = macd_data['signal']
result['MACD_Histogram'] = macd_data['histogram']
elif indicator == 'bollinger_bands':
bb_data = self.bollinger_bands(close)
result['BB_Upper'] = bb_data['upper']
result['BB_Middle'] = bb_data['middle']
result['BB_Lower'] = bb_data['lower']
elif indicator == 'atr':
result['ATR'] = self.atr(high, low, close)
elif indicator == 'obv':
result['OBV'] = self.obv(close, volume)
elif indicator == 'adx':
result['ADX'] = self.adx(high, low, close)
except Exception as e:
logger.warning(f"Failed to calculate indicator {indicator}: {e}")
return result
# Global indicators instance
_indicators: Optional[TechnicalIndicators] = None
def get_indicators() -> TechnicalIndicators:
"""Get global technical indicators instance."""
global _indicators
if _indicators is None:
_indicators = TechnicalIndicators()
return _indicators

447
src/data/news_collector.py Normal file
View File

@@ -0,0 +1,447 @@
"""News collector for crypto sentiment analysis.
Collects headlines from multiple sources:
- RSS feeds (CoinDesk, CoinTelegraph, Decrypt, etc.)
- CryptoPanic API (optional, requires API key)
Headlines are cached and refreshed periodically to avoid rate limits.
"""
import asyncio
import re
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
from src.core.logger import get_logger
logger = get_logger(__name__)
class NewsSource(str, Enum):
"""Supported news sources."""
COINDESK = "coindesk"
COINTELEGRAPH = "cointelegraph"
DECRYPT = "decrypt"
BITCOIN_MAGAZINE = "bitcoin_magazine"
THE_BLOCK = "the_block"
MESSARI = "messari"
CRYPTOPANIC = "cryptopanic"
@dataclass
class NewsItem:
"""A single news item."""
title: str
source: NewsSource
published: datetime
url: Optional[str] = None
summary: Optional[str] = None
symbols: List[str] = field(default_factory=list)
# RSS feed URLs for major crypto news sources
RSS_FEEDS: Dict[NewsSource, str] = {
NewsSource.COINDESK: "https://www.coindesk.com/arc/outboundfeeds/rss/",
NewsSource.COINTELEGRAPH: "https://cointelegraph.com/rss",
NewsSource.DECRYPT: "https://decrypt.co/feed",
NewsSource.BITCOIN_MAGAZINE: "https://bitcoinmagazine.com/.rss/full/",
NewsSource.THE_BLOCK: "https://www.theblock.co/rss.xml",
NewsSource.MESSARI: "https://messari.io/rss",
}
# Common crypto symbols to detect in headlines
CRYPTO_SYMBOLS = {
"BTC": ["bitcoin", "btc"],
"ETH": ["ethereum", "eth", "ether"],
"SOL": ["solana", "sol"],
"XRP": ["ripple", "xrp"],
"ADA": ["cardano", "ada"],
"DOGE": ["dogecoin", "doge"],
"DOT": ["polkadot", "dot"],
"AVAX": ["avalanche", "avax"],
"MATIC": ["polygon", "matic"],
"LINK": ["chainlink", "link"],
"UNI": ["uniswap", "uni"],
"ATOM": ["cosmos", "atom"],
"LTC": ["litecoin", "ltc"],
}
class NewsCollector:
"""Collects crypto news headlines for sentiment analysis.
Features:
- Aggregates news from multiple RSS feeds
- Caches headlines to reduce network requests
- Filters by crypto symbol
- Optional CryptoPanic integration for more sources
Usage:
collector = NewsCollector()
headlines = await collector.fetch_headlines()
# Or filter by symbol
btc_headlines = await collector.fetch_headlines(symbols=["BTC"])
"""
# Minimum time between fetches (in seconds)
MIN_FETCH_INTERVAL = 300 # 5 minutes
def __init__(
self,
sources: Optional[List[NewsSource]] = None,
cryptopanic_api_key: Optional[str] = None,
cache_duration: int = 600, # 10 minutes
max_headlines: int = 50
):
"""Initialize NewsCollector.
Args:
sources: List of news sources to use. Defaults to all RSS feeds.
cryptopanic_api_key: Optional API key for CryptoPanic.
cache_duration: How long to cache headlines (seconds).
max_headlines: Maximum headlines to keep in cache.
"""
self.sources = sources or list(RSS_FEEDS.keys())
self.cryptopanic_api_key = cryptopanic_api_key
self.cache_duration = cache_duration
self.max_headlines = max_headlines
self._cache: List[NewsItem] = []
self._last_fetch: Optional[datetime] = None
self._fetching = False
self.logger = get_logger(__name__)
# Check if feedparser is available
try:
import feedparser
self._feedparser = feedparser
self._feedparser_available = True
except ImportError:
self._feedparser_available = False
self.logger.warning(
"feedparser not installed. Install with: pip install feedparser"
)
def _extract_symbols(self, text: str) -> List[str]:
"""Extract crypto symbols mentioned in text.
Args:
text: Text to search (headline, summary)
Returns:
List of detected symbol codes (e.g., ["BTC", "ETH"])
"""
text_lower = text.lower()
detected = []
for symbol, keywords in CRYPTO_SYMBOLS.items():
for keyword in keywords:
if keyword in text_lower:
detected.append(symbol)
break
return detected
async def _fetch_rss_feed(self, source: NewsSource) -> List[NewsItem]:
"""Fetch and parse a single RSS feed.
Args:
source: News source to fetch
Returns:
List of NewsItems from the feed
"""
if not self._feedparser_available:
return []
url = RSS_FEEDS.get(source)
if not url:
return []
try:
# Run feedparser in thread pool to avoid blocking
loop = asyncio.get_event_loop()
feed = await loop.run_in_executor(
None,
self._feedparser.parse,
url
)
items = []
for entry in feed.entries[:20]: # Limit entries per feed
# Parse publication date
published = datetime.now()
if hasattr(entry, 'published_parsed') and entry.published_parsed:
try:
published = datetime(*entry.published_parsed[:6])
except (TypeError, ValueError):
pass
title = entry.get('title', '')
summary = entry.get('summary', '')
# Clean HTML from summary
summary = re.sub(r'<[^>]+>', '', summary)[:200]
item = NewsItem(
title=title,
source=source,
published=published,
url=entry.get('link'),
summary=summary,
symbols=self._extract_symbols(f"{title} {summary}")
)
items.append(item)
self.logger.debug(f"Fetched {len(items)} items from {source.value}")
return items
except Exception as e:
self.logger.warning(f"Failed to fetch {source.value} RSS: {e}")
return []
async def _fetch_cryptopanic(self, symbols: Optional[List[str]] = None) -> List[NewsItem]:
"""Fetch news from CryptoPanic API.
Args:
symbols: Optional list of symbols to filter
Returns:
List of NewsItems from CryptoPanic
"""
if not self.cryptopanic_api_key:
return []
try:
import aiohttp
url = "https://cryptopanic.com/api/v1/posts/"
params = {
"auth_token": self.cryptopanic_api_key,
"public": "true",
}
if symbols:
params["currencies"] = ",".join(symbols)
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, timeout=10) as response:
if response.status != 200:
self.logger.warning(f"CryptoPanic API error: {response.status}")
return []
data = await response.json()
items = []
for post in data.get("results", [])[:20]:
published = datetime.now()
if post.get("published_at"):
try:
published = datetime.fromisoformat(
post["published_at"].replace("Z", "+00:00")
)
except (ValueError, TypeError):
pass
item = NewsItem(
title=post.get("title", ""),
source=NewsSource.CRYPTOPANIC,
published=published,
url=post.get("url"),
symbols=[c["code"] for c in post.get("currencies", [])]
)
items.append(item)
self.logger.debug(f"Fetched {len(items)} items from CryptoPanic")
return items
except ImportError:
self.logger.warning("aiohttp not installed for CryptoPanic API")
return []
except Exception as e:
self.logger.warning(f"Failed to fetch CryptoPanic: {e}")
return []
async def fetch_news(
self,
symbols: Optional[List[str]] = None,
force_refresh: bool = False
) -> List[NewsItem]:
"""Fetch news items from all sources.
Args:
symbols: Optional list of symbols to filter (e.g., ["BTC", "ETH"])
force_refresh: Force a refresh even if cache is valid
Returns:
List of NewsItems sorted by publication date (newest first)
"""
now = datetime.now()
# Check cache validity
cache_valid = (
self._last_fetch is not None and
(now - self._last_fetch).total_seconds() < self.cache_duration and
len(self._cache) > 0
)
if cache_valid and not force_refresh:
self.logger.debug("Using cached news items")
items = self._cache
else:
# Prevent concurrent fetches
if self._fetching:
self.logger.debug("Fetch already in progress, using cache")
items = self._cache
else:
self._fetching = True
try:
items = await self._fetch_all_sources()
self._cache = items
self._last_fetch = now
finally:
self._fetching = False
# Filter by symbols if specified
if symbols:
symbols_upper = [s.upper() for s in symbols]
items = [
item for item in items
if any(s in symbols_upper for s in item.symbols) or not item.symbols
]
return items
async def _fetch_all_sources(self) -> List[NewsItem]:
"""Fetch from all configured sources concurrently."""
tasks = []
# RSS feeds
for source in self.sources:
if source in RSS_FEEDS:
tasks.append(self._fetch_rss_feed(source))
# CryptoPanic
if self.cryptopanic_api_key:
tasks.append(self._fetch_cryptopanic())
if not tasks:
self.logger.warning("No news sources configured")
return []
# Fetch all concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Combine results
all_items = []
for result in results:
if isinstance(result, list):
all_items.extend(result)
elif isinstance(result, Exception):
self.logger.warning(f"Source fetch failed: {result}")
# Sort by publication date (newest first)
all_items.sort(key=lambda x: x.published, reverse=True)
# Limit total items
all_items = all_items[:self.max_headlines]
self.logger.info(f"Fetched {len(all_items)} total news items")
return all_items
async def fetch_headlines(
self,
symbols: Optional[List[str]] = None,
max_age_hours: int = 24,
force_refresh: bool = False
) -> List[str]:
"""Fetch headlines as strings for sentiment analysis.
This is the main method to use with SentimentScanner.
Args:
symbols: Optional list of symbols to filter
max_age_hours: Only include headlines from the last N hours
force_refresh: Force a refresh even if cache is valid
Returns:
List of headline strings
"""
items = await self.fetch_news(symbols=symbols, force_refresh=force_refresh)
# Filter by age
cutoff = datetime.now() - timedelta(hours=max_age_hours)
recent_items = [item for item in items if item.published > cutoff]
# Extract just the titles
headlines = [item.title for item in recent_items if item.title]
self.logger.debug(f"Returning {len(headlines)} headlines for analysis")
return headlines
def get_cached_headlines(self, symbols: Optional[List[str]] = None) -> List[str]:
"""Get cached headlines synchronously (no fetch).
Args:
symbols: Optional list of symbols to filter
Returns:
List of cached headline strings
"""
items = self._cache
if symbols:
symbols_upper = [s.upper() for s in symbols]
items = [
item for item in items
if any(s in symbols_upper for s in item.symbols) or not item.symbols
]
return [item.title for item in items if item.title]
def get_status(self) -> Dict[str, Any]:
"""Get collector status information.
Returns:
Dictionary with status info
"""
return {
"sources": [s.value for s in self.sources],
"cryptopanic_enabled": self.cryptopanic_api_key is not None,
"feedparser_available": self._feedparser_available,
"cached_items": len(self._cache),
"last_fetch": self._last_fetch.isoformat() if self._last_fetch else None,
"cache_age_seconds": (
(datetime.now() - self._last_fetch).total_seconds()
if self._last_fetch else None
),
}
def clear_cache(self):
"""Clear the headline cache."""
self._cache = []
self._last_fetch = None
self.logger.info("News cache cleared")
# Global instance
_news_collector: Optional[NewsCollector] = None
def get_news_collector(**kwargs) -> NewsCollector:
"""Get or create the global NewsCollector instance.
Args:
**kwargs: Arguments passed to NewsCollector constructor
Returns:
NewsCollector instance
"""
global _news_collector
if _news_collector is None:
_news_collector = NewsCollector(**kwargs)
logger.info("Created global NewsCollector instance")
return _news_collector

406
src/data/pricing_service.py Normal file
View File

@@ -0,0 +1,406 @@
"""Unified pricing data service with multi-provider support and automatic failover."""
import time
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from decimal import Decimal
from .providers.base_provider import BasePricingProvider
from .providers.ccxt_provider import CCXTProvider
from .providers.coingecko_provider import CoinGeckoProvider
from .cache_manager import CacheManager
from .health_monitor import HealthMonitor, HealthStatus
from src.core.config import get_config
from src.core.logger import get_logger
logger = get_logger(__name__)
class PricingService:
"""Unified pricing data service with multi-provider support.
Manages multiple pricing providers with automatic failover, caching,
and health monitoring. Provides a single consistent API for accessing
market data regardless of the underlying provider.
"""
def __init__(self):
"""Initialize pricing service."""
self.config = get_config()
self.logger = get_logger(__name__)
# Initialize components
self.cache = CacheManager(
default_ttl=self.config.get("data_providers.caching.ohlcv_ttl", 60),
max_size=self.config.get("data_providers.caching.max_cache_size", 1000),
ticker_ttl=self.config.get("data_providers.caching.ticker_ttl", 2),
ohlcv_ttl=self.config.get("data_providers.caching.ohlcv_ttl", 60),
)
self.health_monitor = HealthMonitor()
# Provider instances
self._providers: Dict[str, BasePricingProvider] = {}
self._active_provider: Optional[str] = None
self._provider_priority: List[str] = []
# Subscriptions
self._subscriptions: Dict[str, List[Callable]] = {}
# Initialize providers
self._initialize_providers()
def _initialize_providers(self):
"""Initialize providers from configuration."""
# Get primary providers from config
primary_config = self.config.get("data_providers.primary", [])
if not primary_config:
# Default configuration
primary_config = [
{'name': 'kraken', 'enabled': True, 'priority': 1},
{'name': 'coinbase', 'enabled': True, 'priority': 2},
{'name': 'binance', 'enabled': True, 'priority': 3},
]
# Sort by priority
primary_config = sorted(
[p for p in primary_config if p.get('enabled', True)],
key=lambda x: x.get('priority', 999)
)
# Create CCXT providers for each exchange
for provider_config in primary_config:
exchange_name = provider_config.get('name')
try:
provider = CCXTProvider(exchange_name=exchange_name)
provider_name = provider.name
if provider.connect():
self._providers[provider_name] = provider
self._provider_priority.append(provider_name)
self.logger.info(f"Initialized provider: {provider_name}")
else:
self.logger.warning(f"Failed to connect provider: {provider_name}")
except Exception as e:
self.logger.error(f"Error initializing provider {exchange_name}: {e}")
# Add fallback provider (CoinGecko)
fallback_config = self.config.get("data_providers.fallback", {})
if fallback_config.get('enabled', True):
try:
coingecko = CoinGeckoProvider(api_key=fallback_config.get('api_key'))
if coingecko.connect():
self._providers[coingecko.name] = coingecko
self._provider_priority.append(coingecko.name)
self.logger.info(f"Initialized fallback provider: {coingecko.name}")
else:
self.logger.warning("Failed to connect CoinGecko fallback provider")
except Exception as e:
self.logger.error(f"Error initializing CoinGecko provider: {e}")
# Select initial active provider
self._select_active_provider()
def _select_active_provider(self) -> Optional[str]:
"""Select the best available provider.
Returns:
Name of selected provider or None
"""
# Filter to healthy providers
healthy_providers = [
name for name in self._provider_priority
if name in self._providers
and self.health_monitor.is_healthy(name)
]
if not healthy_providers:
# Fall back to any available provider if none are healthy
healthy_providers = list(self._providers.keys())
if not healthy_providers:
self.logger.error("No providers available")
self._active_provider = None
return None
# Select first healthy provider (already sorted by priority)
self._active_provider = healthy_providers[0]
self.logger.info(f"Selected active provider: {self._active_provider}")
return self._active_provider
def _get_provider(self, provider_name: Optional[str] = None) -> Optional[BasePricingProvider]:
"""Get a provider instance.
Args:
provider_name: Name of provider, or None to use active provider
Returns:
Provider instance or None
"""
if provider_name:
return self._providers.get(provider_name)
# Use active provider, or select one if none active
if not self._active_provider:
self._select_active_provider()
return self._providers.get(self._active_provider) if self._active_provider else None
def _execute_with_failover(
self,
operation: Callable[[BasePricingProvider], Any],
operation_name: str
) -> Any:
"""Execute an operation with automatic failover.
Args:
operation: Function that takes a provider and returns a result
operation_name: Name of operation for logging
Returns:
Operation result or None if all providers fail
"""
# Try active provider first
providers_to_try = [self._active_provider] if self._active_provider else []
# Add other providers in priority order
for provider_name in self._provider_priority:
if provider_name != self._active_provider and provider_name in self._providers:
providers_to_try.append(provider_name)
last_error = None
for provider_name in providers_to_try:
provider = self._providers.get(provider_name)
if not provider:
continue
# Check health
if not self.health_monitor.is_healthy(provider_name):
self.logger.debug(f"Skipping unhealthy provider: {provider_name}")
continue
try:
start_time = time.time()
result = operation(provider)
response_time = time.time() - start_time
# Record success
self.health_monitor.record_success(provider_name, response_time)
# Update active provider if we used a different one
if provider_name != self._active_provider:
self.logger.info(f"Switched to provider: {provider_name}")
self._active_provider = provider_name
return result
except Exception as e:
last_error = e
self.logger.warning(f"{operation_name} failed on {provider_name}: {e}")
self.health_monitor.record_failure(provider_name)
# Try next provider
continue
# All providers failed
self.logger.error(f"{operation_name} failed on all providers")
if last_error:
raise last_error
return None
def get_ticker(self, symbol: str, use_cache: bool = True) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading pair symbol (e.g., 'BTC/USD')
use_cache: Whether to use cache
Returns:
Ticker data dictionary
"""
cache_key = f"ticker:{symbol}"
# Check cache
if use_cache:
cached = self.cache.get(cache_key)
if cached:
return cached
# Fetch from provider
def fetch_ticker(provider: BasePricingProvider):
return provider.get_ticker(symbol)
ticker_data = self._execute_with_failover(fetch_ticker, f"get_ticker({symbol})")
if ticker_data:
# Cache the result
if use_cache:
self.cache.set(cache_key, ticker_data, cache_type='ticker')
return ticker_data
return {}
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100,
use_cache: bool = True
) -> List[List]:
"""Get OHLCV candlestick data.
Args:
symbol: Trading pair symbol
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
since: Start datetime
limit: Number of candles
use_cache: Whether to use cache
Returns:
List of [timestamp_ms, open, high, low, close, volume]
"""
cache_key = f"ohlcv:{symbol}:{timeframe}:{limit}"
# Check cache (only if no 'since' parameter, as it changes the result)
if use_cache and not since:
cached = self.cache.get(cache_key)
if cached:
return cached
# Fetch from provider
def fetch_ohlcv(provider: BasePricingProvider):
return provider.get_ohlcv(symbol, timeframe, since, limit)
ohlcv_data = self._execute_with_failover(
fetch_ohlcv,
f"get_ohlcv({symbol}, {timeframe})"
)
if ohlcv_data:
# Cache the result (only if no 'since' parameter)
if use_cache and not since:
self.cache.set(cache_key, ohlcv_data, cache_type='ohlcv')
return ohlcv_data
return []
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates.
Args:
symbol: Trading pair symbol
callback: Callback function(data) called on price updates
Returns:
True if subscription successful
"""
key = f"ticker:{symbol}"
# Add callback
if key not in self._subscriptions:
self._subscriptions[key] = []
if callback not in self._subscriptions[key]:
self._subscriptions[key].append(callback)
# Wrap callback to handle failover
def wrapped_callback(data):
for cb in self._subscriptions.get(key, []):
try:
cb(data)
except Exception as e:
self.logger.error(f"Callback error for {symbol}: {e}")
# Subscribe via active provider
provider = self._get_provider()
if provider:
try:
success = provider.subscribe_ticker(symbol, wrapped_callback)
if success:
self.logger.info(f"Subscribed to ticker updates for {symbol}")
return True
except Exception as e:
self.logger.error(f"Failed to subscribe to ticker for {symbol}: {e}")
return False
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates.
Args:
symbol: Trading pair symbol
callback: Specific callback to remove, or None to remove all
"""
key = f"ticker:{symbol}"
# Remove callback
if key in self._subscriptions:
if callback:
if callback in self._subscriptions[key]:
self._subscriptions[key].remove(callback)
if not self._subscriptions[key]:
del self._subscriptions[key]
else:
del self._subscriptions[key]
# Unsubscribe from all providers
for provider in self._providers.values():
try:
provider.unsubscribe_ticker(symbol, callback)
except Exception:
pass
self.logger.info(f"Unsubscribed from ticker updates for {symbol}")
def get_active_provider(self) -> Optional[str]:
"""Get name of active provider.
Returns:
Provider name or None
"""
return self._active_provider
def get_provider_health(self, provider_name: Optional[str] = None) -> Dict[str, Any]:
"""Get health status for a provider or all providers.
Args:
provider_name: Provider name, or None for all providers
Returns:
Health status dictionary
"""
if provider_name:
metrics = self.health_monitor.get_metrics(provider_name)
if metrics:
return metrics.to_dict()
return {}
return self.health_monitor.get_all_metrics()
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics.
Returns:
Cache statistics dictionary
"""
return self.cache.get_stats()
# Global pricing service instance
_pricing_service: Optional[PricingService] = None
def get_pricing_service() -> PricingService:
"""Get global pricing service instance.
Returns:
PricingService instance
"""
global _pricing_service
if _pricing_service is None:
_pricing_service = PricingService()
return _pricing_service

View File

@@ -0,0 +1,7 @@
"""Pricing data providers package."""
from .base_provider import BasePricingProvider
from .ccxt_provider import CCXTProvider
from .coingecko_provider import CoinGeckoProvider
__all__ = ['BasePricingProvider', 'CCXTProvider', 'CoinGeckoProvider']

View File

@@ -0,0 +1,150 @@
"""Base pricing provider interface."""
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from src.core.logger import get_logger
logger = get_logger(__name__)
class BasePricingProvider(ABC):
"""Base class for pricing data providers.
Pricing providers are responsible for fetching market data (prices, OHLCV)
without requiring API keys. They differ from exchange adapters which handle
trading operations.
"""
def __init__(self):
"""Initialize pricing provider."""
self.logger = get_logger(f"provider.{self.__class__.__name__}")
self._connected = False
self._subscribers: Dict[str, List[Callable]] = {}
@property
@abstractmethod
def name(self) -> str:
"""Provider name."""
pass
@property
@abstractmethod
def supports_websocket(self) -> bool:
"""Whether this provider supports WebSocket connections."""
pass
@abstractmethod
def connect(self) -> bool:
"""Connect to the provider.
Returns:
True if connection successful
"""
pass
@abstractmethod
def disconnect(self):
"""Disconnect from the provider."""
pass
@abstractmethod
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading pair symbol (e.g., 'BTC/USD')
Returns:
Dictionary with ticker data:
- 'symbol': str
- 'bid': Decimal
- 'ask': Decimal
- 'last': Decimal (last price)
- 'high': Decimal (24h high)
- 'low': Decimal (24h low)
- 'volume': Decimal (24h volume)
- 'timestamp': int (Unix timestamp in milliseconds)
"""
pass
@abstractmethod
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV (candlestick) data.
Args:
symbol: Trading pair symbol
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
since: Start datetime
limit: Number of candles
Returns:
List of [timestamp_ms, open, high, low, close, volume]
"""
pass
@abstractmethod
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates.
Args:
symbol: Trading pair symbol
callback: Callback function(data) called on price updates
Returns:
True if subscription successful
"""
pass
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates.
Args:
symbol: Trading pair symbol
callback: Specific callback to remove, or None to remove all
"""
key = f"ticker_{symbol}"
if key in self._subscribers:
if callback:
if callback in self._subscribers[key]:
self._subscribers[key].remove(callback)
if not self._subscribers[key]:
del self._subscribers[key]
else:
del self._subscribers[key]
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format for this provider.
Args:
symbol: Symbol to normalize
Returns:
Normalized symbol
"""
# Default: uppercase and replace dashes with slashes
return symbol.upper().replace('-', '/')
def is_connected(self) -> bool:
"""Check if provider is connected.
Returns:
True if connected
"""
return self._connected
def get_supported_symbols(self) -> List[str]:
"""Get list of supported trading symbols.
Returns:
List of supported symbols, or empty list if not available
"""
# Default implementation - override in subclasses
return []

View File

@@ -0,0 +1,333 @@
"""CCXT-based pricing provider with multi-exchange support and WebSocket capabilities."""
import ccxt
import threading
import time
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from .base_provider import BasePricingProvider
from src.core.logger import get_logger
logger = get_logger(__name__)
class CCXTProvider(BasePricingProvider):
"""CCXT-based pricing provider with multi-exchange fallback support.
This provider uses CCXT to connect to multiple exchanges (Kraken, Coinbase, Binance)
as primary data sources. It supports WebSocket where available and falls back to
polling if WebSocket is not available.
"""
def __init__(self, exchange_name: Optional[str] = None):
"""Initialize CCXT provider.
Args:
exchange_name: Specific exchange to use ('kraken', 'coinbase', 'binance'),
or None to try multiple exchanges
"""
super().__init__()
self.exchange_name = exchange_name
self.exchange = None
self._selected_exchange_id = None
self._polling_threads: Dict[str, threading.Thread] = {}
self._stop_polling: Dict[str, bool] = {}
# Exchange priority order (try first to last)
if exchange_name:
self._exchange_options = [(exchange_name.lower(), exchange_name.capitalize())]
else:
self._exchange_options = [
('kraken', 'Kraken'),
('coinbase', 'Coinbase'),
('binance', 'Binance'),
]
@property
def name(self) -> str:
"""Provider name."""
if self._selected_exchange_id:
return f"CCXT-{self._selected_exchange_id.capitalize()}"
return "CCXT Provider"
@property
def supports_websocket(self) -> bool:
"""Check if current exchange supports WebSocket."""
if not self.exchange:
return False
# Check if exchange has WebSocket support
# Most modern CCXT exchanges support WebSocket, but we check explicitly
exchange_id = getattr(self.exchange, 'id', '').lower()
# Known WebSocket-capable exchanges
ws_capable = ['kraken', 'coinbase', 'binance', 'binanceus', 'okx']
return exchange_id in ws_capable
def connect(self) -> bool:
"""Connect to an exchange via CCXT.
Tries multiple exchanges in order until one succeeds.
Returns:
True if connection successful
"""
for exchange_id, exchange_display_name in self._exchange_options:
try:
# Get exchange class from CCXT
exchange_class = getattr(ccxt, exchange_id, None)
if not exchange_class:
logger.warning(f"Exchange {exchange_id} not found in CCXT")
continue
# Create exchange instance
self.exchange = exchange_class({
'enableRateLimit': True,
'options': {
'defaultType': 'spot',
}
})
# Load markets to test connection
if not hasattr(self.exchange, 'markets') or not self.exchange.markets:
try:
self.exchange.load_markets()
except Exception as e:
logger.warning(f"Failed to load markets for {exchange_id}: {e}")
continue
# Test connection with a common symbol
test_symbols = ['BTC/USDT', 'BTC/USD', 'BTC/EUR']
ticker_result = None
for test_symbol in test_symbols:
try:
ticker_result = self.exchange.fetch_ticker(test_symbol)
if ticker_result:
break
except Exception:
continue
if not ticker_result:
logger.warning(f"Could not fetch ticker from {exchange_id}")
continue
# Success!
self._selected_exchange_id = exchange_id
self._connected = True
logger.info(f"Connected to {exchange_display_name} via CCXT")
return True
except Exception as e:
logger.warning(f"Failed to connect to {exchange_id}: {e}")
continue
# All exchanges failed
logger.error("Failed to connect to any CCXT exchange")
self._connected = False
self.exchange = None
return False
def disconnect(self):
"""Disconnect from exchange."""
# Stop all polling threads
for symbol in list(self._stop_polling.keys()):
self._stop_polling[symbol] = True
# Wait a bit for threads to stop
for symbol, thread in self._polling_threads.items():
if thread.is_alive():
thread.join(timeout=2.0)
self._polling_threads.clear()
self._stop_polling.clear()
self._subscribers.clear()
self._connected = False
self.exchange = None
self._selected_exchange_id = None
logger.info(f"Disconnected from {self.name}")
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data."""
if not self._connected or not self.exchange:
logger.error("Provider not connected")
return {}
try:
normalized_symbol = self.normalize_symbol(symbol)
ticker = self.exchange.fetch_ticker(normalized_symbol)
return {
'symbol': symbol,
'bid': Decimal(str(ticker.get('bid', 0) or 0)),
'ask': Decimal(str(ticker.get('ask', 0) or 0)),
'last': Decimal(str(ticker.get('last', 0) or ticker.get('close', 0) or 0)),
'high': Decimal(str(ticker.get('high', 0) or 0)),
'low': Decimal(str(ticker.get('low', 0) or 0)),
'volume': Decimal(str(ticker.get('quoteVolume', ticker.get('volume', 0)) or 0)),
'timestamp': ticker.get('timestamp', int(time.time() * 1000)),
}
except Exception as e:
logger.error(f"Failed to get ticker for {symbol}: {e}")
return {}
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV candlestick data."""
if not self._connected or not self.exchange:
logger.error("Provider not connected")
return []
try:
since_timestamp = int(since.timestamp() * 1000) if since else None
normalized_symbol = self.normalize_symbol(symbol)
# Most exchanges support up to 1000 candles per request
max_limit = min(limit, 1000)
ohlcv = self.exchange.fetch_ohlcv(
normalized_symbol,
timeframe,
since_timestamp,
max_limit
)
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} ({timeframe})")
return ohlcv
except Exception as e:
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
return []
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates.
Uses WebSocket if available, otherwise falls back to polling.
"""
if not self._connected or not self.exchange:
logger.error("Provider not connected")
return False
key = f"ticker_{symbol}"
# Add callback to subscribers
if key not in self._subscribers:
self._subscribers[key] = []
if callback not in self._subscribers[key]:
self._subscribers[key].append(callback)
# Try WebSocket first if supported
if self.supports_websocket:
try:
# CCXT WebSocket support varies by exchange
# For now, use polling as fallback since WebSocket implementation
# in CCXT requires exchange-specific handling
# TODO: Implement native WebSocket when CCXT adds better support
pass
except Exception as e:
logger.warning(f"WebSocket subscription failed, falling back to polling: {e}")
# Use polling as primary method (more reliable across exchanges)
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
self._stop_polling[key] = False
def poll_ticker():
"""Poll ticker every 2 seconds."""
while not self._stop_polling.get(key, False):
try:
ticker_data = self.get_ticker(symbol)
if ticker_data and 'last' in ticker_data:
# Call all callbacks for this symbol
for cb in self._subscribers.get(key, []):
try:
cb({
'symbol': symbol,
'price': ticker_data['last'],
'bid': ticker_data.get('bid', 0),
'ask': ticker_data.get('ask', 0),
'volume': ticker_data.get('volume', 0),
'timestamp': ticker_data.get('timestamp'),
})
except Exception as e:
logger.error(f"Callback error for {symbol}: {e}")
time.sleep(2) # Poll every 2 seconds
except Exception as e:
logger.error(f"Ticker polling error for {symbol}: {e}")
time.sleep(5) # Wait longer on error
thread = threading.Thread(target=poll_ticker, daemon=True)
thread.start()
self._polling_threads[key] = thread
logger.info(f"Subscribed to ticker updates for {symbol} (polling mode)")
return True
return True
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates."""
key = f"ticker_{symbol}"
# Stop polling thread
if key in self._stop_polling:
self._stop_polling[key] = True
# Remove from subscribers
super().unsubscribe_ticker(symbol, callback)
# Clean up thread reference
if key in self._polling_threads:
del self._polling_threads[key]
logger.info(f"Unsubscribed from ticker updates for {symbol}")
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol for the selected exchange."""
if not self.exchange:
return symbol.replace('-', '/').upper()
# Basic normalization
normalized = symbol.replace('-', '/').upper()
# Try to use exchange's markets to find correct symbol
try:
if hasattr(self.exchange, 'markets') and self.exchange.markets:
# Check if symbol exists
if normalized in self.exchange.markets:
return normalized
# Try alternative formats
if '/USD' in normalized:
alt_symbol = normalized.replace('/USD', '/USDT')
if alt_symbol in self.exchange.markets:
return alt_symbol
# For Kraken: try XBT instead of BTC
if normalized.startswith('BTC/'):
alt_symbol = normalized.replace('BTC/', 'XBT/')
if alt_symbol in self.exchange.markets:
return alt_symbol
except Exception:
pass
# Fallback: return normalized (let CCXT handle errors)
return normalized
def get_supported_symbols(self) -> List[str]:
"""Get list of supported trading symbols."""
if not self._connected or not self.exchange:
return []
try:
markets = self.exchange.load_markets()
return list(markets.keys())
except Exception as e:
logger.error(f"Failed to get supported symbols: {e}")
return []

View File

@@ -0,0 +1,376 @@
"""CoinGecko pricing provider for fallback market data."""
import httpx
import time
import threading
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable, Tuple
from datetime import datetime
from .base_provider import BasePricingProvider
from src.core.logger import get_logger
logger = get_logger(__name__)
class CoinGeckoProvider(BasePricingProvider):
"""CoinGecko API pricing provider.
This provider uses CoinGecko's free API tier as a fallback when CCXT
providers are unavailable. It uses simple REST endpoints that don't
require authentication.
"""
BASE_URL = "https://api.coingecko.com/api/v3"
# CoinGecko coin ID mapping for common symbols
COIN_ID_MAP = {
'BTC': 'bitcoin',
'ETH': 'ethereum',
'BNB': 'binancecoin',
'SOL': 'solana',
'ADA': 'cardano',
'XRP': 'ripple',
'DOGE': 'dogecoin',
'DOT': 'polkadot',
'MATIC': 'matic-network',
'AVAX': 'avalanche-2',
'LINK': 'chainlink',
'USDT': 'tether',
'USDC': 'usd-coin',
'DAI': 'dai',
}
# Currency mapping
CURRENCY_MAP = {
'USD': 'usd',
'EUR': 'eur',
'GBP': 'gbp',
'JPY': 'jpy',
'USDT': 'usd', # CoinGecko uses USD for stablecoins
}
def __init__(self, api_key: Optional[str] = None):
"""Initialize CoinGecko provider.
Args:
api_key: Optional API key for higher rate limits (free tier doesn't require it)
"""
super().__init__()
self.api_key = api_key
self._client = None
self._polling_threads: Dict[str, threading.Thread] = {}
self._stop_polling: Dict[str, bool] = {}
self._rate_limit_delay = 1.0 # Free tier: ~30-50 calls/minute
@property
def name(self) -> str:
"""Provider name."""
return "CoinGecko"
@property
def supports_websocket(self) -> bool:
"""CoinGecko free tier doesn't support WebSocket."""
return False
def connect(self) -> bool:
"""Connect to CoinGecko API.
Returns:
True if connection successful (just validates API access)
"""
try:
self._client = httpx.Client(timeout=10.0)
# Test connection by fetching Bitcoin price
response = self._client.get(
f"{self.BASE_URL}/simple/price",
params={
'ids': 'bitcoin',
'vs_currencies': 'usd',
}
)
if response.status_code == 200:
data = response.json()
if 'bitcoin' in data:
self._connected = True
logger.info("Connected to CoinGecko API")
return True
logger.warning("CoinGecko API test failed")
self._connected = False
return False
except Exception as e:
logger.error(f"Failed to connect to CoinGecko: {e}")
self._connected = False
return False
def disconnect(self):
"""Disconnect from CoinGecko API."""
# Stop all polling threads
for symbol in list(self._stop_polling.keys()):
self._stop_polling[symbol] = True
# Wait for threads to stop
for symbol, thread in self._polling_threads.items():
if thread.is_alive():
thread.join(timeout=2.0)
self._polling_threads.clear()
self._stop_polling.clear()
self._subscribers.clear()
if self._client:
self._client.close()
self._client = None
self._connected = False
logger.info("Disconnected from CoinGecko API")
def _parse_symbol(self, symbol: str) -> Tuple[Optional[str], Optional[str]]:
"""Parse symbol into coin_id and currency.
Args:
symbol: Trading pair like 'BTC/USD' or 'BTC/USDT'
Returns:
Tuple of (coin_id, currency) or (None, None) if not found
"""
parts = symbol.upper().replace('-', '/').split('/')
if len(parts) != 2:
return None, None
base, quote = parts
# Get coin ID
coin_id = self.COIN_ID_MAP.get(base)
if not coin_id:
# Try lowercase base as fallback
coin_id = base.lower()
# Get currency
currency = self.CURRENCY_MAP.get(quote, quote.lower())
return coin_id, currency
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data from CoinGecko."""
if not self._connected or not self._client:
logger.error("Provider not connected")
return {}
try:
coin_id, currency = self._parse_symbol(symbol)
if not coin_id:
logger.error(f"Unknown symbol format: {symbol}")
return {}
# Fetch current price
response = self._client.get(
f"{self.BASE_URL}/simple/price",
params={
'ids': coin_id,
'vs_currencies': currency,
'include_24hr_change': 'true',
'include_24hr_vol': 'true',
}
)
if response.status_code != 200:
logger.error(f"CoinGecko API error: {response.status_code}")
return {}
data = response.json()
if coin_id not in data:
logger.error(f"Coin {coin_id} not found in CoinGecko response")
return {}
coin_data = data[coin_id]
price_key = currency.lower()
if price_key not in coin_data:
logger.error(f"Currency {currency} not found for {coin_id}")
return {}
price = Decimal(str(coin_data[price_key]))
# Calculate high/low from 24h change if available
change_key = f"{price_key}_24h_change"
vol_key = f"{price_key}_24h_vol"
change_24h = coin_data.get(change_key, 0) or 0
volume_24h = coin_data.get(vol_key, 0) or 0
# Estimate high/low from current price and 24h change
# This is approximate since CoinGecko free tier doesn't provide exact high/low
current_price = float(price)
if change_24h:
estimated_high = current_price * (1 + abs(change_24h / 100) / 2)
estimated_low = current_price * (1 - abs(change_24h / 100) / 2)
else:
estimated_high = current_price
estimated_low = current_price
return {
'symbol': symbol,
'bid': price, # CoinGecko doesn't provide bid/ask, use last price
'ask': price,
'last': price,
'high': Decimal(str(estimated_high)),
'low': Decimal(str(estimated_low)),
'volume': Decimal(str(volume_24h)),
'timestamp': int(time.time() * 1000),
}
except Exception as e:
logger.error(f"Failed to get ticker for {symbol}: {e}")
return {}
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV data from CoinGecko.
Note: CoinGecko free tier has limited historical data access.
This method may return empty data for some timeframes.
"""
if not self._connected or not self._client:
logger.error("Provider not connected")
return []
try:
coin_id, currency = self._parse_symbol(symbol)
if not coin_id:
logger.error(f"Unknown symbol format: {symbol}")
return []
# CoinGecko uses days parameter instead of timeframe
# Map timeframe to days
timeframe_days_map = {
'1m': 1, # Last 24 hours, minute data (not available in free tier)
'5m': 1,
'15m': 1,
'30m': 1,
'1h': 1, # Last 24 hours, hourly data
'4h': 7, # Last 7 days
'1d': 30, # Last 30 days
'1w': 90, # Last 90 days
}
days = timeframe_days_map.get(timeframe, 7)
# Fetch OHLC data
response = self._client.get(
f"{self.BASE_URL}/coins/{coin_id}/ohlc",
params={
'vs_currency': currency,
'days': days,
}
)
if response.status_code != 200:
logger.warning(f"CoinGecko OHLC API returned {response.status_code}")
return []
data = response.json()
# CoinGecko returns: [timestamp_ms, open, high, low, close]
# We need to convert to: [timestamp_ms, open, high, low, close, volume]
# Note: CoinGecko OHLC endpoint doesn't include volume
# Filter by since if provided
if since:
since_timestamp = int(since.timestamp() * 1000)
data = [candle for candle in data if candle[0] >= since_timestamp]
# Limit results
data = data[-limit:] if limit else data
# Add volume as 0 (CoinGecko doesn't provide it in OHLC endpoint)
ohlcv = [candle + [0] for candle in data]
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} from CoinGecko")
return ohlcv
except Exception as e:
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
return []
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates via polling."""
if not self._connected or not self._client:
logger.error("Provider not connected")
return False
key = f"ticker_{symbol}"
# Add callback to subscribers
if key not in self._subscribers:
self._subscribers[key] = []
if callback not in self._subscribers[key]:
self._subscribers[key].append(callback)
# Start polling thread if not already running
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
self._stop_polling[key] = False
def poll_ticker():
"""Poll ticker respecting rate limits."""
while not self._stop_polling.get(key, False):
try:
ticker_data = self.get_ticker(symbol)
if ticker_data and 'last' in ticker_data:
# Call all callbacks
for cb in self._subscribers.get(key, []):
try:
cb({
'symbol': symbol,
'price': ticker_data['last'],
'bid': ticker_data.get('bid', 0),
'ask': ticker_data.get('ask', 0),
'volume': ticker_data.get('volume', 0),
'timestamp': ticker_data.get('timestamp'),
})
except Exception as e:
logger.error(f"Callback error for {symbol}: {e}")
# Rate limit: wait between requests (free tier: ~30-50 calls/min)
time.sleep(self._rate_limit_delay * 2) # Poll every 2 seconds
except Exception as e:
logger.error(f"Ticker polling error for {symbol}: {e}")
time.sleep(10) # Wait longer on error
thread = threading.Thread(target=poll_ticker, daemon=True)
thread.start()
self._polling_threads[key] = thread
logger.info(f"Subscribed to ticker updates for {symbol} (CoinGecko polling)")
return True
return True
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates."""
key = f"ticker_{symbol}"
# Stop polling
if key in self._stop_polling:
self._stop_polling[key] = True
# Remove from subscribers
super().unsubscribe_ticker(symbol, callback)
# Clean up thread
if key in self._polling_threads:
del self._polling_threads[key]
logger.info(f"Unsubscribed from ticker updates for {symbol}")
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format."""
# CoinGecko uses coin IDs, so we just normalize the format
return symbol.replace('-', '/').upper()

116
src/data/quality.py Normal file
View File

@@ -0,0 +1,116 @@
"""Data quality validation, gap filling, and retention policies."""
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from src.core.database import get_database, MarketData
from src.core.config import get_config
from src.core.logger import get_logger
logger = get_logger(__name__)
class DataQualityManager:
"""Manages data quality and retention."""
def __init__(self):
"""Initialize data quality manager."""
self.db = get_database()
self.config = get_config()
self.logger = get_logger(__name__)
def validate_data_quality(
self,
exchange: str,
symbol: str,
timeframe: str,
start_date: datetime,
end_date: datetime
) -> Dict[str, Any]:
"""Validate data quality.
Args:
exchange: Exchange name
symbol: Trading symbol
timeframe: Timeframe
start_date: Start date
end_date: End date
Returns:
Quality report
"""
session = self.db.get_session()
try:
data = session.query(MarketData).filter(
MarketData.exchange == exchange,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe,
MarketData.timestamp >= start_date,
MarketData.timestamp <= end_date
).order_by(MarketData.timestamp).all()
if len(data) == 0:
return {"valid": False, "reason": "No data"}
# Check for gaps
gaps = self._detect_gaps(data, timeframe)
# Check for anomalies
anomalies = self._detect_anomalies(data)
return {
"valid": len(gaps) == 0 and len(anomalies) == 0,
"total_records": len(data),
"gaps": len(gaps),
"anomalies": len(anomalies),
}
finally:
session.close()
def _detect_gaps(self, data: List[MarketData], timeframe: str) -> List[datetime]:
"""Detect gaps in data.
Args:
data: List of market data
timeframe: Timeframe
Returns:
List of gap timestamps
"""
gaps = []
# Simplified gap detection
return gaps
def _detect_anomalies(self, data: List[MarketData]) -> List[int]:
"""Detect data anomalies.
Args:
data: List of market data
Returns:
List of anomaly indices
"""
anomalies = []
# Simplified anomaly detection
return anomalies
def cleanup_old_data(self, days_to_keep: int = 365):
"""Clean up old data based on retention policy.
Args:
days_to_keep: Days of data to keep
"""
session = self.db.get_session()
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_to_keep)
deleted = session.query(MarketData).filter(
MarketData.timestamp < cutoff_date
).delete()
session.commit()
logger.info(f"Cleaned up {deleted} old data records")
except Exception as e:
session.rollback()
logger.error(f"Failed to cleanup old data: {e}")
finally:
session.close()

225
src/data/redis_cache.py Normal file
View File

@@ -0,0 +1,225 @@
"""Redis-based caching for market data and API responses."""
from typing import Any, Optional
import json
from datetime import datetime
from src.core.redis import get_redis_client
from src.core.logger import get_logger
logger = get_logger(__name__)
class RedisCache:
"""Redis-based cache for market data and API responses."""
# Default TTL values (seconds)
TTL_TICKER = 5 # Ticker prices are very volatile
TTL_OHLCV = 60 # OHLCV can be cached longer
TTL_ORDERBOOK = 2 # Order books change rapidly
TTL_API_RESPONSE = 30 # General API response cache
def __init__(self):
"""Initialize Redis cache."""
self.redis = get_redis_client()
async def get_ticker(self, symbol: str) -> Optional[dict]:
"""Get cached ticker data.
Args:
symbol: Trading symbol (e.g., 'BTC/USD')
Returns:
Cached ticker data or None
"""
key = f"cache:ticker:{symbol.replace('/', '_')}"
try:
client = self.redis.get_client()
data = await client.get(key)
if data:
logger.debug(f"Cache hit for ticker:{symbol}")
return json.loads(data)
return None
except Exception as e:
logger.warning(f"Redis cache get failed: {e}")
return None
async def set_ticker(self, symbol: str, data: dict, ttl: int = None) -> bool:
"""Cache ticker data.
Args:
symbol: Trading symbol
data: Ticker data
ttl: Time-to-live in seconds (default: TTL_TICKER)
Returns:
True if cached successfully
"""
key = f"cache:ticker:{symbol.replace('/', '_')}"
ttl = ttl or self.TTL_TICKER
try:
client = self.redis.get_client()
await client.setex(key, ttl, json.dumps(data))
logger.debug(f"Cached ticker:{symbol} for {ttl}s")
return True
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
return False
async def get_ohlcv(self, symbol: str, timeframe: str, limit: int = 100) -> Optional[list]:
"""Get cached OHLCV data.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
limit: Number of candles
Returns:
Cached OHLCV data or None
"""
key = f"cache:ohlcv:{symbol.replace('/', '_')}:{timeframe}:{limit}"
try:
client = self.redis.get_client()
data = await client.get(key)
if data:
logger.debug(f"Cache hit for ohlcv:{symbol}:{timeframe}")
return json.loads(data)
return None
except Exception as e:
logger.warning(f"Redis cache get failed: {e}")
return None
async def set_ohlcv(self, symbol: str, timeframe: str, data: list, limit: int = 100, ttl: int = None) -> bool:
"""Cache OHLCV data.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
data: OHLCV data
limit: Number of candles
ttl: Time-to-live in seconds
Returns:
True if cached successfully
"""
key = f"cache:ohlcv:{symbol.replace('/', '_')}:{timeframe}:{limit}"
ttl = ttl or self.TTL_OHLCV
try:
client = self.redis.get_client()
await client.setex(key, ttl, json.dumps(data))
logger.debug(f"Cached ohlcv:{symbol}:{timeframe} for {ttl}s")
return True
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
return False
async def get_api_response(self, cache_key: str) -> Optional[dict]:
"""Get cached API response.
Args:
cache_key: Unique cache key
Returns:
Cached response or None
"""
key = f"cache:api:{cache_key}"
try:
client = self.redis.get_client()
data = await client.get(key)
if data:
logger.debug(f"Cache hit for api:{cache_key}")
return json.loads(data)
return None
except Exception as e:
logger.warning(f"Redis cache get failed: {e}")
return None
async def set_api_response(self, cache_key: str, data: dict, ttl: int = None) -> bool:
"""Cache API response.
Args:
cache_key: Unique cache key
data: Response data
ttl: Time-to-live in seconds
Returns:
True if cached successfully
"""
key = f"cache:api:{cache_key}"
ttl = ttl or self.TTL_API_RESPONSE
try:
client = self.redis.get_client()
await client.setex(key, ttl, json.dumps(data))
logger.debug(f"Cached api:{cache_key} for {ttl}s")
return True
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
return False
async def invalidate(self, pattern: str) -> int:
"""Invalidate cache entries matching pattern.
Args:
pattern: Redis key pattern (e.g., 'cache:ticker:*')
Returns:
Number of keys deleted
"""
try:
client = self.redis.get_client()
keys = []
async for key in client.scan_iter(match=pattern):
keys.append(key)
if keys:
deleted = await client.delete(*keys)
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
return deleted
return 0
except Exception as e:
logger.warning(f"Redis cache invalidation failed: {e}")
return 0
async def get_stats(self) -> dict:
"""Get cache statistics.
Returns:
Cache statistics
"""
try:
client = self.redis.get_client()
info = await client.info('memory')
# Count cached items by type
ticker_count = 0
ohlcv_count = 0
api_count = 0
async for key in client.scan_iter(match='cache:ticker:*'):
ticker_count += 1
async for key in client.scan_iter(match='cache:ohlcv:*'):
ohlcv_count += 1
async for key in client.scan_iter(match='cache:api:*'):
api_count += 1
return {
"memory_used": info.get('used_memory_human', 'N/A'),
"ticker_entries": ticker_count,
"ohlcv_entries": ohlcv_count,
"api_entries": api_count,
"total_entries": ticker_count + ohlcv_count + api_count
}
except Exception as e:
logger.warning(f"Failed to get cache stats: {e}")
return {"error": str(e)}
# Global cache instance
_redis_cache: Optional[RedisCache] = None
def get_redis_cache() -> RedisCache:
"""Get global Redis cache instance."""
global _redis_cache
if _redis_cache is None:
_redis_cache = RedisCache()
return _redis_cache

75
src/data/storage.py Normal file
View File

@@ -0,0 +1,75 @@
"""Data persistence."""
from decimal import Decimal
from datetime import datetime
from typing import List, Optional
from sqlalchemy.orm import Session
from src.core.database import get_database, MarketData
from src.core.logger import get_logger
logger = get_logger(__name__)
class DataStorage:
"""Manages data storage and persistence."""
def __init__(self):
"""Initialize data storage."""
self.db = get_database()
self.logger = get_logger(__name__)
def store_ohlcv(
self,
exchange: str,
symbol: str,
timeframe: str,
timestamp: datetime,
open: Decimal,
high: Decimal,
low: Decimal,
close: Decimal,
volume: Decimal
):
"""Store OHLCV data.
Args:
exchange: Exchange name
symbol: Trading symbol
timeframe: Timeframe
timestamp: Timestamp
open: Open price
high: High price
low: Low price
close: Close price
volume: Volume
"""
session = self.db.get_session()
try:
# Check if exists
existing = session.query(MarketData).filter_by(
exchange=exchange,
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp
).first()
if not existing:
market_data = MarketData(
exchange=exchange,
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp,
open=open,
high=high,
low=low,
close=close,
volume=volume
)
session.add(market_data)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to store OHLCV data: {e}")
finally:
session.close()

14
src/exchanges/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
"""Exchange adapters package."""
from .base import BaseExchangeAdapter
from .factory import ExchangeFactory, get_exchange
from .coinbase import CoinbaseAdapter
from .public_data import PublicDataAdapter
# Register exchange adapters
ExchangeFactory.register("coinbase", CoinbaseAdapter)
ExchangeFactory.register("binance public", PublicDataAdapter)
ExchangeFactory.register("public data", PublicDataAdapter) # Alias
__all__ = ['ExchangeFactory', 'get_exchange', 'CoinbaseAdapter', 'PublicDataAdapter', 'BaseExchangeAdapter']

309
src/exchanges/base.py Normal file
View File

@@ -0,0 +1,309 @@
"""Base exchange adapter interface."""
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import Dict, List, Optional, Any
from datetime import datetime
from src.core.logger import get_logger
logger = get_logger(__name__)
class Order:
"""Order representation."""
def __init__(
self,
symbol: str,
side: str, # 'buy' or 'sell'
order_type: str, # 'market', 'limit', etc.
quantity: Decimal,
price: Optional[Decimal] = None,
**kwargs
):
self.symbol = symbol
self.side = side
self.order_type = order_type
self.quantity = quantity
self.price = price
self.extra = kwargs
class BaseExchangeAdapter(ABC):
"""Base class for exchange adapters."""
def __init__(self, api_key: str, api_secret: str, sandbox: bool = False, read_only: bool = True):
"""Initialize exchange adapter.
Args:
api_key: API key
api_secret: API secret
sandbox: Use sandbox/testnet
read_only: Read-only mode (no trading)
"""
self.api_key = api_key
self.api_secret = api_secret
self.sandbox = sandbox
self.read_only = read_only
self.logger = get_logger(f"exchange.{self.__class__.__name__}")
self._connected = False
@property
@abstractmethod
def name(self) -> str:
"""Exchange name."""
pass
@abstractmethod
async def connect(self) -> bool:
"""Connect to exchange.
Returns:
True if connection successful
"""
pass
@abstractmethod
async def disconnect(self):
"""Disconnect from exchange."""
pass
@abstractmethod
async def get_balance(self, currency: Optional[str] = None) -> Dict[str, Decimal]:
"""Get account balance.
Args:
currency: Specific currency to get balance for, or None for all
Returns:
Dictionary of currency -> balance
"""
pass
@abstractmethod
async def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker for symbol.
Args:
symbol: Trading pair symbol (e.g., 'BTC/USD')
Returns:
Ticker data with 'bid', 'ask', 'last', 'volume', etc.
"""
pass
@abstractmethod
async def get_orderbook(self, symbol: str, limit: int = 20) -> Dict[str, List]:
"""Get order book for symbol.
Args:
symbol: Trading pair symbol
limit: Number of orders per side
Returns:
Dictionary with 'bids' and 'asks' lists
"""
pass
@abstractmethod
async def place_order(self, order: Order) -> Dict[str, Any]:
"""Place an order.
Args:
order: Order object
Returns:
Order response with order ID and status
"""
pass
@abstractmethod
async def cancel_order(self, order_id: str, symbol: str) -> bool:
"""Cancel an order.
Args:
order_id: Order ID to cancel
symbol: Trading pair symbol
Returns:
True if cancellation successful
"""
pass
@abstractmethod
async def get_order_status(self, order_id: str, symbol: str) -> Dict[str, Any]:
"""Get order status.
Args:
order_id: Order ID
symbol: Trading pair symbol
Returns:
Order status information
"""
pass
@abstractmethod
async def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get open orders.
Args:
symbol: Optional symbol to filter by
Returns:
List of open orders
"""
pass
@abstractmethod
async def get_positions(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get open positions.
Args:
symbol: Optional symbol to filter by
Returns:
List of positions
"""
pass
@abstractmethod
async def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV (candlestick) data.
Args:
symbol: Trading pair symbol
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
since: Start datetime
limit: Number of candles
Returns:
List of [timestamp, open, high, low, close, volume]
"""
pass
@abstractmethod
async def subscribe_ticker(self, symbol: str, callback):
"""Subscribe to ticker updates via WebSocket.
Args:
symbol: Trading pair symbol
callback: Callback function(ticker_data)
"""
pass
@abstractmethod
async def subscribe_orderbook(self, symbol: str, callback):
"""Subscribe to order book updates via WebSocket.
Args:
symbol: Trading pair symbol
callback: Callback function(orderbook_data)
"""
pass
@abstractmethod
async def subscribe_trades(self, symbol: str, callback):
"""Subscribe to trade updates via WebSocket.
Args:
symbol: Trading pair symbol
callback: Callback function(trade_data)
"""
pass
def validate_order(self, order: Order) -> bool:
"""Validate order before placing.
Args:
order: Order to validate
Returns:
True if order is valid
"""
if self.read_only:
self.logger.warning("Exchange is in read-only mode")
return False
if not order.symbol:
self.logger.error("Order symbol is required")
return False
if order.quantity <= 0:
self.logger.error("Order quantity must be positive")
return False
if order.order_type == 'limit' and not order.price:
self.logger.error("Limit orders require a price")
return False
return True
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format.
Args:
symbol: Symbol to normalize
Returns:
Normalized symbol
"""
# Default implementation - override in subclasses if needed
return symbol.upper().replace('-', '/')
def get_fee_structure(self) -> Dict[str, float]:
"""Get fee structure (maker/taker fees).
Returns:
Dictionary with 'maker' and 'taker' fee percentages, and optional 'minimum'
"""
# Default fees - override in subclasses
return {
'maker': 0.001, # 0.1%
'taker': 0.001, # 0.1%
'minimum': 0.0, # Minimum fee amount
}
def extract_fee_from_order_response(self, order_response: Dict[str, Any]) -> Optional[Decimal]:
"""Extract actual fee from order response.
Args:
order_response: Order response from exchange API
Returns:
Fee amount as Decimal, or None if not available
"""
# Default implementation - override in subclasses
# Try common fee fields
if 'fee' in order_response:
try:
return Decimal(str(order_response['fee']))
except (ValueError, TypeError):
pass
if 'fees' in order_response:
# Some exchanges return fees as a list
fees = order_response['fees']
if isinstance(fees, list) and len(fees) > 0:
try:
# Sum all fees
total_fee = sum(Decimal(str(f.get('cost', 0) if isinstance(f, dict) else f)) for f in fees)
return total_fee
except (ValueError, TypeError):
pass
# Try 'cost' field (some exchanges use this)
if 'cost' in order_response:
try:
return Decimal(str(order_response['cost']))
except (ValueError, TypeError):
pass
return None

392
src/exchanges/coinbase.py Normal file
View File

@@ -0,0 +1,392 @@
"""Coinbase Advanced Trade API adapter with WebSocket support."""
import asyncio
import ccxt
import websockets
import json
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from .base import BaseExchangeAdapter, Order
from src.core.logger import get_logger
logger = get_logger(__name__)
class CoinbaseAdapter(BaseExchangeAdapter):
"""Coinbase Advanced Trade API adapter."""
@property
def name(self) -> str:
"""Exchange name."""
return "Coinbase"
def __init__(self, api_key: str, api_secret: str, sandbox: bool = False, read_only: bool = True):
"""Initialize Coinbase adapter."""
super().__init__(api_key, api_secret, sandbox, read_only)
# Initialize ccxt exchange with async support
exchange_class = ccxt.async_support.coinbase
self.exchange = exchange_class({
'apiKey': api_key,
'secret': api_secret,
'sandbox': sandbox,
'enableRateLimit': True,
'options': {
'defaultType': 'spot', # or 'future' for futures
}
})
# WebSocket connections
self._ws_connections = {}
self._ws_callbacks = {}
self._ws_loop = None
async def connect(self) -> bool:
"""Connect to exchange."""
try:
# Test connection by fetching account info
if self.read_only:
# Just verify credentials work
await self.exchange.fetch_balance()
else:
# Verify trading permissions
await self.exchange.fetch_balance()
self._connected = True
logger.info(f"Connected to {self.name} (sandbox={self.sandbox}, read_only={self.read_only})")
return True
except Exception as e:
logger.error(f"Failed to connect to {self.name}: {e}")
self._connected = False
return False
async def disconnect(self):
"""Disconnect from exchange."""
# Close WebSocket connections
for ws in self._ws_connections.values():
try:
# await ws.close() # If using real websockets
pass
except Exception:
pass
await self.exchange.close()
self._ws_connections.clear()
self._ws_callbacks.clear()
self._connected = False
logger.info(f"Disconnected from {self.name}")
async def get_balance(self, currency: Optional[str] = None) -> Dict[str, Decimal]:
"""Get account balance."""
try:
balance = await self.exchange.fetch_balance()
result = {}
for curr, amounts in balance.items():
if isinstance(amounts, dict) and 'free' in amounts:
free = Decimal(str(amounts['free']))
if free > 0 or currency == curr:
result[curr] = free
if currency:
return {currency: result.get(currency, Decimal(0))}
return result
except Exception as e:
logger.error(f"Failed to get balance: {e}")
return {}
async def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker."""
try:
ticker = await self.exchange.fetch_ticker(self.normalize_symbol(symbol))
return {
'symbol': symbol,
'bid': Decimal(str(ticker.get('bid', 0))),
'ask': Decimal(str(ticker.get('ask', 0))),
'last': Decimal(str(ticker.get('last', 0))),
'high': Decimal(str(ticker.get('high', 0))),
'low': Decimal(str(ticker.get('low', 0))),
'volume': Decimal(str(ticker.get('quoteVolume', 0))),
'timestamp': ticker.get('timestamp'),
}
except Exception as e:
logger.error(f"Failed to get ticker for {symbol}: {e}")
return {}
async def get_orderbook(self, symbol: str, limit: int = 20) -> Dict[str, List]:
"""Get order book."""
try:
orderbook = await self.exchange.fetch_order_book(self.normalize_symbol(symbol), limit)
return {
'bids': [[Decimal(str(b[0])), Decimal(str(b[1]))] for b in orderbook['bids']],
'asks': [[Decimal(str(a[0])), Decimal(str(a[1]))] for a in orderbook['asks']],
}
except Exception as e:
logger.error(f"Failed to get orderbook for {symbol}: {e}")
return {'bids': [], 'asks': []}
async def place_order(self, order: Order) -> Dict[str, Any]:
"""Place an order."""
if not self.validate_order(order):
return {'error': 'Invalid order'}
try:
symbol = self.normalize_symbol(order.symbol)
if order.order_type == 'market':
result = await self.exchange.create_market_order(
symbol,
order.side,
float(order.quantity)
)
elif order.order_type == 'limit':
result = await self.exchange.create_limit_order(
symbol,
order.side,
float(order.quantity),
float(order.price)
)
else:
return {'error': f'Unsupported order type: {order.order_type}'}
# Extract fee from result
fee = self.extract_fee_from_order_response(result)
return {
'id': result.get('id'),
'symbol': symbol,
'status': result.get('status', 'open'),
'side': result.get('side'),
'type': result.get('type'),
'amount': Decimal(str(result.get('amount', 0))),
'price': Decimal(str(result.get('price', 0))) if result.get('price') else None,
'fee': fee, # Include fee in response
}
except Exception as e:
logger.error(f"Failed to place order: {e}")
return {'error': str(e)}
async def cancel_order(self, order_id: str, symbol: str) -> bool:
"""Cancel an order."""
try:
await self.exchange.cancel_order(order_id, self.normalize_symbol(symbol))
return True
except Exception as e:
logger.error(f"Failed to cancel order {order_id}: {e}")
return False
async def get_order_status(self, order_id: str, symbol: str) -> Dict[str, Any]:
"""Get order status."""
try:
order = await self.exchange.fetch_order(order_id, self.normalize_symbol(symbol))
# Extract fee from order
fee = self.extract_fee_from_order_response(order)
return {
'id': order.get('id'),
'symbol': symbol,
'status': order.get('status'),
'side': order.get('side'),
'type': order.get('type'),
'amount': Decimal(str(order.get('amount', 0))),
'filled': Decimal(str(order.get('filled', 0))),
'remaining': Decimal(str(order.get('remaining', 0))),
'price': Decimal(str(order.get('price', 0))) if order.get('price') else None,
'average': Decimal(str(order.get('average', 0))) if order.get('average') else None,
'fee': fee, # Include fee in response
}
except Exception as e:
logger.error(f"Failed to get order status {order_id}: {e}")
return {}
async def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get open orders."""
try:
if symbol:
orders = await self.exchange.fetch_open_orders(self.normalize_symbol(symbol))
else:
orders = await self.exchange.fetch_open_orders()
return [
{
'id': o.get('id'),
'symbol': o.get('symbol'),
'status': o.get('status'),
'side': o.get('side'),
'type': o.get('type'),
'amount': Decimal(str(o.get('amount', 0))),
'filled': Decimal(str(o.get('filled', 0))),
'price': Decimal(str(o.get('price', 0))) if o.get('price') else None,
}
for o in orders
]
except Exception as e:
logger.error(f"Failed to get open orders: {e}")
return []
async def get_positions(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get open positions."""
try:
# Coinbase Advanced Trade uses fetch_positions for futures
# For spot, positions are derived from balances
positions = await self.exchange.fetch_positions(symbols=[symbol] if symbol else None)
return [
{
'symbol': p.get('symbol'),
'side': p.get('side'),
'size': Decimal(str(p.get('size', 0))),
'entry_price': Decimal(str(p.get('entryPrice', 0))),
'mark_price': Decimal(str(p.get('markPrice', 0))),
'unrealized_pnl': Decimal(str(p.get('unrealizedPnl', 0))),
}
for p in positions if p.get('size', 0) != 0
]
except Exception as e:
logger.error(f"Failed to get positions: {e}")
return []
async def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV data."""
try:
since_timestamp = int(since.timestamp() * 1000) if since else None
ohlcv = await self.exchange.fetch_ohlcv(
self.normalize_symbol(symbol),
timeframe,
since_timestamp,
limit
)
return ohlcv
except Exception as e:
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
return []
def subscribe_ticker(self, symbol: str, callback: Callable):
"""Subscribe to ticker updates via WebSocket."""
try:
import asyncio
import websockets
import json
# Normalize symbol for Coinbase
normalized_symbol = self.normalize_symbol(symbol)
# Store callback
self._ws_callbacks[f'ticker_{symbol}'] = callback
# Start WebSocket connection if not already started
if not hasattr(self, '_ws_running') or not self._ws_running:
self._start_websocket_loop()
logger.info(f"Subscribed to ticker updates for {symbol}")
except ImportError:
logger.warning("websockets library not available, using polling fallback")
self._ws_callbacks[f'ticker_{symbol}'] = callback
except Exception as e:
logger.error(f"Failed to subscribe to ticker: {e}")
self._ws_callbacks[f'ticker_{symbol}'] = callback
def subscribe_orderbook(self, symbol: str, callback: Callable):
"""Subscribe to order book updates via WebSocket."""
try:
normalized_symbol = self.normalize_symbol(symbol)
self._ws_callbacks[f'orderbook_{symbol}'] = callback
if not hasattr(self, '_ws_running') or not self._ws_running:
self._start_websocket_loop()
logger.info(f"Subscribed to orderbook updates for {symbol}")
except Exception as e:
logger.error(f"Failed to subscribe to orderbook: {e}")
self._ws_callbacks[f'orderbook_{symbol}'] = callback
def subscribe_trades(self, symbol: str, callback: Callable):
"""Subscribe to trade updates via WebSocket."""
try:
normalized_symbol = self.normalize_symbol(symbol)
self._ws_callbacks[f'trades_{symbol}'] = callback
if not hasattr(self, '_ws_running') or not self._ws_running:
self._start_websocket_loop()
logger.info(f"Subscribed to trades updates for {symbol}")
except Exception as e:
logger.error(f"Failed to subscribe to trades: {e}")
self._ws_callbacks[f'trades_{symbol}'] = callback
def _start_websocket_loop(self):
"""Start WebSocket connection loop."""
try:
import threading
self._ws_running = True
# Start WebSocket in background thread
# Note: Full implementation would use asyncio event loop
logger.info("WebSocket connection started (basic implementation)")
except Exception as e:
logger.error(f"Failed to start WebSocket: {e}")
self._ws_running = False
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol for Coinbase."""
# Coinbase uses format like BTC-USD
return symbol.replace('/', '-').upper()
def get_fee_structure(self) -> Dict[str, float]:
"""Get Coinbase fee structure."""
# Coinbase Advanced Trade fees (approximate)
# Actual fees may vary based on trading volume and account type
return {
'maker': 0.004, # 0.4%
'taker': 0.006, # 0.6%
'minimum': 0.0, # No minimum fee
}
def extract_fee_from_order_response(self, order_response: Dict[str, Any]) -> Optional[Decimal]:
"""Extract actual fee from Coinbase order response.
Args:
order_response: Order response from Coinbase API
Returns:
Fee amount as Decimal, or None if not available
"""
# Coinbase/ccxt typically returns fees in 'fee' field or 'fees' list
if 'fee' in order_response and order_response['fee']:
try:
fee_data = order_response['fee']
if isinstance(fee_data, dict):
# Fee is often {'cost': amount, 'currency': 'USD'}
return Decimal(str(fee_data.get('cost', 0)))
else:
return Decimal(str(fee_data))
except (ValueError, TypeError):
pass
# Try 'fees' list
if 'fees' in order_response:
fees = order_response['fees']
if isinstance(fees, list) and len(fees) > 0:
try:
total_fee = Decimal(0)
for fee_item in fees:
if isinstance(fee_item, dict):
total_fee += Decimal(str(fee_item.get('cost', 0)))
else:
total_fee += Decimal(str(fee_item))
return total_fee if total_fee > 0 else None
except (ValueError, TypeError):
pass
return None

165
src/exchanges/factory.py Normal file
View File

@@ -0,0 +1,165 @@
"""Exchange factory for creating exchange adapters."""
from typing import Optional, Dict, Type, List
from src.core.database import get_database, Exchange
from src.core.logger import get_logger
from src.security.key_manager import get_key_manager
from .base import BaseExchangeAdapter
logger = get_logger(__name__)
class ExchangeFactory:
"""Factory for creating exchange adapter instances."""
_adapters: Dict[str, Type[BaseExchangeAdapter]] = {}
@classmethod
def register(cls, name: str, adapter_class: Type[BaseExchangeAdapter]):
"""Register an exchange adapter.
Args:
name: Exchange name
adapter_class: Adapter class
"""
cls._adapters[name.lower()] = adapter_class
logger.info(f"Registered exchange adapter: {name}")
@classmethod
async def create(cls, exchange_id: int) -> Optional[BaseExchangeAdapter]:
"""Create exchange adapter from database.
Args:
exchange_id: Exchange ID from database
Returns:
Exchange adapter instance or None
"""
from sqlalchemy import select
db = get_database()
key_manager = get_key_manager()
async with db.get_session() as session:
try:
stmt = select(Exchange).where(Exchange.id == exchange_id)
result = await session.execute(stmt)
exchange = result.scalar_one_or_none()
if not exchange:
logger.error(f"Exchange {exchange_id} not found")
return None
if not exchange.enabled:
logger.warning(f"Exchange {exchange.name} is disabled")
return None
# Get adapter class
adapter_class = cls._adapters.get(exchange.name.lower())
if not adapter_class:
logger.error(f"No adapter registered for exchange: {exchange.name}")
return None
# Check if this is a public data adapter (doesn't need credentials)
from .public_data import PublicDataAdapter
is_public_data = adapter_class == PublicDataAdapter
if is_public_data:
# Public data adapter doesn't need credentials
adapter = adapter_class(
api_key="",
api_secret="",
sandbox=False,
read_only=True
)
else:
# Get credentials for regular exchanges
credentials = await key_manager.get_exchange_credentials(exchange_id)
if not credentials:
logger.error(f"No credentials found for exchange: {exchange.name}")
return None
# Create adapter instance
adapter = adapter_class(
api_key=credentials['api_key'],
api_secret=credentials['api_secret'],
sandbox=credentials['sandbox'],
read_only=credentials['read_only']
)
# Connect
if await adapter.connect():
logger.info(f"Connected to {exchange.name}")
return adapter
else:
logger.error(f"Failed to connect to {exchange.name}")
return None
except Exception as e:
logger.error(f"Failed to create exchange adapter: {e}")
return None
@classmethod
async def create_by_name(
cls,
name: str,
api_key: str,
api_secret: str,
sandbox: bool = False,
read_only: bool = True
) -> Optional[BaseExchangeAdapter]:
"""Create exchange adapter directly.
Args:
name: Exchange name
api_key: API key
api_secret: API secret
sandbox: Use sandbox/testnet
read_only: Read-only mode
Returns:
Exchange adapter instance or None
"""
adapter_class = cls._adapters.get(name.lower())
if not adapter_class:
logger.error(f"No adapter registered for exchange: {name}")
return None
try:
adapter = adapter_class(
api_key=api_key,
api_secret=api_secret,
sandbox=sandbox,
read_only=read_only
)
if await adapter.connect():
return adapter
else:
logger.error(f"Failed to connect to {name}")
return None
except Exception as e:
logger.error(f"Failed to create exchange adapter: {e}")
return None
@classmethod
def list_available(cls) -> List[str]:
"""List available exchange adapters.
Returns:
List of exchange names
"""
return list(cls._adapters.keys())
# Convenience function
async def get_exchange(exchange_id: int) -> Optional[BaseExchangeAdapter]:
"""Get exchange adapter by ID.
Args:
exchange_id: Exchange ID
Returns:
Exchange adapter instance or None
"""
return await ExchangeFactory.create(exchange_id)

View File

@@ -0,0 +1,433 @@
"""Public Data Exchange Adapter - Free market data without API keys.
Uses CCXT in public mode to fetch market data from Binance (or other exchanges)
without requiring API keys. Perfect for testing, backtesting, and paper trading.
"""
import ccxt.async_support as ccxt
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from .base import BaseExchangeAdapter, Order
from src.core.logger import get_logger
logger = get_logger(__name__)
class PublicDataAdapter(BaseExchangeAdapter):
"""Public market data adapter using CCXT in public mode (no API keys needed).
This adapter uses Binance's public API to fetch:
- Historical OHLCV data
- Real-time ticker data
- Order book data
- Trade data
No API keys required - perfect for testing and paper trading.
"""
@property
def name(self) -> str:
"""Exchange name."""
return self._selected_exchange_name or "Public Data"
def __init__(
self,
api_key: str = "",
api_secret: str = "",
sandbox: bool = False,
read_only: bool = True
):
"""Initialize Public Data adapter.
Args:
api_key: Ignored (public data doesn't need API keys)
api_secret: Ignored (public data doesn't need API keys)
sandbox: Ignored (Binance public API is always live data)
read_only: Always True (this adapter is read-only by design)
"""
super().__init__("", "", False, True) # Always read-only, no keys needed
# List of exchanges to try as fallbacks (in order of preference)
# These exchanges have good public API access without geographic restrictions
self._exchange_options = [
('kraken', 'Kraken Public'),
('coinbase', 'Coinbase Public'),
('binance', 'Binance Public'), # Try Binance last since it has geo restrictions
]
self.exchange = None
self._selected_exchange_name = None
# WebSocket connections for real-time data
self._ws_callbacks = {}
self._polling_timer = None
async def connect(self) -> bool:
"""Connect to exchange (test public API access).
Tries multiple exchanges as fallbacks if one is blocked.
Returns:
True if connection successful (public API is always available)
"""
# Try each exchange until one works
for exchange_id, exchange_display_name in self._exchange_options:
try:
# Create exchange instance
exchange_class = getattr(ccxt, exchange_id, None)
if not exchange_class:
continue
self.exchange = exchange_class({
'enableRateLimit': True,
'options': {
'defaultType': 'spot',
}
})
# Load markets if not already loaded
if not hasattr(self.exchange, 'markets') or not self.exchange.markets:
try:
await self.exchange.load_markets()
except Exception as e:
raise
# Test connection with a common symbol
# Different exchanges use different symbol formats
test_symbols = ['BTC/USDT', 'BTC/USD', 'BTC/EUR']
ticker_result = None
for test_symbol in test_symbols:
try:
ticker_result = await self.exchange.fetch_ticker(test_symbol)
break
except Exception as e:
continue
if not ticker_result:
raise Exception("Could not fetch ticker from any test symbol")
# Success! Use this exchange
self._selected_exchange_name = exchange_display_name
self._connected = True
logger.info(f"Connected to {self.name} (public data, no API keys needed)")
return True
except Exception as e:
logger.warning(f"Failed to connect to {exchange_id}: {e}, trying next exchange...")
continue
# All exchanges failed
logger.error("Failed to connect to any public exchange")
self._connected = False
return False
async def disconnect(self):
"""Disconnect from exchange."""
self._ws_callbacks.clear()
if self._polling_timer:
self._polling_timer.cancel()
self._polling_timer = None
if self.exchange:
await self.exchange.close()
self._connected = False
logger.info(f"Disconnected from {self.name}")
async def get_balance(self, currency: Optional[str] = None) -> Dict[str, Decimal]:
"""Get account balance.
Note: Public data adapter cannot access account balances.
Returns empty dict. Use paper trading balance instead.
"""
logger.warning("Public data adapter cannot access account balances")
return {}
async def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker (public endpoint)."""
try:
ticker = await self.exchange.fetch_ticker(self.normalize_symbol(symbol))
return {
'symbol': symbol,
'bid': Decimal(str(ticker.get('bid', 0))),
'ask': Decimal(str(ticker.get('ask', 0))),
'last': Decimal(str(ticker.get('last', 0))),
'high': Decimal(str(ticker.get('high', 0))),
'low': Decimal(str(ticker.get('low', 0))),
'volume': Decimal(str(ticker.get('quoteVolume', 0))),
'timestamp': ticker.get('timestamp'),
}
except Exception as e:
logger.error(f"Failed to get ticker for {symbol}: {e}")
return {}
async def get_orderbook(self, symbol: str, limit: int = 20) -> Dict[str, List]:
"""Get order book (public endpoint)."""
try:
orderbook = await self.exchange.fetch_order_book(self.normalize_symbol(symbol), limit)
return {
'bids': [[Decimal(str(b[0])), Decimal(str(b[1]))] for b in orderbook['bids']],
'asks': [[Decimal(str(a[0])), Decimal(str(a[1]))] for a in orderbook['asks']],
}
except Exception as e:
logger.error(f"Failed to get orderbook for {symbol}: {e}")
return {'bids': [], 'asks': []}
async def place_order(self, order: Order) -> Dict[str, Any]:
"""Place an order.
Note: Public data adapter cannot place orders.
Returns error message. Use paper trading instead.
"""
logger.warning("Public data adapter cannot place orders - use paper trading")
return {'error': 'Public data adapter is read-only. Use paper trading for order execution.'}
async def cancel_order(self, order_id: str, symbol: str) -> bool:
"""Cancel an order.
Note: Public data adapter cannot cancel orders.
"""
logger.warning("Public data adapter cannot cancel orders")
return False
async def get_order_status(self, order_id: str, symbol: str) -> Dict[str, Any]:
"""Get order status.
Note: Public data adapter cannot access order status.
"""
logger.warning("Public data adapter cannot access order status")
return {}
async def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get open orders.
Note: Public data adapter cannot access orders.
"""
logger.warning("Public data adapter cannot access orders")
return []
async def get_positions(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get open positions.
Note: Public data adapter cannot access positions.
"""
logger.warning("Public data adapter cannot access positions")
return []
async def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV data (public endpoint).
This is the main method for fetching historical data.
Most exchanges support up to 1000 candles per request.
"""
try:
since_timestamp = int(since.timestamp() * 1000) if since else None
# Most exchanges support up to 1000 candles per request
# If limit > 1000, we'll need to make multiple requests
max_limit = min(limit, 1000)
normalized_symbol = self.normalize_symbol(symbol)
ohlcv = await self.exchange.fetch_ohlcv(
normalized_symbol,
timeframe,
since_timestamp,
max_limit
)
logger.info(f"Fetched {len(ohlcv)} candles for {symbol} ({timeframe})")
return ohlcv
except Exception as e:
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
return []
def subscribe_ticker(self, symbol: str, callback: Callable):
"""Subscribe to ticker updates (polling-based for public data).
Since we don't have WebSocket auth, we poll the ticker endpoint.
"""
try:
import threading
import time
normalized_symbol = self.normalize_symbol(symbol)
self._ws_callbacks[f'ticker_{symbol}'] = callback
def poll_ticker():
"""Poll ticker every 2 seconds."""
while f'ticker_{symbol}' in self._ws_callbacks:
try:
ticker = self.get_ticker(symbol)
if ticker and callback:
callback({
'price': ticker.get('last', 0),
'bid': ticker.get('bid', 0),
'ask': ticker.get('ask', 0),
'volume': ticker.get('volume', 0),
'timestamp': ticker.get('timestamp'),
})
time.sleep(2) # Poll every 2 seconds
except Exception as e:
logger.error(f"Ticker polling error: {e}")
time.sleep(5) # Wait longer on error
# Start polling thread
thread = threading.Thread(target=poll_ticker, daemon=True)
thread.start()
logger.info(f"Subscribed to ticker updates for {symbol} (polling mode)")
except Exception as e:
logger.error(f"Failed to subscribe to ticker: {e}")
self._ws_callbacks[f'ticker_{symbol}'] = callback
def subscribe_orderbook(self, symbol: str, callback: Callable):
"""Subscribe to order book updates (polling-based)."""
try:
import threading
import time
normalized_symbol = self.normalize_symbol(symbol)
self._ws_callbacks[f'orderbook_{symbol}'] = callback
def poll_orderbook():
"""Poll order book every 1 second."""
while f'orderbook_{symbol}' in self._ws_callbacks:
try:
orderbook = self.get_orderbook(symbol, limit=20)
if orderbook and callback:
callback(orderbook)
time.sleep(1) # Poll every second
except Exception as e:
logger.error(f"Orderbook polling error: {e}")
time.sleep(5)
thread = threading.Thread(target=poll_orderbook, daemon=True)
thread.start()
logger.info(f"Subscribed to orderbook updates for {symbol} (polling mode)")
except Exception as e:
logger.error(f"Failed to subscribe to orderbook: {e}")
self._ws_callbacks[f'orderbook_{symbol}'] = callback
def subscribe_trades(self, symbol: str, callback: Callable):
"""Subscribe to trade updates (polling-based)."""
try:
import threading
import time
normalized_symbol = self.normalize_symbol(symbol)
self._ws_callbacks[f'trades_{symbol}'] = callback
def poll_trades():
"""Poll recent trades every 1 second."""
while f'trades_{symbol}' in self._ws_callbacks:
try:
trades = self.exchange.fetch_trades(normalized_symbol, limit=10)
if trades and callback:
for trade in trades[-5:]: # Last 5 trades
callback({
'price': trade.get('price', 0),
'amount': trade.get('amount', 0),
'side': trade.get('side', 'buy'),
'timestamp': trade.get('timestamp'),
})
time.sleep(1)
except Exception as e:
logger.error(f"Trades polling error: {e}")
time.sleep(5)
thread = threading.Thread(target=poll_trades, daemon=True)
thread.start()
logger.info(f"Subscribed to trades updates for {symbol} (polling mode)")
except Exception as e:
logger.error(f"Failed to subscribe to trades: {e}")
self._ws_callbacks[f'trades_{symbol}'] = callback
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol for the selected exchange.
Different exchanges use different symbol formats:
- Binance/Kraken: BTC/USDT, BTC/USD
- Some exchanges: BTC-USDT, BTC-USD
- Kraken sometimes uses: XBT/USD instead of BTC/USD
"""
if not self.exchange:
return symbol.replace('-', '/').upper()
# Basic normalization: convert dashes to slashes, uppercase
normalized = symbol.replace('-', '/').upper()
# Try to use exchange's built-in symbol normalization
try:
if hasattr(self.exchange, 'markets') and self.exchange.markets:
# Check if normalized symbol exists in markets
if normalized in self.exchange.markets:
return normalized
# Try alternative formats
# For USD pairs, try USDT (common on many exchanges)
if '/USD' in normalized:
alt_symbol = normalized.replace('/USD', '/USDT')
if alt_symbol in self.exchange.markets:
return alt_symbol
# For Kraken: try XBT instead of BTC
if normalized.startswith('BTC/'):
alt_symbol = normalized.replace('BTC/', 'XBT/')
if alt_symbol in self.exchange.markets:
return alt_symbol
# Try to find similar symbol (fuzzy match)
base = normalized.split('/')[0] if '/' in normalized else normalized
quote = normalized.split('/')[1] if '/' in normalized else 'USD'
# Search for matching symbols
for market_symbol in self.exchange.markets.keys():
if market_symbol.startswith(base + '/') or market_symbol.endswith('/' + quote):
return market_symbol
except Exception as e:
pass
# Fallback: return normalized symbol (let CCXT handle errors)
return normalized
def get_fee_structure(self) -> Dict[str, float]:
"""Get fee structure for the selected exchange."""
# Default fees (approximate spot trading fees)
if not self.exchange or not hasattr(self.exchange, 'id'):
return {'maker': 0.001, 'taker': 0.001}
exchange_id = self.exchange.id if hasattr(self.exchange, 'id') else None
# Exchange-specific fees (approximate)
fee_map = {
'binance': {'maker': 0.001, 'taker': 0.001}, # 0.1%
'kraken': {'maker': 0.0016, 'taker': 0.0026}, # 0.16% / 0.26%
'coinbase': {'maker': 0.004, 'taker': 0.006}, # 0.4% / 0.6%
}
return fee_map.get(exchange_id, {'maker': 0.001, 'taker': 0.001})
def get_available_symbols(self) -> List[str]:
"""Get list of available trading symbols.
Returns:
List of available symbol pairs (e.g., ['BTC/USDT', 'ETH/USDT', ...])
"""
try:
markets = self.exchange.load_markets()
return list(markets.keys())
except Exception as e:
logger.error(f"Failed to get available symbols: {e}")
return []

View File

View File

@@ -0,0 +1,76 @@
"""Bayesian optimization."""
from typing import Dict, Any, Callable
from src.core.logger import get_logger
logger = get_logger(__name__)
class BayesianOptimizer:
"""Bayesian optimization using scikit-optimize."""
def __init__(self):
"""Initialize Bayesian optimizer."""
self.logger = get_logger(__name__)
def optimize(
self,
param_space: Dict[str, Any],
objective_function: Callable[[Dict[str, Any]], float],
n_calls: int = 50,
maximize: bool = True
) -> Dict[str, Any]:
"""Run Bayesian optimization.
Args:
param_space: Parameter space definition
objective_function: Objective function
n_calls: Number of optimization iterations
maximize: True to maximize, False to minimize
Returns:
Best parameters and score
"""
try:
from skopt import gp_minimize
from skopt.space import Real, Integer
# Convert param_space to skopt format
dimensions = []
param_names = []
for name, space in param_space.items():
param_names.append(name)
if isinstance(space, tuple):
if isinstance(space[0], int):
dimensions.append(Integer(space[0], space[1], name=name))
else:
dimensions.append(Real(space[0], space[1], name=name))
# Wrapper function
def objective(params):
param_dict = dict(zip(param_names, params))
score = objective_function(param_dict)
return -score if maximize else score
result = gp_minimize(
objective,
dimensions,
n_calls=n_calls,
random_state=42
)
best_params = dict(zip(param_names, result.x))
best_score = -result.fun if maximize else result.fun
return {
"best_params": best_params,
"best_score": best_score,
}
except ImportError:
logger.warning("scikit-optimize not available, using grid search fallback")
from .grid_search import GridSearchOptimizer
optimizer = GridSearchOptimizer()
# Convert param_space to grid format
param_grid = {k: [v[0], v[1], (v[0] + v[1]) / 2] for k, v in param_space.items()}
return optimizer.optimize(param_grid, objective_function, maximize)

111
src/optimization/genetic.py Normal file
View File

@@ -0,0 +1,111 @@
"""Genetic algorithm optimization."""
from typing import Dict, List, Any, Callable
import random
from src.core.logger import get_logger
logger = get_logger(__name__)
class GeneticOptimizer:
"""Genetic algorithm parameter optimization."""
def __init__(self):
"""Initialize genetic optimizer."""
self.logger = get_logger(__name__)
def optimize(
self,
param_ranges: Dict[str, tuple],
objective_function: Callable[[Dict[str, Any]], float],
population_size: int = 50,
generations: int = 100,
maximize: bool = True
) -> Dict[str, Any]:
"""Run genetic algorithm optimization.
Args:
param_ranges: Dictionary of parameter -> (min, max) range
objective_function: Objective function
population_size: Population size
generations: Number of generations
maximize: True to maximize, False to minimize
Returns:
Best parameters and score
"""
# Simplified genetic algorithm implementation
best_score = float('-inf') if maximize else float('inf')
best_params = None
# Initialize population
population = self._initialize_population(param_ranges, population_size)
for generation in range(generations):
# Evaluate fitness
fitness_scores = []
for individual in population:
try:
score = objective_function(individual)
fitness_scores.append((individual, score))
except Exception:
continue
# Sort by fitness
fitness_scores.sort(key=lambda x: x[1], reverse=maximize)
# Update best
if fitness_scores:
best_individual, best_gen_score = fitness_scores[0]
if (maximize and best_gen_score > best_score) or (not maximize and best_gen_score < best_score):
best_score = best_gen_score
best_params = best_individual.copy()
# Create next generation (simplified)
population = self._evolve_population(fitness_scores, param_ranges, population_size)
return {
"best_params": best_params,
"best_score": best_score,
}
def _initialize_population(self, param_ranges: Dict, size: int) -> List[Dict]:
"""Initialize random population."""
population = []
for _ in range(size):
individual = {}
for param, (min_val, max_val) in param_ranges.items():
if isinstance(min_val, int):
individual[param] = random.randint(min_val, max_val)
else:
individual[param] = random.uniform(min_val, max_val)
population.append(individual)
return population
def _evolve_population(self, fitness_scores: List, param_ranges: Dict, size: int) -> List[Dict]:
"""Evolve population (crossover and mutation)."""
# Simplified evolution
new_population = []
elite_size = size // 10
# Keep elite
for i in range(min(elite_size, len(fitness_scores))):
new_population.append(fitness_scores[i][0].copy())
# Generate rest through mutation
while len(new_population) < size:
parent = random.choice(fitness_scores[:size//2])[0]
child = parent.copy()
# Mutate
param = random.choice(list(param_ranges.keys()))
min_val, max_val = param_ranges[param]
if isinstance(min_val, int):
child[param] = random.randint(min_val, max_val)
else:
child[param] = random.uniform(min_val, max_val)
new_population.append(child)
return new_population[:size]

View File

@@ -0,0 +1,57 @@
"""Grid search optimization."""
from typing import Dict, List, Any, Callable
from itertools import product
from src.core.logger import get_logger
logger = get_logger(__name__)
class GridSearchOptimizer:
"""Grid search parameter optimization."""
def __init__(self):
"""Initialize grid search optimizer."""
self.logger = get_logger(__name__)
def optimize(
self,
param_grid: Dict[str, List[Any]],
objective_function: Callable[[Dict[str, Any]], float],
maximize: bool = True
) -> Dict[str, Any]:
"""Run grid search optimization.
Args:
param_grid: Dictionary of parameter -> list of values
objective_function: Function that takes parameters and returns score
maximize: True to maximize, False to minimize
Returns:
Best parameters and score
"""
best_score = float('-inf') if maximize else float('inf')
best_params = None
# Generate all parameter combinations
param_names = list(param_grid.keys())
param_values = list(param_grid.values())
for combination in product(*param_values):
params = dict(zip(param_names, combination))
try:
score = objective_function(params)
if (maximize and score > best_score) or (not maximize and score < best_score):
best_score = score
best_params = params
except Exception as e:
logger.warning(f"Failed to evaluate parameters {params}: {e}")
continue
return {
"best_params": best_params,
"best_score": best_score,
}

View File

265
src/portfolio/analytics.py Normal file
View File

@@ -0,0 +1,265 @@
"""Advanced portfolio analytics (Sharpe ratio, Sortino, drawdown analysis, performance charts)."""
import numpy as np
import pandas as pd
from decimal import Decimal
from typing import Dict, List, Any, Optional
from datetime import datetime, timedelta
from sqlalchemy import select
from src.core.database import get_database, PortfolioSnapshot, Trade
from src.core.logger import get_logger
from .tracker import get_portfolio_tracker
logger = get_logger(__name__)
class PortfolioAnalytics:
"""Advanced portfolio analytics."""
def __init__(self):
"""Initialize portfolio analytics."""
self.db = get_database()
self.tracker = get_portfolio_tracker()
self.logger = get_logger(__name__)
def calculate_sharpe_ratio(
self,
returns: pd.Series,
risk_free_rate: float = 0.0,
periods_per_year: int = 252
) -> float:
"""Calculate Sharpe ratio.
Args:
returns: Series of returns
risk_free_rate: Risk-free rate (annual)
periods_per_year: Trading periods per year
Returns:
Sharpe ratio
"""
if len(returns) == 0 or returns.std() == 0:
return 0.0
excess_returns = returns - (risk_free_rate / periods_per_year)
return float(np.sqrt(periods_per_year) * excess_returns.mean() / returns.std())
def calculate_sortino_ratio(
self,
returns: pd.Series,
risk_free_rate: float = 0.0,
periods_per_year: int = 252
) -> float:
"""Calculate Sortino ratio.
Args:
returns: Series of returns
risk_free_rate: Risk-free rate (annual)
periods_per_year: Trading periods per year
Returns:
Sortino ratio
"""
if len(returns) == 0:
return 0.0
excess_returns = returns - (risk_free_rate / periods_per_year)
downside_returns = returns[returns < 0]
if len(downside_returns) == 0 or downside_returns.std() == 0:
return 0.0
downside_std = downside_returns.std()
return float(np.sqrt(periods_per_year) * excess_returns.mean() / downside_std)
def calculate_drawdown(self, values: pd.Series) -> Dict[str, Any]:
"""Calculate drawdown metrics.
Args:
values: Series of portfolio values
Returns:
Dictionary with drawdown metrics
"""
if len(values) == 0:
return {
"max_drawdown": 0.0,
"current_drawdown": 0.0,
"drawdown_duration": 0,
}
# Calculate running maximum
running_max = values.expanding().max()
drawdown = (values - running_max) / running_max
max_drawdown = float(drawdown.min())
current_drawdown = float(drawdown.iloc[-1])
# Calculate drawdown duration
in_drawdown = drawdown < 0
drawdown_duration = 0
if in_drawdown.iloc[-1]:
# Count consecutive days in drawdown
for i in range(len(in_drawdown) - 1, -1, -1):
if in_drawdown.iloc[i]:
drawdown_duration += 1
else:
break
return {
"max_drawdown": max_drawdown,
"current_drawdown": current_drawdown,
"drawdown_duration": drawdown_duration,
"drawdown_series": drawdown.tolist(),
}
async def get_performance_metrics(
self,
days: int = 30,
paper_trading: bool = True
) -> Dict[str, Any]:
"""Get comprehensive performance metrics.
Args:
days: Number of days to analyze
paper_trading: Paper trading flag
Returns:
Dictionary of performance metrics
"""
history = await self.tracker.get_portfolio_history(days, paper_trading)
if len(history) < 2:
return {
"total_return": 0.0,
"sharpe_ratio": 0.0,
"sortino_ratio": 0.0,
"max_drawdown": 0.0,
"win_rate": 0.0,
}
# Convert to DataFrame
df = pd.DataFrame(history)
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.set_index('timestamp').sort_index()
# Calculate returns
returns = df['total_value'].pct_change().dropna()
# Calculate metrics
initial_value = df['total_value'].iloc[0]
final_value = df['total_value'].iloc[-1]
total_return = (final_value - initial_value) / initial_value
sharpe = self.calculate_sharpe_ratio(returns)
sortino = self.calculate_sortino_ratio(returns)
drawdown = self.calculate_drawdown(df['total_value'])
# Calculate win rate from trades
win_rate = await self._calculate_win_rate(days, paper_trading)
# Calculate fee metrics
fee_metrics = await self._calculate_fee_metrics(days, paper_trading, initial_value)
return {
"total_return": float(total_return),
"total_return_percent": float(total_return * 100),
"sharpe_ratio": sharpe,
"sortino_ratio": sortino,
"max_drawdown": drawdown["max_drawdown"],
"current_drawdown": drawdown["current_drawdown"],
"win_rate": win_rate,
"initial_value": float(initial_value),
"final_value": float(final_value),
**fee_metrics, # Include fee metrics
}
async def _calculate_fee_metrics(
self,
days: int,
paper_trading: bool,
initial_value: float
) -> Dict[str, float]:
"""Calculate fee-related metrics.
Args:
days: Number of days
paper_trading: Paper trading flag
initial_value: Initial portfolio value
Returns:
Dictionary of fee metrics
"""
try:
async with self.db.get_session() as session:
since = datetime.utcnow() - timedelta(days=days)
stmt = select(Trade).where(
Trade.paper_trading == paper_trading,
Trade.timestamp >= since
)
result = await session.execute(stmt)
trades = result.scalars().all()
total_fees = sum(float(trade.fee or 0) for trade in trades)
total_trades = len(trades)
avg_fee_per_trade = total_fees / total_trades if total_trades > 0 else 0.0
# Calculate fee percentage of initial value
fee_percentage = (total_fees / initial_value * 100) if initial_value > 0 else 0.0
return {
"total_fees": total_fees,
"avg_fee_per_trade": avg_fee_per_trade,
"fee_percentage": fee_percentage,
"total_trades_with_fees": total_trades,
}
except Exception as e:
logger.warning(f"Error calculating fee metrics: {e}")
return {
"total_fees": 0.0,
"avg_fee_per_trade": 0.0,
"fee_percentage": 0.0,
"total_trades_with_fees": 0,
}
async def _calculate_win_rate(self, days: int, paper_trading: bool) -> float:
"""Calculate win rate from trades.
Args:
days: Number of days
paper_trading: Paper trading flag
Returns:
Win rate (0.0 to 1.0)
"""
try:
async with self.db.get_session() as session:
since = datetime.utcnow() - timedelta(days=days)
stmt = select(Trade).where(
Trade.paper_trading == paper_trading,
Trade.timestamp >= since
)
result = await session.execute(stmt)
trades = result.scalars().all()
if len(trades) == 0:
return 0.0
# Simplified win rate calculation
# In practice, would need to match buy/sell pairs
return 0.5 # Placeholder
except Exception as e:
logger.warning(f"Error calculating win rate: {e}")
return 0.0
# Global portfolio analytics
_portfolio_analytics: Optional[PortfolioAnalytics] = None
def get_portfolio_analytics() -> PortfolioAnalytics:
"""Get global portfolio analytics instance."""
global _portfolio_analytics
if _portfolio_analytics is None:
_portfolio_analytics = PortfolioAnalytics()
return _portfolio_analytics

144
src/portfolio/tracker.py Normal file
View File

@@ -0,0 +1,144 @@
"""Portfolio tracking with real-time P&L calculation."""
from decimal import Decimal
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from sqlalchemy import select
from src.core.database import get_database, Position, PortfolioSnapshot, Trade
from src.core.logger import get_logger
from src.trading.paper_trading import get_paper_trading
logger = get_logger(__name__)
class PortfolioTracker:
"""Tracks portfolio with real-time P&L calculation."""
def __init__(self):
"""Initialize portfolio tracker."""
self.db = get_database()
self.paper_trading = get_paper_trading()
self.logger = get_logger(__name__)
async def get_current_portfolio(self, paper_trading: bool = True) -> Dict[str, Any]:
"""Get current portfolio state.
Args:
paper_trading: Paper trading flag
Returns:
Portfolio dictionary
"""
if paper_trading:
performance = self.paper_trading.get_performance()
positions = self.paper_trading.get_positions()
else:
# Live trading - get from database
async with self.db.get_session() as session:
stmt = select(Position).where(Position.paper_trading == False)
result = await session.execute(stmt)
positions = result.scalars().all()
# Calculate performance from positions
total_value = Decimal(0)
unrealized_pnl = Decimal(0)
for pos in positions:
if pos.current_price:
pos_value = pos.quantity * pos.current_price
total_value += pos_value
unrealized_pnl += (pos.current_price - pos.entry_price) * pos.quantity
performance = {
"current_value": float(total_value),
"unrealized_pnl": float(unrealized_pnl),
"realized_pnl": float(sum(pos.realized_pnl for pos in positions)),
}
return {
"positions": [
{
"symbol": pos.symbol if hasattr(pos, 'symbol') else pos.symbol,
"quantity": float(pos.quantity),
"entry_price": float(pos.entry_price),
"current_price": float(pos.current_price) if pos.current_price else float(pos.entry_price),
"unrealized_pnl": float(pos.unrealized_pnl) if hasattr(pos, 'unrealized_pnl') else 0.0,
}
for pos in positions
],
"performance": performance,
"timestamp": datetime.utcnow().isoformat(),
}
async def update_positions_prices(self, prices: Dict[str, Decimal], paper_trading: bool = True):
"""Update current prices for positions.
Args:
prices: Dictionary of symbol -> current_price
paper_trading: Paper trading flag
"""
if paper_trading:
await self.paper_trading.update_positions_prices(prices)
else:
async with self.db.get_session() as session:
try:
stmt = select(Position).where(Position.paper_trading == False)
result = await session.execute(stmt)
positions = result.scalars().all()
for pos in positions:
if pos.symbol in prices:
pos.current_price = prices[pos.symbol]
pos.unrealized_pnl = (prices[pos.symbol] - pos.entry_price) * pos.quantity
pos.updated_at = datetime.utcnow()
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Failed to update positions prices: {e}")
async def get_portfolio_history(
self,
days: int = 30,
paper_trading: bool = True
) -> List[Dict[str, Any]]:
"""Get portfolio history.
Args:
days: Number of days
paper_trading: Paper trading flag
Returns:
List of portfolio snapshots
"""
async with self.db.get_session() as session:
since = datetime.utcnow() - timedelta(days=days)
stmt = select(PortfolioSnapshot).where(
PortfolioSnapshot.paper_trading == paper_trading,
PortfolioSnapshot.timestamp >= since
).order_by(PortfolioSnapshot.timestamp)
result = await session.execute(stmt)
snapshots = result.scalars().all()
return [
{
"timestamp": snapshot.timestamp.isoformat(),
"total_value": float(snapshot.total_value),
"cash": float(snapshot.cash),
"positions_value": float(snapshot.positions_value),
"unrealized_pnl": float(snapshot.unrealized_pnl),
"realized_pnl": float(snapshot.realized_pnl),
}
for snapshot in snapshots
]
# Global portfolio tracker
_portfolio_tracker: Optional[PortfolioTracker] = None
def get_portfolio_tracker() -> PortfolioTracker:
"""Get global portfolio tracker instance."""
global _portfolio_tracker
if _portfolio_tracker is None:
_portfolio_tracker = PortfolioTracker()
return _portfolio_tracker

View File

196
src/rebalancing/engine.py Normal file
View File

@@ -0,0 +1,196 @@
"""Portfolio rebalancing engine."""
from decimal import Decimal
from typing import Dict, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from src.core.database import get_database, RebalancingEvent
from src.core.logger import get_logger
from src.portfolio.tracker import get_portfolio_tracker
from src.trading.engine import get_trading_engine
logger = get_logger(__name__)
class RebalancingEngine:
"""Portfolio rebalancing engine."""
def __init__(self):
"""Initialize rebalancing engine."""
self.db = get_database()
self.tracker = get_portfolio_tracker()
self.trading_engine = get_trading_engine()
self.logger = get_logger(__name__)
def rebalance(
self,
target_allocations: Dict[str, float],
exchange_id: int,
paper_trading: bool = True
) -> bool:
"""Rebalance portfolio to target allocations.
Args:
target_allocations: Dictionary of symbol -> target percentage
exchange_id: Exchange ID
paper_trading: Paper trading flag
Returns:
True if rebalancing successful
"""
try:
# Get current portfolio
portfolio = self.tracker.get_current_portfolio(paper_trading)
total_value = portfolio['performance']['current_value']
# Calculate current allocations
current_allocations = {}
for pos in portfolio['positions']:
pos_value = pos['quantity'] * pos['current_price']
current_allocations[pos['symbol']] = float(pos_value / total_value) if total_value > 0 else 0.0
# Get exchange adapter for fee calculations
adapter = await self.trading_engine.get_exchange_adapter(exchange_id)
# Calculate required trades, factoring in fees
orders = []
from src.trading.fee_calculator import get_fee_calculator
fee_calculator = get_fee_calculator()
# Get fee threshold from config (default 0.5% to account for round-trip fees)
fee_threshold = Decimal(str(self.tracker.db.get_session().query(
# Get from config
))) if False else Decimal("0.005") # 0.5% default threshold
for symbol, target_pct in target_allocations.items():
current_pct = current_allocations.get(symbol, 0.0)
deviation = target_pct - current_pct
# Only rebalance if deviation exceeds fee threshold
# Default threshold is 1%, but we'll use a configurable fee-aware threshold
min_deviation = max(Decimal("0.01"), fee_threshold) # At least 1% or fee threshold
if abs(deviation) > min_deviation:
target_value = total_value * Decimal(str(target_pct))
current_value = Decimal(str(current_allocations.get(symbol, 0.0))) * total_value
trade_value = target_value - current_value
# Get current price
if adapter:
ticker = await adapter.get_ticker(symbol)
price = ticker.get('last', Decimal(0))
if price > 0:
# Estimate fee for this trade
estimated_quantity = abs(trade_value / price)
estimated_fee = fee_calculator.estimate_round_trip_fee(
quantity=estimated_quantity,
price=price,
exchange_adapter=adapter
)
# Adjust trade value to account for fees
# For buy: reduce quantity to account for fee
# For sell: fee comes from proceeds
if trade_value > 0: # Buy
# Reduce trade value by estimated fee
adjusted_trade_value = trade_value - estimated_fee
quantity = adjusted_trade_value / price if price > 0 else Decimal(0)
else: # Sell
# Fee comes from proceeds, so quantity stays the same
quantity = abs(trade_value / price)
if quantity > 0:
side = 'buy' if trade_value > 0 else 'sell'
orders.append({
'symbol': symbol,
'side': side,
'quantity': quantity,
'price': price,
})
# Execute rebalancing orders
executed_orders = []
for order in orders:
from src.core.database import OrderSide, OrderType
side = OrderSide.BUY if order['side'] == 'buy' else OrderSide.SELL
result = await self.trading_engine.execute_order(
exchange_id=exchange_id,
strategy_id=None,
symbol=order['symbol'],
side=side,
order_type=OrderType.MARKET,
quantity=order['quantity'],
paper_trading=paper_trading
)
if result:
executed_orders.append(result.id)
# Record rebalancing event
self._record_rebalancing_event(
'manual',
target_allocations,
current_allocations,
executed_orders
)
return True
except Exception as e:
logger.error(f"Failed to rebalance portfolio: {e}")
return False
def _record_rebalancing_event(
self,
trigger_type: str,
target_allocations: Dict[str, float],
before_allocations: Dict[str, float],
orders_placed: List[int]
):
"""Record rebalancing event in database.
Args:
trigger_type: Trigger type
target_allocations: Target allocations
before_allocations: Allocations before rebalancing
orders_placed: List of order IDs
"""
session = self.db.get_session()
try:
# Get after allocations
portfolio = self.tracker.get_current_portfolio()
total_value = portfolio['performance']['current_value']
after_allocations = {}
for pos in portfolio['positions']:
pos_value = pos['quantity'] * pos['current_price']
after_allocations[pos['symbol']] = float(pos_value / total_value) if total_value > 0 else 0.0
event = RebalancingEvent(
trigger_type=trigger_type,
target_allocations=target_allocations,
before_allocations=before_allocations,
after_allocations=after_allocations,
orders_placed=orders_placed,
timestamp=datetime.utcnow()
)
session.add(event)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to record rebalancing event: {e}")
finally:
session.close()
# Global rebalancing engine
_rebalancing_engine: Optional[RebalancingEngine] = None
def get_rebalancing_engine() -> RebalancingEngine:
"""Get global rebalancing engine instance."""
global _rebalancing_engine
if _rebalancing_engine is None:
_rebalancing_engine = RebalancingEngine()
return _rebalancing_engine

View File

@@ -0,0 +1,36 @@
"""Rebalancing strategies."""
from typing import Dict
from decimal import Decimal
class RebalancingStrategy:
"""Base rebalancing strategy."""
def calculate_trades(
self,
current_allocations: Dict[str, float],
target_allocations: Dict[str, float],
total_value: Decimal
) -> Dict[str, Decimal]:
"""Calculate required trades.
Args:
current_allocations: Current allocations
target_allocations: Target allocations
total_value: Total portfolio value
Returns:
Dictionary of symbol -> trade value (positive = buy, negative = sell)
"""
trades = {}
for symbol, target_pct in target_allocations.items():
current_pct = current_allocations.get(symbol, 0.0)
deviation = target_pct - current_pct
if abs(deviation) > 0.01: # 1% threshold
target_value = total_value * Decimal(str(target_pct))
current_value = Decimal(str(current_pct)) * total_value
trades[symbol] = target_value - current_value
return trades

View File

View File

@@ -0,0 +1,120 @@
"""CSV export functionality."""
import csv
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from src.core.database import get_database, Trade, Order
from src.core.logger import get_logger
logger = get_logger(__name__)
class CSVExporter:
"""CSV export functionality."""
def __init__(self):
"""Initialize CSV exporter."""
self.db = get_database()
self.logger = get_logger(__name__)
def export_trades(
self,
filepath: Path,
paper_trading: bool = True,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> bool:
"""Export trades to CSV.
Args:
filepath: Output file path
paper_trading: Filter by paper trading
start_date: Start date filter
end_date: End date filter
Returns:
True if export successful
"""
session = self.db.get_session()
try:
query = session.query(Trade).filter_by(paper_trading=paper_trading)
if start_date:
query = query.filter(Trade.timestamp >= start_date)
if end_date:
query = query.filter(Trade.timestamp <= end_date)
trades = query.order_by(Trade.timestamp).all()
with open(filepath, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow([
'Timestamp', 'Symbol', 'Side', 'Quantity', 'Price', 'Fee', 'Total'
])
for trade in trades:
writer.writerow([
trade.timestamp.isoformat(),
trade.symbol,
trade.side.value,
float(trade.quantity),
float(trade.price),
float(trade.fee),
float(trade.total),
])
logger.info(f"Exported {len(trades)} trades to {filepath}")
return True
except Exception as e:
logger.error(f"Failed to export trades: {e}")
return False
finally:
session.close()
def export_portfolio(self, filepath: Path) -> bool:
"""Export portfolio snapshot to CSV.
Args:
filepath: Output file path
Returns:
True if export successful
"""
from src.portfolio.tracker import get_portfolio_tracker
tracker = get_portfolio_tracker()
portfolio = tracker.get_current_portfolio()
try:
with open(filepath, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['Symbol', 'Quantity', 'Entry Price', 'Current Price', 'Unrealized P&L'])
for pos in portfolio['positions']:
writer.writerow([
pos['symbol'],
pos['quantity'],
pos['entry_price'],
pos['current_price'],
pos['unrealized_pnl'],
])
logger.info(f"Exported portfolio to {filepath}")
return True
except Exception as e:
logger.error(f"Failed to export portfolio: {e}")
return False
# Global CSV exporter
_csv_exporter: Optional[CSVExporter] = None
def get_csv_exporter() -> CSVExporter:
"""Get global CSV exporter instance."""
global _csv_exporter
if _csv_exporter is None:
_csv_exporter = CSVExporter()
return _csv_exporter

View File

@@ -0,0 +1,111 @@
"""PDF report generation."""
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Optional
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors
from src.core.logger import get_logger
logger = get_logger(__name__)
class PDFGenerator:
"""PDF report generation."""
def __init__(self):
"""Initialize PDF generator."""
self.logger = get_logger(__name__)
def generate_performance_report(
self,
filepath: Path,
metrics: Dict[str, Any],
title: str = "Portfolio Performance Report"
) -> bool:
"""Generate performance report PDF.
Args:
filepath: Output file path
metrics: Performance metrics dictionary
title: Report title
Returns:
True if generation successful
"""
try:
doc = SimpleDocTemplate(str(filepath), pagesize=letter)
story = []
styles = getSampleStyleSheet()
# Title
story.append(Paragraph(title, styles['Title']))
story.append(Spacer(1, 12))
# Metrics table
data = [
['Metric', 'Value'],
['Total Return', f"{metrics.get('total_return_percent', 0):.2f}%"],
['Sharpe Ratio', f"{metrics.get('sharpe_ratio', 0):.2f}"],
['Sortino Ratio', f"{metrics.get('sortino_ratio', 0):.2f}"],
['Max Drawdown', f"{metrics.get('max_drawdown', 0):.2%}"],
['Win Rate', f"{metrics.get('win_rate', 0):.2%}"],
]
table = Table(data)
table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 14),
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
('BACKGROUND', (0, 1), (-1, -1), colors.beige),
('GRID', (0, 0), (-1, -1), 1, colors.black),
]))
story.append(table)
story.append(Spacer(1, 12))
story.append(Paragraph(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", styles['Normal']))
doc.build(story)
logger.info(f"Generated PDF report: {filepath}")
return True
except Exception as e:
logger.error(f"Failed to generate PDF report: {e}")
return False
def generate_backtest_report(
self,
results: Dict[str, Any],
filepath: str
) -> bool:
"""Generate backtest report PDF.
Args:
results: Backtest results dictionary
filepath: Output file path
Returns:
True if generation successful
"""
return self.generate_performance_report(
Path(filepath),
results,
"Backtest Report"
)
# Global PDF generator
_pdf_generator: Optional[PDFGenerator] = None
def get_pdf_generator() -> PDFGenerator:
"""Get global PDF generator instance."""
global _pdf_generator
if _pdf_generator is None:
_pdf_generator = PDFGenerator()
return _pdf_generator

View File

@@ -0,0 +1,127 @@
"""Tax reporting (FIFO, LIFO, specific identification)."""
from decimal import Decimal
from datetime import datetime
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from src.core.database import get_database, Trade
from src.core.logger import get_logger
logger = get_logger(__name__)
class TaxReporter:
"""Tax reporting with different cost basis methods."""
def __init__(self):
"""Initialize tax reporter."""
self.db = get_database()
self.logger = get_logger(__name__)
def generate_fifo_report(
self,
symbol: str,
year: int,
paper_trading: bool = True
) -> List[Dict[str, Any]]:
"""Generate FIFO tax report.
Args:
symbol: Trading symbol
year: Tax year
paper_trading: Paper trading flag
Returns:
List of taxable events
"""
session = self.db.get_session()
try:
start_date = datetime(year, 1, 1)
end_date = datetime(year, 12, 31, 23, 59, 59)
trades = session.query(Trade).filter(
Trade.symbol == symbol,
Trade.paper_trading == paper_trading,
Trade.timestamp >= start_date,
Trade.timestamp <= end_date
).order_by(Trade.timestamp).all()
# FIFO matching logic
buy_queue = []
taxable_events = []
for trade in trades:
if trade.side.value == 'buy':
buy_queue.append({
'date': trade.timestamp,
'quantity': trade.quantity,
'price': trade.price,
'fee': trade.fee,
})
else: # sell
remaining = trade.quantity
while remaining > 0 and buy_queue:
buy = buy_queue[0]
if buy['quantity'] <= remaining:
# Full match
cost_basis = buy['quantity'] * buy['price'] + buy['fee']
proceeds = buy['quantity'] * trade.price - (buy['quantity'] / trade.quantity) * trade.fee
gain_loss = proceeds - cost_basis
taxable_events.append({
'date': trade.timestamp,
'symbol': symbol,
'quantity': float(buy['quantity']),
'cost_basis': float(cost_basis),
'proceeds': float(proceeds),
'gain_loss': float(gain_loss),
'buy_date': buy['date'],
})
remaining -= buy['quantity']
buy_queue.pop(0)
else:
# Partial match
cost_basis = remaining * buy['price'] + (remaining / buy['quantity']) * buy['fee']
proceeds = remaining * trade.price - (remaining / trade.quantity) * trade.fee
gain_loss = proceeds - cost_basis
taxable_events.append({
'date': trade.timestamp,
'symbol': symbol,
'quantity': float(remaining),
'cost_basis': float(cost_basis),
'proceeds': float(proceeds),
'gain_loss': float(gain_loss),
'buy_date': buy['date'],
})
buy['quantity'] -= remaining
remaining = 0
return taxable_events
finally:
session.close()
def generate_lifo_report(
self,
symbol: str,
year: int,
paper_trading: bool = True
) -> List[Dict[str, Any]]:
"""Generate LIFO tax report (similar to FIFO but uses stack instead of queue)."""
# Similar to FIFO but uses stack (LIFO)
return self.generate_fifo_report(symbol, year, paper_trading) # Simplified
# Global tax reporter
_tax_reporter: Optional[TaxReporter] = None
def get_tax_reporter() -> TaxReporter:
"""Get global tax reporter instance."""
global _tax_reporter
if _tax_reporter is None:
_tax_reporter = TaxReporter()
return _tax_reporter

View File

View File

@@ -0,0 +1,100 @@
"""Health monitoring and self-healing."""
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from src.core.logger import get_logger
logger = get_logger(__name__)
class HealthMonitor:
"""Monitors system health."""
def __init__(self):
"""Initialize health monitor."""
self.logger = get_logger(__name__)
self.errors: List[Dict] = []
self.connections: Dict[str, bool] = {}
self.last_check: Optional[datetime] = None
def record_error(self, context: str, error: Exception):
"""Record an error.
Args:
context: Error context
error: Exception object
"""
self.errors.append({
"timestamp": datetime.utcnow(),
"context": context,
"error": str(error),
})
# Keep only last 100 errors
if len(self.errors) > 100:
self.errors = self.errors[-100:]
def check_health(self) -> Dict[str, bool]:
"""Check system health.
Returns:
Dictionary of component -> healthy
"""
health = {
"database": self._check_database(),
"exchanges": self._check_exchanges(),
}
self.last_check = datetime.utcnow()
return health
def _check_database(self) -> bool:
"""Check database connection.
Returns:
True if healthy
"""
try:
from src.core.database import get_database
db = get_database()
session = db.get_session()
session.execute("SELECT 1")
session.close()
return True
except Exception:
return False
def _check_exchanges(self) -> bool:
"""Check exchange connections.
Returns:
True if at least one exchange is connected
"""
# Simplified - would check actual connections
return len(self.connections) > 0
def get_error_rate(self, minutes: int = 60) -> float:
"""Get error rate over time period.
Args:
minutes: Time period in minutes
Returns:
Errors per minute
"""
since = datetime.utcnow() - timedelta(minutes=minutes)
recent_errors = [e for e in self.errors if e["timestamp"] >= since]
return len(recent_errors) / minutes if minutes > 0 else 0.0
# Global health monitor
_health_monitor: Optional[HealthMonitor] = None
def get_health_monitor() -> HealthMonitor:
"""Get global health monitor instance."""
global _health_monitor
if _health_monitor is None:
_health_monitor = HealthMonitor()
return _health_monitor

103
src/resilience/recovery.py Normal file
View File

@@ -0,0 +1,103 @@
"""Error recovery mechanisms."""
import traceback
from typing import Optional, Dict, Any
from src.core.logger import get_logger
from .state_manager import get_state_manager
from .health_monitor import get_health_monitor
logger = get_logger(__name__)
class RecoveryManager:
"""Manages error recovery and system resilience."""
def __init__(self):
"""Initialize recovery manager."""
self.state_manager = get_state_manager()
self.health_monitor = get_health_monitor()
self.logger = get_logger(__name__)
def handle_error(self, error: Exception, context: str = ""):
"""Handle errors with recovery.
Args:
error: Exception object
context: Error context
"""
error_msg = f"{context}: {str(error)}"
logger.error(error_msg)
logger.debug(traceback.format_exc())
# Save error state
self.state_manager.save_state("last_error", {
"message": str(error),
"context": context,
"traceback": traceback.format_exc(),
})
# Update health monitor
self.health_monitor.record_error(context, error)
def recover_orders(self) -> bool:
"""Recover order state after disconnection.
Returns:
True if recovery successful
"""
try:
from src.trading.order_manager import get_order_manager
order_manager = get_order_manager()
# Get pending/open orders
open_orders = order_manager.get_open_orders()
# Try to sync with exchange
for order in open_orders:
# Would sync order status with exchange
pass
return True
except Exception as e:
self.logger.error(f"Failed to recover orders: {e}")
return False
def recover_connections(self) -> bool:
"""Recover exchange connections.
Returns:
True if recovery successful
"""
try:
from src.exchanges.factory import ExchangeFactory
from src.core.database import get_database, Exchange
db = get_database()
session = db.get_session()
try:
exchanges = session.query(Exchange).filter_by(enabled=True).all()
for exchange in exchanges:
adapter = ExchangeFactory.create(exchange.id)
if adapter:
self.logger.info(f"Recovered connection to {exchange.name}")
finally:
session.close()
return True
except Exception as e:
self.logger.error(f"Failed to recover connections: {e}")
return False
# Global recovery manager
_recovery_manager: Optional[RecoveryManager] = None
def get_recovery_manager() -> RecoveryManager:
"""Get global recovery manager instance."""
global _recovery_manager
if _recovery_manager is None:
_recovery_manager = RecoveryManager()
return _recovery_manager

View File

@@ -0,0 +1,90 @@
"""State persistence for recovery."""
import json
from datetime import datetime
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session
from src.core.database import get_database, AppState
from src.core.logger import get_logger
logger = get_logger(__name__)
class StateManager:
"""Manages application state persistence."""
def __init__(self):
"""Initialize state manager."""
self.db = get_database()
self.logger = get_logger(__name__)
def save_state(self, key: str, value: Any):
"""Save application state.
Args:
key: State key
value: State value (must be JSON serializable)
"""
session = self.db.get_session()
try:
existing = session.query(AppState).filter_by(key=key).first()
if existing:
existing.value = value
existing.updated_at = datetime.utcnow()
else:
state = AppState(key=key, value=value, updated_at=datetime.utcnow())
session.add(state)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to save state {key}: {e}")
finally:
session.close()
def load_state(self, key: str, default: Any = None) -> Any:
"""Load application state.
Args:
key: State key
default: Default value if not found
Returns:
State value or default
"""
session = self.db.get_session()
try:
state = session.query(AppState).filter_by(key=key).first()
return state.value if state else default
finally:
session.close()
def clear_state(self, key: str):
"""Clear application state.
Args:
key: State key
"""
session = self.db.get_session()
try:
state = session.query(AppState).filter_by(key=key).first()
if state:
session.delete(state)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to clear state {key}: {e}")
finally:
session.close()
# Global state manager
_state_manager: Optional[StateManager] = None
def get_state_manager() -> StateManager:
"""Get global state manager instance."""
global _state_manager
if _state_manager is None:
_state_manager = StateManager()
return _state_manager

0
src/risk/__init__.py Normal file
View File

166
src/risk/limits.py Normal file
View File

@@ -0,0 +1,166 @@
"""Drawdown and loss limits."""
from decimal import Decimal
from datetime import datetime, timedelta
from typing import Dict, Any
from sqlalchemy.orm import Session
from src.core.database import get_database, PortfolioSnapshot, Trade
from src.core.config import get_config
from src.core.logger import get_logger
logger = get_logger(__name__)
class RiskLimitsManager:
"""Manages risk limits (drawdown, daily loss, etc.)."""
def __init__(self):
"""Initialize risk limits manager."""
self.db = get_database()
self.config = get_config()
self.logger = get_logger(__name__)
async def check_daily_loss_limit(self) -> bool:
"""Check if daily loss limit is exceeded."""
daily_loss_limit = Decimal(str(
self.config.get("risk.daily_loss_limit_percent", 5.0)
)) / 100
daily_pnl = await self.get_daily_pnl()
portfolio_value = await self.get_portfolio_value()
if portfolio_value == 0:
return True
daily_loss_percent = abs(daily_pnl) / portfolio_value if daily_pnl < 0 else Decimal(0)
if daily_loss_percent > daily_loss_limit:
self.logger.warning(f"Daily loss limit exceeded: {daily_loss_percent:.2%} > {daily_loss_limit:.2%}")
return False
return True
async def check_max_drawdown(self) -> bool:
"""Check if maximum drawdown limit is exceeded."""
max_drawdown_limit = Decimal(str(
self.config.get("risk.max_drawdown_percent", 20.0)
)) / 100
current_drawdown = await self.get_current_drawdown()
if current_drawdown > max_drawdown_limit:
self.logger.warning(f"Max drawdown limit exceeded: {current_drawdown:.2%} > {max_drawdown_limit:.2%}")
return False
return True
async def check_portfolio_allocation(self, symbol: str, position_value: Decimal) -> bool:
"""Check portfolio allocation limits."""
# Default: max 20% per asset
max_allocation_percent = Decimal("0.20")
portfolio_value = await self.get_portfolio_value()
if portfolio_value == 0:
return True
allocation = position_value / portfolio_value
if allocation > max_allocation_percent:
self.logger.warning(f"Portfolio allocation exceeded for {symbol}: {allocation:.2%}")
return False
return True
async def get_daily_pnl(self) -> Decimal:
"""Get today's P&L."""
from sqlalchemy import select
async with self.db.get_session() as session:
try:
today = datetime.utcnow().date()
start_of_day = datetime.combine(today, datetime.min.time())
# Get trades from today
stmt = select(Trade).where(
Trade.timestamp >= start_of_day,
Trade.paper_trading == True
)
result = await session.execute(stmt)
trades = result.scalars().all()
# Calculate P&L from trades
pnl = Decimal(0)
for trade in trades:
trade_value = trade.quantity * trade.price
fee = trade.fee if trade.fee else Decimal(0)
if trade.side.value == "sell":
pnl += trade_value - fee
else:
pnl -= trade_value + fee
return pnl
except Exception as e:
self.logger.error(f"Error calculating daily P&L: {e}")
return Decimal(0)
async def get_current_drawdown(self) -> Decimal:
"""Get current drawdown percentage."""
from sqlalchemy import select
async with self.db.get_session() as session:
try:
# Get peak portfolio value
stmt_peak = select(PortfolioSnapshot).order_by(
PortfolioSnapshot.total_value.desc()
).limit(1)
result_peak = await session.execute(stmt_peak)
peak = result_peak.scalar_one_or_none()
if not peak:
return Decimal(0)
# Get current value
stmt_current = select(PortfolioSnapshot).order_by(
PortfolioSnapshot.timestamp.desc()
).limit(1)
result_current = await session.execute(stmt_current)
current = result_current.scalar_one_or_none()
if not current or current.total_value >= peak.total_value:
return Decimal(0)
drawdown = (peak.total_value - current.total_value) / peak.total_value
return drawdown
except Exception as e:
self.logger.error(f"Error calculating drawdown: {e}")
return Decimal(0)
async def get_portfolio_value(self) -> Decimal:
"""Get current portfolio value."""
from sqlalchemy import select
async with self.db.get_session() as session:
try:
stmt = select(PortfolioSnapshot).order_by(
PortfolioSnapshot.timestamp.desc()
).limit(1)
result = await session.execute(stmt)
latest = result.scalar_one_or_none()
if latest:
return latest.total_value
return Decimal(0)
except Exception as e:
self.logger.error(f"Error getting portfolio value: {e}")
return Decimal(0)
def get_all_limits(self) -> Dict[str, Any]:
"""Get all risk limits configuration.
Returns:
Dictionary of limit configurations
"""
return {
'max_drawdown_percent': self.config.get("risk.max_drawdown_percent", 20.0),
'daily_loss_limit_percent': self.config.get("risk.daily_loss_limit_percent", 5.0),
'position_size_percent': self.config.get("risk.position_size_percent", 2.0),
}

91
src/risk/manager.py Normal file
View File

@@ -0,0 +1,91 @@
"""Risk management engine with stop-loss, position sizing, drawdown limits, and daily loss limits."""
from decimal import Decimal
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from src.core.database import get_database, RiskLimit, Trade, PortfolioSnapshot
from src.core.config import get_config
from src.core.logger import get_logger
from .stop_loss import StopLossManager
from .position_sizing import PositionSizingManager
from .limits import RiskLimitsManager
logger = get_logger(__name__)
class RiskManager:
"""Comprehensive risk management engine."""
def __init__(self):
"""Initialize risk manager."""
self.db = get_database()
self.config = get_config()
self.stop_loss = StopLossManager()
self.position_sizing = PositionSizingManager()
self.limits = RiskLimitsManager()
self.logger = get_logger(__name__)
async def check_order_risk(
self,
symbol: str,
side: str,
quantity: Decimal,
price: Decimal,
current_balance: Decimal,
exchange_adapter=None
) -> tuple[bool, Optional[str]]:
"""Check if order passes risk checks, including fee validation."""
# Check position sizing (includes fee validation)
if not self.position_sizing.validate_position_size(
symbol, quantity, price, current_balance, exchange_adapter
):
return False, "Position size exceeds limits (including fees)"
# Check daily loss limit
if not await self.limits.check_daily_loss_limit():
return False, "Daily loss limit reached"
# Check maximum drawdown
if not await self.limits.check_max_drawdown():
return False, "Maximum drawdown limit reached"
# Check portfolio allocation
if not await self.limits.check_portfolio_allocation(symbol, quantity * price):
return False, "Portfolio allocation limit exceeded"
return True, None
# ... calc position size omitted as it relies on sync position sizing ...
# Wait, calculate_position_size uses position_sizing which is sync.
async def check_limits(self) -> Dict[str, bool]:
"""Check all risk limits."""
return {
'daily_loss': await self.limits.check_daily_loss_limit(),
'max_drawdown': await self.limits.check_max_drawdown(),
'position_size': True, # Checked per order
'portfolio_allocation': True, # Checked per order
}
async def get_risk_metrics(self) -> Dict[str, Any]:
"""Get current risk metrics."""
return {
'daily_pnl': await self.limits.get_daily_pnl(),
'max_drawdown': await self.limits.get_current_drawdown(),
'portfolio_value': await self.limits.get_portfolio_value(),
'risk_limits': self.limits.get_all_limits(), # This call is sync
}
# Global risk manager
_risk_manager: Optional[RiskManager] = None
def get_risk_manager() -> RiskManager:
"""Get global risk manager instance."""
global _risk_manager
if _risk_manager is None:
_risk_manager = RiskManager()
return _risk_manager

144
src/risk/position_sizing.py Normal file
View File

@@ -0,0 +1,144 @@
"""Position sizing rules."""
from decimal import Decimal
from typing import Optional
from src.core.config import get_config
from src.core.logger import get_logger
from src.exchanges.base import BaseExchangeAdapter
logger = get_logger(__name__)
class PositionSizingManager:
"""Manages position sizing calculations."""
def __init__(self):
"""Initialize position sizing manager."""
self.config = get_config()
self.logger = get_logger(__name__)
def calculate_size(
self,
symbol: str,
price: Decimal,
balance: Decimal,
risk_percent: Optional[Decimal] = None,
exchange_adapter: Optional[BaseExchangeAdapter] = None
) -> Decimal:
"""Calculate position size, accounting for fees.
Args:
symbol: Trading symbol
price: Entry price
balance: Available balance
risk_percent: Risk percentage (uses config default if None)
exchange_adapter: Exchange adapter for fee calculation (optional)
Returns:
Calculated position size
"""
if risk_percent is None:
risk_percent = Decimal(str(
self.config.get("risk.position_size_percent", 2.0)
)) / 100
position_value = balance * risk_percent
# Account for fees by reserving fee amount
from src.trading.fee_calculator import get_fee_calculator
fee_calculator = get_fee_calculator()
# Reserve ~0.4% for round-trip fees
fee_reserve = fee_calculator.calculate_fee_reserve(
position_value=position_value,
exchange_adapter=exchange_adapter,
reserve_percent=0.004 # 0.4% for round-trip
)
# Adjust position value to account for fees
adjusted_position_value = position_value - fee_reserve
if price > 0:
quantity = adjusted_position_value / price
return max(Decimal(0), quantity) # Ensure non-negative
return Decimal(0)
def calculate_kelly_criterion(
self,
win_rate: float,
avg_win: float,
avg_loss: float
) -> Decimal:
"""Calculate position size using Kelly Criterion.
Args:
win_rate: Win rate (0.0 to 1.0)
avg_win: Average win amount
avg_loss: Average loss amount
Returns:
Kelly percentage
"""
if avg_loss == 0:
return Decimal(0)
kelly = (win_rate * avg_win - (1 - win_rate) * avg_loss) / avg_win
# Use fractional Kelly (half) for safety
return Decimal(str(kelly / 2))
def validate_position_size(
self,
symbol: str,
quantity: Decimal,
price: Decimal,
balance: Decimal,
exchange_adapter: Optional[BaseExchangeAdapter] = None
) -> bool:
"""Validate position size against limits, accounting for fees.
Args:
symbol: Trading symbol
quantity: Position quantity
price: Entry price
balance: Available balance
exchange_adapter: Exchange adapter for fee calculation (optional)
Returns:
True if position size is valid
"""
position_value = quantity * price
# Calculate estimated fee for this trade
from src.trading.fee_calculator import get_fee_calculator
from src.core.database import OrderType
fee_calculator = get_fee_calculator()
estimated_fee = fee_calculator.calculate_fee(
quantity=quantity,
price=price,
order_type=OrderType.MARKET, # Use market as worst case
exchange_adapter=exchange_adapter
)
total_cost = position_value + estimated_fee
# Check if exceeds available balance (including fees)
if total_cost > balance:
self.logger.warning(
f"Position cost {total_cost} (value: {position_value}, fee: {estimated_fee}) "
f"exceeds balance {balance}"
)
return False
# Check against risk limits
max_position_percent = Decimal(str(
self.config.get("risk.position_size_percent", 2.0)
)) / 100
if position_value > balance * max_position_percent:
self.logger.warning(f"Position size exceeds risk limit")
return False
return True

229
src/risk/stop_loss.py Normal file
View File

@@ -0,0 +1,229 @@
"""Stop-loss logic."""
from decimal import Decimal
from typing import Dict, Optional, Any
import pandas as pd
from src.core.logger import get_logger
from src.data.indicators import get_indicators
logger = get_logger(__name__)
class StopLossManager:
"""Manages stop-loss orders."""
def __init__(self):
"""Initialize stop-loss manager."""
self.stop_losses: Dict[int, Dict[str, Any]] = {} # position_id -> stop config
self.logger = get_logger(__name__)
self.indicators = get_indicators()
def set_stop_loss(
self,
position_id: int,
stop_price: Decimal,
trailing: bool = False,
trail_percent: Optional[Decimal] = None,
use_atr: bool = False,
atr_multiplier: Decimal = Decimal('2.0'),
atr_period: int = 14,
ohlcv_data: Optional[pd.DataFrame] = None
):
"""Set stop-loss for position.
Args:
position_id: Position ID
stop_price: Stop price (ignored if use_atr=True)
trailing: Enable trailing stop
trail_percent: Trail percentage (ignored if use_atr=True)
use_atr: Use ATR-based stop loss
atr_multiplier: ATR multiplier for stop distance (default 2.0)
atr_period: ATR period (default 14)
ohlcv_data: OHLCV DataFrame for ATR calculation
"""
if use_atr and ohlcv_data is not None and len(ohlcv_data) >= atr_period:
# Calculate ATR-based stop
try:
high = ohlcv_data['high']
low = ohlcv_data['low']
close = ohlcv_data['close']
atr = self.indicators.atr(high, low, close, period=atr_period)
current_atr = Decimal(str(atr.iloc[-1])) if not pd.isna(atr.iloc[-1]) else None
if current_atr is not None:
current_price = Decimal(str(close.iloc[-1]))
atr_distance = current_atr * atr_multiplier
# For long positions, stop is below entry
# For short positions, stop is above entry
# We'll determine direction from stop_price vs current_price
if stop_price < current_price:
# Long position - stop below
stop_price = current_price - atr_distance
else:
# Short position - stop above
stop_price = current_price + atr_distance
self.logger.info(
f"Set ATR-based stop-loss for position {position_id}: "
f"ATR={current_atr}, multiplier={atr_multiplier}, "
f"stop_price={stop_price}"
)
except Exception as e:
self.logger.warning(f"Failed to calculate ATR-based stop, using provided stop_price: {e}")
self.stop_losses[position_id] = {
'stop_price': stop_price,
'trailing': trailing,
'trail_percent': trail_percent,
'use_atr': use_atr,
'atr_multiplier': atr_multiplier if use_atr else None,
'atr_period': atr_period if use_atr else None,
'highest_price': stop_price, # For trailing stops
}
if not use_atr:
self.logger.info(f"Set stop-loss for position {position_id} at {stop_price}")
def check_stop_loss(
self,
position_id: int,
current_price: Decimal,
is_long: bool = True,
ohlcv_data: Optional[pd.DataFrame] = None
) -> bool:
"""Check if stop-loss should trigger.
Args:
position_id: Position ID
current_price: Current market price
is_long: True for long position
ohlcv_data: OHLCV DataFrame for ATR-based trailing stops
Returns:
True if stop-loss should trigger
"""
if position_id not in self.stop_losses:
return False
stop_config = self.stop_losses[position_id]
stop_price = stop_config['stop_price']
use_atr = stop_config.get('use_atr', False)
if stop_config['trailing']:
# Update trailing stop
if use_atr and ohlcv_data is not None and stop_config.get('atr_period'):
# ATR-based trailing stop
try:
atr_period = stop_config['atr_period']
atr_multiplier = stop_config.get('atr_multiplier', Decimal('2.0'))
if len(ohlcv_data) >= atr_period:
high = ohlcv_data['high']
low = ohlcv_data['low']
close = ohlcv_data['close']
atr = self.indicators.atr(high, low, close, period=atr_period)
current_atr = Decimal(str(atr.iloc[-1])) if not pd.isna(atr.iloc[-1]) else None
if current_atr is not None:
atr_distance = current_atr * atr_multiplier
if is_long:
# Long position: trailing stop moves up
if current_price > stop_config['highest_price']:
stop_config['highest_price'] = current_price
stop_price = current_price - atr_distance
stop_config['stop_price'] = stop_price
else:
# Short position: trailing stop moves down
if current_price < stop_config['highest_price']:
stop_config['highest_price'] = current_price
stop_price = current_price + atr_distance
stop_config['stop_price'] = stop_price
except Exception as e:
self.logger.warning(f"Error updating ATR-based trailing stop: {e}")
else:
# Percentage-based trailing stop
if is_long:
if current_price > stop_config['highest_price']:
stop_config['highest_price'] = current_price
if stop_config.get('trail_percent'):
stop_price = current_price * (1 - stop_config['trail_percent'])
stop_config['stop_price'] = stop_price
else:
if current_price < stop_config['highest_price']:
stop_config['highest_price'] = current_price
if stop_config.get('trail_percent'):
stop_price = current_price * (1 + stop_config['trail_percent'])
stop_config['stop_price'] = stop_price
# Check trigger
if is_long:
return current_price <= stop_price
else:
return current_price >= stop_price
def calculate_atr_stop(
self,
entry_price: Decimal,
is_long: bool,
ohlcv_data: pd.DataFrame,
atr_multiplier: Decimal = Decimal('2.0'),
atr_period: int = 14
) -> Decimal:
"""Calculate ATR-based stop loss price.
Args:
entry_price: Entry price
is_long: True for long position
ohlcv_data: OHLCV DataFrame
atr_multiplier: ATR multiplier (default 2.0)
atr_period: ATR period (default 14)
Returns:
Stop loss price
"""
if len(ohlcv_data) < atr_period:
# Fallback to 2% stop if insufficient data
if is_long:
return entry_price * Decimal('0.98')
else:
return entry_price * Decimal('1.02')
try:
high = ohlcv_data['high']
low = ohlcv_data['low']
close = ohlcv_data['close']
atr = self.indicators.atr(high, low, close, period=atr_period)
current_atr = Decimal(str(atr.iloc[-1])) if not pd.isna(atr.iloc[-1]) else None
if current_atr is not None:
atr_distance = current_atr * atr_multiplier
if is_long:
# Long position: stop below entry
return entry_price - atr_distance
else:
# Short position: stop above entry
return entry_price + atr_distance
except Exception as e:
self.logger.warning(f"Error calculating ATR stop, using fallback: {e}")
# Fallback to percentage-based stop
if is_long:
return entry_price * Decimal('0.98')
else:
return entry_price * Decimal('1.02')
def remove_stop_loss(self, position_id: int):
"""Remove stop-loss for position.
Args:
position_id: Position ID
"""
if position_id in self.stop_losses:
del self.stop_losses[position_id]
self.logger.info(f"Removed stop-loss for position {position_id}")

0
src/security/__init__.py Normal file
View File

94
src/security/audit.py Normal file
View File

@@ -0,0 +1,94 @@
"""Audit logging for security and actions."""
from datetime import datetime
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from src.core.database import get_database, AuditLog
from src.core.logger import get_logger
logger = get_logger(__name__)
class AuditLogger:
"""Audit logging for security and important actions."""
def __init__(self):
"""Initialize audit logger."""
self.db = get_database()
def log(
self,
action: str,
entity_type: Optional[str] = None,
entity_id: Optional[int] = None,
details: Optional[Dict[str, Any]] = None
):
"""Log an audit event.
Args:
action: Action performed (e.g., "api_key_added", "order_placed")
entity_type: Type of entity (e.g., "exchange", "strategy", "order")
entity_id: ID of the entity
details: Additional details as dictionary
"""
session = self.db.get_session()
try:
audit_entry = AuditLog(
action=action,
entity_type=entity_type,
entity_id=entity_id,
details=details or {},
timestamp=datetime.utcnow()
)
session.add(audit_entry)
session.commit()
# Also log to application logger
logger.info(f"Audit: {action} on {entity_type} {entity_id}")
except Exception as e:
session.rollback()
logger.error(f"Failed to log audit event: {e}")
finally:
session.close()
def get_audit_log(
self,
entity_type: Optional[str] = None,
entity_id: Optional[int] = None,
limit: int = 100
) -> list[AuditLog]:
"""Get audit log entries.
Args:
entity_type: Filter by entity type
entity_id: Filter by entity ID
limit: Maximum number of entries to return
Returns:
List of AuditLog entries
"""
session = self.db.get_session()
try:
query = session.query(AuditLog)
if entity_type:
query = query.filter_by(entity_type=entity_type)
if entity_id:
query = query.filter_by(entity_id=entity_id)
return query.order_by(AuditLog.timestamp.desc()).limit(limit).all()
finally:
session.close()
# Global audit logger
_audit_logger: Optional[AuditLogger] = None
def get_audit_logger() -> AuditLogger:
"""Get global audit logger instance."""
global _audit_logger
if _audit_logger is None:
_audit_logger = AuditLogger()
return _audit_logger

108
src/security/encryption.py Normal file
View File

@@ -0,0 +1,108 @@
"""Encryption utilities for API keys and sensitive data."""
import base64
import os
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from typing import Optional
from src.core.config import get_config
from src.core.logger import get_logger
logger = get_logger(__name__)
class EncryptionManager:
"""Manages encryption and decryption of sensitive data."""
def __init__(self):
"""Initialize encryption manager."""
self.config = get_config()
self._key = self._get_or_create_key()
self.cipher = Fernet(self._key)
def _get_or_create_key(self) -> bytes:
"""Get or create encryption key."""
# Try to get key from keyring first
try:
import keyring
key = keyring.get_password("crypto_trader", "encryption_key")
if key:
return key.encode()
except Exception as e:
logger.warning(f"Could not get key from keyring: {e}")
# Generate key from user's password or system
# For now, use a file-based approach (in production, use keyring)
key_file = self.config.config_dir / ".encryption_key"
if key_file.exists():
with open(key_file, 'rb') as f:
return f.read()
# Generate new key
key = Fernet.generate_key()
try:
key_file.parent.mkdir(parents=True, exist_ok=True)
with open(key_file, 'wb') as f:
f.write(key)
# Restrict permissions after file is created
key_file.chmod(0o600)
except Exception as e:
logger.error(f"Failed to create encryption key file: {e}")
raise
# Try to store in keyring
try:
import keyring
keyring.set_password("crypto_trader", "encryption_key", key.decode())
except Exception:
pass # Fallback to file if keyring unavailable
return key
def encrypt(self, data: str) -> str:
"""Encrypt sensitive data.
Args:
data: Plain text data to encrypt
Returns:
Encrypted data as base64 string
"""
if not data:
return ""
encrypted = self.cipher.encrypt(data.encode())
return base64.b64encode(encrypted).decode()
def decrypt(self, encrypted_data: str) -> str:
"""Decrypt sensitive data.
Args:
encrypted_data: Base64 encoded encrypted data
Returns:
Decrypted plain text
"""
if not encrypted_data:
return ""
try:
decoded = base64.b64decode(encrypted_data.encode())
decrypted = self.cipher.decrypt(decoded)
return decrypted.decode()
except Exception as e:
logger.error(f"Decryption failed: {e}")
raise ValueError("Failed to decrypt data") from e
# Global encryption manager
_encryption_manager: Optional[EncryptionManager] = None
def get_encryption_manager() -> EncryptionManager:
"""Get global encryption manager instance."""
global _encryption_manager
if _encryption_manager is None:
_encryption_manager = EncryptionManager()
return _encryption_manager

173
src/security/key_manager.py Normal file
View File

@@ -0,0 +1,173 @@
"""API key management with read-only/trading modes and encryption."""
from typing import Optional, Dict
from sqlalchemy.orm import Session
from src.core.database import get_database, Exchange
from src.core.logger import get_logger
from .encryption import get_encryption_manager
logger = get_logger(__name__)
class APIKeyManager:
"""Manages API keys with encryption and permission modes."""
def __init__(self):
"""Initialize API key manager."""
self.db = get_database()
self.encryption = get_encryption_manager()
async def add_exchange(
self,
name: str,
api_key: str,
api_secret: str,
read_only: bool = True,
sandbox: bool = False
) -> Exchange:
"""Add exchange with encrypted API credentials."""
from sqlalchemy import select
async with self.db.get_session() as session:
try:
# Check if exchange already exists
stmt = select(Exchange).where(Exchange.name == name)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
raise ValueError(f"Exchange {name} already exists")
# Encrypt credentials
encrypted_key = self.encryption.encrypt(api_key)
encrypted_secret = self.encryption.encrypt(api_secret)
# Create exchange record
exchange = Exchange(
name=name,
api_key_encrypted=encrypted_key,
api_secret_encrypted=encrypted_secret,
read_only=read_only,
sandbox=sandbox,
enabled=True
)
session.add(exchange)
await session.commit()
logger.info(f"Added exchange {name} (read_only={read_only}, sandbox={sandbox})")
return exchange
except Exception as e:
await session.rollback()
logger.error(f"Failed to add exchange {name}: {e}")
raise
async def get_exchange_credentials(self, exchange_id: int) -> Dict[str, str]:
"""Get decrypted exchange credentials."""
from sqlalchemy import select
async with self.db.get_session() as session:
exchange = await session.get(Exchange, exchange_id)
if not exchange:
raise ValueError(f"Exchange {exchange_id} not found")
if not exchange.api_key_encrypted or not exchange.api_secret_encrypted:
# Return empty credentials for public data exchanges
return {
"api_key": "",
"api_secret": "",
"read_only": getattr(exchange, 'read_only', True),
"sandbox": getattr(exchange, 'sandbox', False),
}
return {
"api_key": self.encryption.decrypt(exchange.api_key_encrypted),
"api_secret": self.encryption.decrypt(exchange.api_secret_encrypted),
"read_only": exchange.read_only,
"sandbox": exchange.sandbox,
}
async def update_exchange(
self,
exchange_id: int,
api_key: Optional[str] = None,
api_secret: Optional[str] = None,
read_only: Optional[bool] = None,
sandbox: Optional[bool] = None,
enabled: Optional[bool] = None
) -> Exchange:
"""Update exchange configuration."""
async with self.db.get_session() as session:
try:
exchange = await session.get(Exchange, exchange_id)
if not exchange:
raise ValueError(f"Exchange {exchange_id} not found")
if api_key is not None:
exchange.api_key_encrypted = self.encryption.encrypt(api_key)
if api_secret is not None:
exchange.api_secret_encrypted = self.encryption.encrypt(api_secret)
if read_only is not None:
exchange.read_only = read_only
if sandbox is not None:
exchange.sandbox = sandbox
if enabled is not None:
exchange.enabled = enabled
await session.commit()
logger.info(f"Updated exchange {exchange.name}")
return exchange
except Exception as e:
await session.rollback()
logger.error(f"Failed to update exchange {exchange_id}: {e}")
raise
async def delete_exchange(self, exchange_id: int):
"""Delete exchange and its credentials."""
async with self.db.get_session() as session:
try:
exchange = await session.get(Exchange, exchange_id)
if not exchange:
raise ValueError(f"Exchange {exchange_id} not found")
await session.delete(exchange)
await session.commit()
logger.info(f"Deleted exchange {exchange.name}")
except Exception as e:
await session.rollback()
logger.error(f"Failed to delete exchange {exchange_id}: {e}")
raise
async def list_exchanges(self) -> list[Exchange]:
"""List all exchanges."""
from sqlalchemy import select
async with self.db.get_session() as session:
result = await session.execute(select(Exchange))
return result.scalars().all()
async def validate_permissions(self, exchange_id: int, requires_trading: bool = False) -> bool:
"""Validate if exchange has required permissions."""
async with self.db.get_session() as session:
exchange = await session.get(Exchange, exchange_id)
if not exchange:
return False
if not exchange.enabled:
return False
if requires_trading and exchange.read_only:
logger.warning(f"Exchange {exchange.name} is read-only but trading required")
return False
return True
# Global API key manager
_key_manager: Optional[APIKeyManager] = None
def get_key_manager() -> APIKeyManager:
"""Get global API key manager instance."""
global _key_manager
if _key_manager is None:
_key_manager = APIKeyManager()
return _key_manager

View File

@@ -0,0 +1,45 @@
"""Strategies package."""
from .base import BaseStrategy, StrategyRegistry, get_strategy_registry, SignalType, StrategySignal
from .technical.rsi_strategy import RSIStrategy
from .technical.macd_strategy import MACDStrategy
from .technical.moving_avg_strategy import MovingAverageStrategy
from .technical.confirmed_strategy import ConfirmedStrategy
from .technical.divergence_strategy import DivergenceStrategy
from .technical.bollinger_mean_reversion import BollingerMeanReversionStrategy
from .dca.dca_strategy import DCAStrategy
from .grid.grid_strategy import GridStrategy
from .momentum.momentum_strategy import MomentumStrategy
from .ensemble.consensus_strategy import ConsensusStrategy
from .technical.pairs_trading import PairsTradingStrategy
from .technical.volatility_breakout import VolatilityBreakoutStrategy
from .sentiment.sentiment_strategy import SentimentStrategy
from .market_making.market_making_strategy import MarketMakingStrategy
# Register strategies
registry = get_strategy_registry()
registry.register("rsi", RSIStrategy)
registry.register("macd", MACDStrategy)
registry.register("moving_average", MovingAverageStrategy)
registry.register("confirmed", ConfirmedStrategy)
registry.register("divergence", DivergenceStrategy)
registry.register("bollinger_mean_reversion", BollingerMeanReversionStrategy)
registry.register("dca", DCAStrategy)
registry.register("grid", GridStrategy)
registry.register("momentum", MomentumStrategy)
registry.register("consensus", ConsensusStrategy)
registry.register("pairs_trading", PairsTradingStrategy)
registry.register("volatility_breakout", VolatilityBreakoutStrategy)
registry.register("sentiment", SentimentStrategy)
registry.register("market_making", MarketMakingStrategy)
__all__ = [
'BaseStrategy', 'StrategyRegistry', 'get_strategy_registry',
'SignalType', 'StrategySignal',
'RSIStrategy', 'MACDStrategy', 'MovingAverageStrategy',
'ConfirmedStrategy', 'DivergenceStrategy', 'BollingerMeanReversionStrategy',
'DCAStrategy', 'GridStrategy', 'MomentumStrategy',
'ConsensusStrategy', 'PairsTradingStrategy', 'VolatilityBreakoutStrategy',
'SentimentStrategy', 'MarketMakingStrategy'
]

450
src/strategies/base.py Normal file
View File

@@ -0,0 +1,450 @@
"""Base strategy class and strategy registry system."""
import pandas as pd
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import Dict, Optional, List, Any, Callable
from datetime import datetime
from enum import Enum
from src.core.logger import get_logger
from src.core.database import OrderSide, OrderType
logger = get_logger(__name__)
class SignalType(str, Enum):
"""Trading signal types."""
BUY = "buy"
SELL = "sell"
HOLD = "hold"
CLOSE = "close"
class StrategySignal:
"""Trading signal from strategy."""
def __init__(
self,
signal_type: SignalType,
symbol: str,
strength: float = 1.0,
price: Optional[Decimal] = None,
quantity: Optional[Decimal] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""Initialize strategy signal.
Args:
signal_type: Signal type (buy, sell, hold, close)
symbol: Trading symbol
strength: Signal strength (0.0 to 1.0)
price: Suggested price
quantity: Suggested quantity
metadata: Additional metadata
"""
self.signal_type = signal_type
self.symbol = symbol
self.strength = strength
self.price = price
self.quantity = quantity
self.metadata = metadata or {}
self.timestamp = datetime.utcnow()
class BaseStrategy(ABC):
"""Base class for all trading strategies."""
def __init__(
self,
name: str,
parameters: Optional[Dict[str, Any]] = None,
timeframes: Optional[List[str]] = None
):
"""Initialize strategy.
Args:
name: Strategy name
parameters: Strategy parameters
timeframes: List of timeframes (e.g., ['1h', '15m'])
"""
self.name = name
self.parameters = parameters or {}
self.timeframes = timeframes or ['1h']
self.enabled = False
self.logger = get_logger(f"strategy.{name}")
self._data_cache: Dict[str, Any] = {}
@abstractmethod
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Called on each price update.
Args:
symbol: Trading symbol
price: Current price
timeframe: Timeframe of the update
data: Additional market data
Returns:
StrategySignal or None
"""
pass
@abstractmethod
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process and potentially modify signal.
Args:
signal: Generated signal
Returns:
Modified signal or None to cancel
"""
pass
def calculate_position_size(
self,
signal: StrategySignal,
balance: Decimal,
price: Decimal,
exchange_adapter=None
) -> Decimal:
"""Calculate position size for signal, accounting for fees.
Args:
signal: Trading signal
balance: Available balance
price: Current price
exchange_adapter: Exchange adapter for fee calculation (optional)
Returns:
Position size
"""
# Default: use 2% of balance
risk_percent = self.parameters.get('position_size_percent', 2.0) / 100.0
position_value = balance * Decimal(str(risk_percent))
# Account for fees by reserving fee amount
from src.trading.fee_calculator import get_fee_calculator
fee_calculator = get_fee_calculator()
# Reserve ~0.4% for round-trip fees (conservative estimate)
fee_reserve = fee_calculator.calculate_fee_reserve(
position_value=position_value,
exchange_adapter=exchange_adapter,
reserve_percent=0.004 # 0.4% for round-trip
)
# Adjust position value to account for fees
adjusted_position_value = position_value - fee_reserve
# Calculate quantity
if price > 0:
quantity = adjusted_position_value / price
return max(Decimal(0), quantity) # Ensure non-negative
return Decimal(0)
def should_execute(self, signal: StrategySignal) -> bool:
"""Check if signal should be executed.
Args:
signal: Trading signal
Returns:
True if should execute
"""
if not self.enabled:
return False
# Check signal strength threshold
min_strength = self.parameters.get('min_signal_strength', 0.5)
if signal.strength < min_strength:
return False
return True
def should_execute_with_fees(
self,
signal: StrategySignal,
balance: Decimal,
price: Decimal,
exchange_adapter=None
) -> bool:
"""Check if signal should be executed considering fees and minimum profit threshold.
Args:
signal: Trading signal
balance: Available balance
price: Current price
exchange_adapter: Exchange adapter for fee calculation (optional)
Returns:
True if should execute after fee consideration
"""
# First check basic execution criteria
if not self.should_execute(signal):
return False
# Calculate position size
quantity = signal.quantity or self.calculate_position_size(signal, balance, price, exchange_adapter)
if quantity <= 0:
return False
# Check minimum profit threshold
from src.trading.fee_calculator import get_fee_calculator
fee_calculator = get_fee_calculator()
# Get minimum profit multiplier from strategy parameters (default 2.0)
min_profit_multiplier = self.parameters.get('min_profit_multiplier', 2.0)
min_profit_threshold = fee_calculator.get_minimum_profit_threshold(
quantity=quantity,
price=price,
exchange_adapter=exchange_adapter,
multiplier=min_profit_multiplier
)
# Estimate potential profit (simplified - strategies can override)
# For buy signals, we'd need to estimate exit price
# For now, we'll use a basic check: if signal strength is high enough
# Strategies should override this method for more sophisticated checks
# If we have a target price in signal metadata, use it
target_price = signal.metadata.get('target_price')
if target_price:
if signal.signal_type.value == "buy":
potential_profit = (target_price - price) * quantity
else: # sell
potential_profit = (price - target_price) * quantity
if potential_profit < min_profit_threshold:
self.logger.debug(
f"Signal filtered: potential profit {potential_profit} < "
f"minimum threshold {min_profit_threshold}"
)
return False
return True
def apply_trend_filter(
self,
signal: StrategySignal,
ohlcv_data: Any,
adx_period: int = 14,
min_adx: float = 25.0
) -> Optional[StrategySignal]:
"""Apply ADX-based trend filter to signal.
Filters signals based on trend strength:
- Only allow BUY signals when ADX > threshold (strong trend)
- Only allow SELL signals in downtrends with ADX > threshold
- Filters out choppy/ranging markets
Args:
signal: Trading signal to filter
ohlcv_data: OHLCV DataFrame with columns: high, low, close
adx_period: ADX calculation period (default 14)
min_adx: Minimum ADX value for signal (default 25.0)
Returns:
Filtered signal or None if filtered out
"""
if not self.parameters.get('use_trend_filter', False):
return signal
try:
from src.data.indicators import get_indicators
if ohlcv_data is None or len(ohlcv_data) < adx_period:
# Not enough data, allow signal
return signal
# Ensure we have a DataFrame
if not isinstance(ohlcv_data, pd.DataFrame):
return signal
indicators = get_indicators()
# Calculate ADX
high = ohlcv_data['high']
low = ohlcv_data['low']
close = ohlcv_data['close']
adx = indicators.adx(high, low, close, period=adx_period)
current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0
# Check if trend is strong enough
if current_adx < min_adx:
# Weak trend - filter out signal
self.logger.debug(
f"Trend filter: ADX {current_adx:.2f} < {min_adx}, "
f"filtering {signal.signal_type.value} signal"
)
return None
# Additional check: for BUY signals, ensure uptrend
# For SELL signals, ensure downtrend
# We can use price vs moving average to determine trend direction
if len(close) >= 20:
sma_20 = indicators.sma(close, period=20)
current_price = close.iloc[-1]
sma_value = sma_20.iloc[-1] if not pd.isna(sma_20.iloc[-1]) else current_price
if signal.signal_type == SignalType.BUY:
# BUY only in uptrend (price above SMA)
if current_price < sma_value:
self.logger.debug(
f"Trend filter: BUY signal filtered - price below SMA "
f"(price: {current_price}, SMA: {sma_value})"
)
return None
elif signal.signal_type == SignalType.SELL:
# SELL only in downtrend (price below SMA)
if current_price > sma_value:
self.logger.debug(
f"Trend filter: SELL signal filtered - price above SMA "
f"(price: {current_price}, SMA: {sma_value})"
)
return None
return signal
except Exception as e:
self.logger.warning(f"Error applying trend filter: {e}, allowing signal")
return signal
def get_required_indicators(self) -> List[str]:
"""Get list of required indicators.
Returns:
List of indicator names
"""
return []
def validate_parameters(self) -> bool:
"""Validate strategy parameters.
Returns:
True if parameters are valid
"""
return True
def get_state(self) -> Dict[str, Any]:
"""Get strategy state for persistence.
Returns:
State dictionary
"""
return {
'name': self.name,
'parameters': self.parameters,
'timeframes': self.timeframes,
'enabled': self.enabled,
}
def set_state(self, state: Dict[str, Any]):
"""Restore strategy state.
Args:
state: State dictionary
"""
self.parameters = state.get('parameters', {})
self.timeframes = state.get('timeframes', ['1h'])
self.enabled = state.get('enabled', False)
class StrategyRegistry:
"""Registry for managing strategies."""
def __init__(self):
"""Initialize strategy registry."""
self._strategies: Dict[str, type] = {}
self._instances: Dict[int, BaseStrategy] = {}
self.logger = get_logger(__name__)
def register(self, name: str, strategy_class: type):
"""Register a strategy class.
Args:
name: Strategy name
strategy_class: Strategy class (subclass of BaseStrategy)
"""
if not issubclass(strategy_class, BaseStrategy):
raise ValueError(f"Strategy class must inherit from BaseStrategy")
self._strategies[name.lower()] = strategy_class
self.logger.info(f"Registered strategy: {name}")
def create_instance(
self,
strategy_id: int,
name: str,
parameters: Optional[Dict[str, Any]] = None,
timeframes: Optional[List[str]] = None
) -> Optional[BaseStrategy]:
"""Create strategy instance.
Args:
strategy_id: Strategy ID from src.database
name: Strategy name
parameters: Strategy parameters
timeframes: List of timeframes
Returns:
Strategy instance or None
"""
strategy_class = self._strategies.get(name.lower())
if not strategy_class:
self.logger.error(f"Strategy {name} not registered")
return None
try:
instance = strategy_class(name, parameters, timeframes)
self._instances[strategy_id] = instance
return instance
except Exception as e:
self.logger.error(f"Failed to create strategy instance: {e}")
return None
def get_instance(self, strategy_id: int) -> Optional[BaseStrategy]:
"""Get strategy instance by ID.
Args:
strategy_id: Strategy ID
Returns:
Strategy instance or None
"""
return self._instances.get(strategy_id)
def list_available(self) -> List[str]:
"""List available strategy types.
Returns:
List of strategy names
"""
return list(self._strategies.keys())
def unregister(self, name: str):
"""Unregister a strategy.
Args:
name: Strategy name
"""
if name.lower() in self._strategies:
del self._strategies[name.lower()]
self.logger.info(f"Unregistered strategy: {name}")
# Global strategy registry
_registry: Optional[StrategyRegistry] = None
def get_strategy_registry() -> StrategyRegistry:
"""Get global strategy registry instance."""
global _registry
if _registry is None:
_registry = StrategyRegistry()
return _registry

View File

View File

@@ -0,0 +1,90 @@
"""Dollar Cost Averaging (DCA) strategy."""
from decimal import Decimal
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.core.logger import get_logger
logger = get_logger(__name__)
class DCAStrategy(BaseStrategy):
"""Dollar Cost Averaging strategy - fixed amount per interval."""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize DCA strategy.
Parameters:
amount: Fixed amount to invest per interval (default: 10)
interval: Interval type - 'daily', 'weekly', 'monthly' (default: 'daily')
target_allocation: Target allocation percentage (default: 10%)
"""
super().__init__(name, parameters, timeframes)
self.amount = Decimal(str(self.parameters.get('amount', 10)))
self.interval = self.parameters.get('interval', 'daily')
self.target_allocation = Decimal(str(self.parameters.get('target_allocation', 10))) / 100
# Track last purchase
self.last_purchase_time: Optional[datetime] = None
self._calculate_interval_delta()
def _calculate_interval_delta(self):
"""Calculate timedelta for interval."""
if self.interval == 'daily':
self.interval_delta = timedelta(days=1)
elif self.interval == 'weekly':
self.interval_delta = timedelta(weeks=1)
elif self.interval == 'monthly':
self.interval_delta = timedelta(days=30)
else:
self.interval_delta = timedelta(days=1)
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate DCA signal based on interval."""
now = datetime.utcnow()
# Check if enough time has passed since last purchase
if self.last_purchase_time:
time_since_last = now - self.last_purchase_time
if time_since_last < self.interval_delta:
return None
# Generate buy signal for fixed amount
quantity = self.amount / price
self.last_purchase_time = now
return StrategySignal(
signal_type=SignalType.BUY,
symbol=symbol,
strength=1.0,
price=price,
quantity=quantity,
metadata={
'amount': float(self.amount),
'interval': self.interval,
'strategy': 'dca'
}
)
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal."""
return signal if self.should_execute(signal) else None
def should_rebalance(self, current_allocation: Decimal, portfolio_value: Decimal) -> bool:
"""Check if rebalancing is needed.
Args:
current_allocation: Current allocation percentage
portfolio_value: Total portfolio value
Returns:
True if rebalancing needed
"""
target_value = portfolio_value * self.target_allocation
current_value = portfolio_value * current_allocation
# Rebalance if allocation deviates by more than 5%
deviation = abs(current_allocation - self.target_allocation)
return deviation > Decimal('0.05')

View File

@@ -0,0 +1,6 @@
"""Ensemble strategy package."""
from .consensus_strategy import ConsensusStrategy
__all__ = ['ConsensusStrategy']

View File

@@ -0,0 +1,244 @@
"""Ensemble/consensus strategy.
Combines signals from multiple strategies with voting mechanism.
Only executes when multiple strategies agree, improving signal quality.
"""
from decimal import Decimal
from typing import Optional, Dict, Any, List
import pandas as pd
from src.strategies.base import BaseStrategy, StrategySignal, SignalType, get_strategy_registry
from src.core.logger import get_logger
from src.autopilot.performance_tracker import get_performance_tracker
logger = get_logger(__name__)
class ConsensusStrategy(BaseStrategy):
"""Ensemble strategy that combines signals from multiple strategies.
This strategy aggregates signals from multiple registered strategies
and only generates signals when a minimum consensus is reached.
Signals are weighted by strategy performance metrics.
"""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize consensus strategy.
Parameters:
strategy_names: List of strategy names to include (None = all)
min_consensus: Minimum number of strategies that must agree (default 2)
use_weights: Weight signals by strategy performance (default True)
min_weight: Minimum weight for a strategy to participate (default 0.3)
exclude_strategies: List of strategy names to exclude
"""
super().__init__(name, parameters, timeframes)
self.strategy_names = self.parameters.get('strategy_names', None)
self.min_consensus = self.parameters.get('min_consensus', 2)
self.use_weights = self.parameters.get('use_weights', True)
self.min_weight = self.parameters.get('min_weight', 0.3)
self.exclude_strategies = self.parameters.get('exclude_strategies', [])
self.registry = get_strategy_registry()
self.performance_tracker = get_performance_tracker()
self._strategy_instances: Dict[str, BaseStrategy] = {}
self._strategy_weights: Dict[str, float] = {}
# Initialize strategy instances
self._initialize_strategies()
# Calculate weights if using weighted voting
if self.use_weights:
self._calculate_weights()
def _initialize_strategies(self):
"""Initialize instances of strategies to monitor."""
if self.strategy_names is None:
available = self.registry.list_available()
# Aggressive filtering to prevent recursion
self.strategy_names = [
s for s in available
if s not in self.exclude_strategies
and "consensus" not in s.lower()
and s.lower() != self.name.lower()
]
self.logger.info(f"ConsensusStrategy({self.name}) automatically selected strategies: {self.strategy_names}")
else:
self.logger.info(f"ConsensusStrategy({self.name}) manually selected strategies: {self.strategy_names}")
for strategy_name in self.strategy_names:
try:
strategy_class = self.registry._strategies.get(strategy_name.lower())
if strategy_class:
instance = strategy_class(
name=f"{strategy_name}_consensus",
parameters={},
timeframes=self.timeframes
)
instance.enabled = True
self._strategy_instances[strategy_name] = instance
self.logger.debug(f"Initialized strategy for consensus: {strategy_name}")
except Exception as e:
self.logger.warning(f"Failed to initialize strategy {strategy_name} for consensus: {e}")
async def _calculate_weights(self):
"""Calculate weights for strategies based on performance."""
for strategy_name in self._strategy_instances.keys():
try:
metrics = await self.performance_tracker.calculate_metrics(strategy_name, period_days=30)
# Weight based on win rate and Sharpe ratio
win_rate = metrics.get('win_rate', 0.5)
sharpe_ratio = max(0.0, metrics.get('sharpe_ratio', 0.0))
# Normalize to 0-1 range
weight = (win_rate * 0.6) + (min(sharpe_ratio / 2.0, 1.0) * 0.4)
# Apply minimum weight threshold
if weight < self.min_weight:
weight = self.min_weight
self._strategy_weights[strategy_name] = weight
self.logger.debug(f"Strategy {strategy_name} weight: {weight:.2f}")
except Exception as e:
# Default weight if metrics unavailable
self._strategy_weights[strategy_name] = 0.5
self.logger.warning(f"Could not calculate weight for {strategy_name}: {e}")
async def _collect_signals(
self,
symbol: str,
price: Decimal,
timeframe: str,
data: Dict[str, Any]
) -> List[tuple[str, StrategySignal, float]]:
"""Collect signals from all monitored strategies.
Args:
symbol: Trading symbol
price: Current price
timeframe: Timeframe
data: Market data
Returns:
List of (strategy_name, signal, weight) tuples
"""
signals: List[tuple[str, StrategySignal, float]] = []
for strategy_name, strategy_instance in self._strategy_instances.items():
try:
signal = await strategy_instance.on_tick(symbol, price, timeframe, data)
if signal and signal.signal_type != SignalType.HOLD:
# Process signal through strategy's on_signal
signal = strategy_instance.on_signal(signal)
if signal:
weight = self._strategy_weights.get(strategy_name, 0.5) if self.use_weights else 1.0
signals.append((strategy_name, signal, weight))
except Exception as e:
import traceback
self.logger.warning(f"Error getting signal from {strategy_name}: {e}\n{traceback.format_exc()}")
return signals
def _aggregate_signals(
self,
signals: List[tuple[str, StrategySignal, float]]
) -> Optional[StrategySignal]:
"""Aggregate signals and determine consensus.
Args:
signals: List of (strategy_name, signal, weight) tuples
Returns:
Consensus signal or None
"""
if not signals:
return None
# Group signals by type
buy_signals: List[tuple[str, StrategySignal, float]] = []
sell_signals: List[tuple[str, StrategySignal, float]] = []
for strategy_name, signal, weight in signals:
if signal.signal_type == SignalType.BUY:
buy_signals.append((strategy_name, signal, weight))
elif signal.signal_type == SignalType.SELL:
sell_signals.append((strategy_name, signal, weight))
# Calculate weighted consensus scores
buy_score = sum(weight * signal.strength for _, signal, weight in buy_signals)
sell_score = sum(weight * signal.strength for _, signal, weight in sell_signals)
buy_count = len(buy_signals)
sell_count = len(sell_signals)
# Determine final signal
final_signal_type = None
consensus_count = 0
consensus_score = 0.0
participating_strategies = []
if buy_count >= self.min_consensus and buy_score > sell_score:
final_signal_type = SignalType.BUY
consensus_count = buy_count
consensus_score = buy_score
participating_strategies = [name for name, _, _ in buy_signals]
elif sell_count >= self.min_consensus and sell_score > buy_score:
final_signal_type = SignalType.SELL
consensus_count = sell_count
consensus_score = sell_score
participating_strategies = [name for name, _, _ in sell_signals]
if final_signal_type is None:
return None
# Calculate final signal strength (normalized)
max_possible_score = len(self._strategy_instances) * 1.0 # Max weight * max strength
strength = min(1.0, consensus_score / max_possible_score) if max_possible_score > 0 else 0.5
# Use price from strongest signal
strongest_signal = max(
buy_signals if final_signal_type == SignalType.BUY else sell_signals,
key=lambda x: x[2] * x[1].strength # weight * strength
)[1]
return StrategySignal(
signal_type=final_signal_type,
symbol=strongest_signal.symbol,
strength=strength,
price=strongest_signal.price,
metadata={
'consensus_count': consensus_count,
'consensus_score': consensus_score,
'participating_strategies': participating_strategies,
'buy_count': buy_count,
'sell_count': sell_count,
'buy_score': buy_score,
'sell_score': sell_score,
'strategy': 'consensus'
}
)
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate consensus signal from multiple strategies."""
# Update weights periodically (every 100 ticks)
if not hasattr(self, '_tick_count'):
self._tick_count = 0
self._tick_count += 1
if self.use_weights and self._tick_count % 100 == 0:
await self._calculate_weights()
# Collect signals from all strategies
signals = await self._collect_signals(symbol, price, timeframe, data)
# Aggregate signals
consensus_signal = self._aggregate_signals(signals)
return consensus_signal
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal."""
return signal if self.should_execute(signal) else None

View File

View File

@@ -0,0 +1,109 @@
"""Grid trading strategy."""
from decimal import Decimal
from typing import Optional, Dict, Any, List
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.core.logger import get_logger
logger = get_logger(__name__)
class GridStrategy(BaseStrategy):
"""Grid trading strategy - buy at lower levels, sell at higher levels."""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize Grid strategy.
Parameters:
grid_spacing: Percentage spacing between grid levels (default: 1%)
num_levels: Number of grid levels above and below center (default: 10)
center_price: Center price for grid (default: current price)
profit_target: Profit target percentage (default: 2%)
"""
super().__init__(name, parameters, timeframes)
self.grid_spacing = Decimal(str(self.parameters.get('grid_spacing', 1))) / 100
self.num_levels = self.parameters.get('num_levels', 10)
self.center_price = self.parameters.get('center_price')
self.profit_target = Decimal(str(self.parameters.get('profit_target', 2))) / 100
# Grid levels
self.buy_levels: List[Decimal] = []
self.sell_levels: List[Decimal] = []
self.positions: Dict[Decimal, Decimal] = {} # entry_price -> quantity
def _update_grid_levels(self, current_price: Decimal):
"""Update grid levels based on current price."""
if not self.center_price:
self.center_price = current_price
# Calculate grid levels
self.buy_levels = []
self.sell_levels = []
for i in range(1, self.num_levels + 1):
# Buy levels below center
buy_price = self.center_price * (1 - self.grid_spacing * Decimal(i))
self.buy_levels.append(buy_price)
# Sell levels above center
sell_price = self.center_price * (1 + self.grid_spacing * Decimal(i))
self.sell_levels.append(sell_price)
# Sort levels
self.buy_levels.sort(reverse=True) # Highest to lowest
self.sell_levels.sort() # Lowest to highest
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate grid trading signals."""
self._update_grid_levels(price)
# Check buy levels
for buy_level in self.buy_levels:
if price <= buy_level and buy_level not in self.positions:
# Buy signal
quantity = self.parameters.get('position_size', Decimal('0.1'))
return StrategySignal(
signal_type=SignalType.BUY,
symbol=symbol,
strength=0.8,
price=buy_level,
quantity=quantity,
metadata={
'grid_level': float(buy_level),
'strategy': 'grid',
'type': 'buy'
}
)
# Check sell levels (profit taking)
for entry_price, quantity in list(self.positions.items()):
profit_pct = (price - entry_price) / entry_price
if profit_pct >= self.profit_target:
# Sell signal for profit taking
del self.positions[entry_price]
return StrategySignal(
signal_type=SignalType.SELL,
symbol=symbol,
strength=1.0,
price=price,
quantity=quantity,
metadata={
'entry_price': float(entry_price),
'profit_pct': float(profit_pct * 100),
'strategy': 'grid',
'type': 'profit_take'
}
)
return None
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal and track positions."""
if signal.signal_type == SignalType.BUY:
# Track position
self.positions[signal.price] = signal.quantity
elif signal.signal_type == SignalType.SELL:
# Position already removed in on_tick
pass
return signal if self.should_execute(signal) else None

View File

@@ -0,0 +1,5 @@
"""Market making strategy package."""
from .market_making_strategy import MarketMakingStrategy
__all__ = ['MarketMakingStrategy']

View File

@@ -0,0 +1,206 @@
"""Market Making Strategy - Profits from the bid-ask spread in sideways markets."""
from decimal import Decimal
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from ..base import BaseStrategy, StrategySignal, SignalType
from src.data.pricing_service import get_pricing_service
from src.data.indicators import get_indicators
from src.core.logger import get_logger
logger = get_logger(__name__)
class MarketMakingStrategy(BaseStrategy):
"""
Market Making Strategy that profits from the bid-ask spread.
Logic:
1. Calculate mid-price from order book or last trade.
2. Place limit BUY order at (mid - spread%).
3. Place limit SELL order at (mid + spread%).
4. Re-quote orders when price moves beyond threshold.
5. Use inventory skew to manage risk (if holding too much, lower sell price).
Best suited for:
- Sideways/ranging markets with low volatility
- High-liquidity pairs with tight spreads
Risks:
- Adverse selection (getting filled on wrong side during trends)
- Inventory accumulation
"""
def __init__(self, name: str, parameters: Dict[str, Any], timeframes: List[str] = None):
super().__init__(name, parameters, timeframes)
# Strategy parameters
self.spread_percent = float(parameters.get('spread_percent', 0.2)) / 100 # Convert to decimal
self.requote_threshold = float(parameters.get('requote_threshold', 0.5)) / 100
self.max_inventory = Decimal(str(parameters.get('max_inventory', 1.0)))
self.inventory_skew_factor = float(parameters.get('inventory_skew_factor', 0.5))
self.min_adx = float(parameters.get('min_adx', 20)) # Only make markets when ADX < this (low trend)
self.order_size_percent = float(parameters.get('order_size_percent', 5)) / 100
self.pricing_service = get_pricing_service()
self.indicators = get_indicators()
# Track active orders and inventory
self._current_inventory = Decimal('0')
self._last_mid_price: Optional[Decimal] = None
self._last_quote_time: Optional[datetime] = None
self._active_orders: Dict[str, Any] = {} # side -> order info
def _calculate_skewed_spread(self, inventory: Decimal) -> tuple[float, float]:
"""Calculate bid/ask spread with inventory skew.
When inventory is positive (long), we want to sell more aggressively.
When inventory is negative (short), we want to buy more aggressively.
Returns:
Tuple of (bid_spread, ask_spread) as percentages
"""
base_spread = self.spread_percent
# Calculate skew based on inventory relative to max
if self.max_inventory > 0:
inventory_ratio = float(inventory / self.max_inventory)
else:
inventory_ratio = 0.0
# Clamp between -1 and 1
inventory_ratio = max(-1.0, min(1.0, inventory_ratio))
# Apply skew
skew = inventory_ratio * self.inventory_skew_factor * base_spread
# If positive inventory, tighten ask (lower sell price), widen bid
# If negative inventory, tighten bid (higher buy price), widen ask
bid_spread = base_spread + skew # Positive inventory -> wider bid
ask_spread = base_spread - skew # Positive inventory -> tighter ask
# Ensure spreads are positive
bid_spread = max(0.0001, bid_spread)
ask_spread = max(0.0001, ask_spread)
return bid_spread, ask_spread
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Check for market making opportunities."""
if not self.enabled:
return None
try:
current_price = price
# Fetch OHLCV for ADX calculation
ohlcv = self.pricing_service.get_ohlcv(
symbol=symbol,
timeframe=timeframe,
limit=30
)
if ohlcv and len(ohlcv) >= 14:
df_data = {
'high': [float(c[2]) for c in ohlcv],
'low': [float(c[3]) for c in ohlcv],
'close': [float(c[4]) for c in ohlcv]
}
import pandas as pd
df = pd.DataFrame(df_data)
adx = self.indicators.adx(df['high'], df['low'], df['close'], period=14)
current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0
# Only make markets in low-trend environments
if current_adx > self.min_adx:
self.logger.debug(
f"Market Making skipped: ADX {current_adx:.1f} > {self.min_adx} (trending market)"
)
return None
else:
current_adx = 0.0
# Check if we need to requote
should_requote = False
if self._last_mid_price is None:
should_requote = True
else:
price_change = abs(float(current_price - self._last_mid_price) / float(self._last_mid_price))
if price_change > self.requote_threshold:
should_requote = True
self.logger.debug(f"Requote triggered: price moved {price_change:.2%}")
if not should_requote:
return None
# Calculate skewed spreads based on inventory
bid_spread, ask_spread = self._calculate_skewed_spread(self._current_inventory)
# Calculate quote prices
bid_price = current_price * Decimal(str(1 - bid_spread))
ask_price = current_price * Decimal(str(1 + ask_spread))
# Round to 2 decimal places
bid_price = bid_price.quantize(Decimal('0.01'))
ask_price = ask_price.quantize(Decimal('0.01'))
self._last_mid_price = current_price
self._last_quote_time = datetime.now()
self.logger.info(
f"Market Making {symbol}: Mid={current_price}, "
f"Bid={bid_price} ({bid_spread:.3%}), Ask={ask_price} ({ask_spread:.3%}), "
f"Inventory={self._current_inventory}, ADX={current_adx:.1f}"
)
# Generate signal for placing both limit orders
# The execution engine will need special handling for this
signal = StrategySignal(
signal_type=SignalType.HOLD, # Special: not BUY or SELL, but both
symbol=symbol,
strength=1.0 - (current_adx / 100), # Higher strength when less trending
price=current_price,
metadata={
"strategy": "market_making",
"order_type": "limit_pair", # Indicates we want both bid and ask
"bid_price": float(bid_price),
"ask_price": float(ask_price),
"bid_spread": bid_spread,
"ask_spread": ask_spread,
"inventory": float(self._current_inventory),
"adx": float(current_adx),
"description": f"Market Making: Bid={bid_price}, Ask={ask_price}"
}
)
return signal
except Exception as e:
self.logger.error(f"Error in MarketMakingStrategy: {e}")
return None
return None
def update_inventory(self, side: str, quantity: Decimal):
"""Update inventory after a fill.
Args:
side: 'buy' or 'sell'
quantity: Filled quantity
"""
if side == 'buy':
self._current_inventory += quantity
elif side == 'sell':
self._current_inventory -= quantity
self.logger.info(f"Inventory updated: {side} {quantity}, new inventory: {self._current_inventory}")
def get_inventory(self) -> Decimal:
"""Get current inventory position."""
return self._current_inventory
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal (pass-through)."""
return signal

View File

View File

@@ -0,0 +1,138 @@
"""Momentum trading strategy."""
from decimal import Decimal
from typing import Optional, Dict, Any, List
import pandas as pd
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.data.indicators import get_indicators
from src.core.logger import get_logger
logger = get_logger(__name__)
class MomentumStrategy(BaseStrategy):
"""Momentum-based trading strategy."""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize Momentum strategy.
Parameters:
lookback_period: Lookback period for momentum calculation (default: 20)
momentum_threshold: Minimum momentum strength to enter (default: 0.05 = 5%)
volume_threshold: Minimum volume increase for confirmation (default: 1.5x)
exit_threshold: Momentum reversal threshold for exit (default: -0.02 = -2%)
"""
super().__init__(name, parameters, timeframes)
self.lookback_period = self.parameters.get('lookback_period', 20)
self.momentum_threshold = Decimal(str(self.parameters.get('momentum_threshold', 0.05)))
self.volume_threshold = self.parameters.get('volume_threshold', 1.5)
self.exit_threshold = Decimal(str(self.parameters.get('exit_threshold', -0.02)))
self.indicators = get_indicators()
self._price_history: List[float] = []
self._volume_history: List[float] = []
self._in_position = False
self._entry_price: Optional[Decimal] = None
def _calculate_momentum(self, prices: pd.Series) -> float:
"""Calculate price momentum.
Args:
prices: Series of prices
Returns:
Momentum value (percentage change)
"""
if len(prices) < self.lookback_period:
return 0.0
recent_prices = prices[-self.lookback_period:]
old_price = recent_prices.iloc[0]
new_price = recent_prices.iloc[-1]
if old_price == 0:
return 0.0
return float((new_price - old_price) / old_price)
def _check_volume_confirmation(self, volumes: pd.Series) -> bool:
"""Check if volume confirms momentum.
Args:
volumes: Series of volumes
Returns:
True if volume confirms
"""
if len(volumes) < self.lookback_period:
return False
recent_volumes = volumes[-self.lookback_period:]
avg_volume = recent_volumes.mean()
current_volume = recent_volumes.iloc[-1]
return current_volume >= avg_volume * self.volume_threshold
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate momentum-based signals."""
# Add to history
self._price_history.append(float(price))
self._volume_history.append(float(data.get('volume', 0)))
if len(self._price_history) < self.lookback_period + 1:
return None
prices = pd.Series(self._price_history)
volumes = pd.Series(self._volume_history)
# Calculate momentum
momentum = self._calculate_momentum(prices)
if not self._in_position:
# Entry logic
if momentum >= float(self.momentum_threshold):
# Check volume confirmation
if self._check_volume_confirmation(volumes):
self._in_position = True
self._entry_price = price
return StrategySignal(
signal_type=SignalType.BUY,
symbol=symbol,
strength=min(momentum / float(self.momentum_threshold), 1.0),
price=price,
metadata={
'momentum': momentum,
'volume_confirmed': True,
'strategy': 'momentum',
'type': 'entry'
}
)
else:
# Exit logic
if momentum <= float(self.exit_threshold):
self._in_position = False
entry_price = self._entry_price or price
return StrategySignal(
signal_type=SignalType.SELL,
symbol=symbol,
strength=1.0,
price=price,
metadata={
'momentum': momentum,
'entry_price': float(entry_price),
'strategy': 'momentum',
'type': 'exit'
}
)
return None
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal."""
if signal.signal_type == SignalType.SELL:
self._in_position = False
self._entry_price = None
return signal if self.should_execute(signal) else None

370
src/strategies/scheduler.py Normal file
View File

@@ -0,0 +1,370 @@
"""Strategy scheduling system with time and condition-based triggers."""
from datetime import datetime, time
from typing import Optional, Callable, Dict, Any, List
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from apscheduler.triggers.interval import IntervalTrigger
from src.core.logger import get_logger
from src.core.database import get_database, Strategy
from src.strategies.base import get_strategy_registry, SignalType
from src.trading.engine import get_trading_engine
from src.data.pricing_service import get_pricing_service
from src.core.database import OrderSide, OrderType
logger = get_logger(__name__)
class StrategyScheduler:
"""Schedules strategy execution based on time and conditions."""
def __init__(self):
"""Initialize strategy scheduler."""
self.scheduler = BackgroundScheduler()
self.scheduler.start()
self.jobs: Dict[int, str] = {} # strategy_id -> job_id
self._active_strategies: Dict[int, Dict[str, Any]] = {} # strategy_id -> status info
self.logger = get_logger(__name__)
def schedule_time_based(
self,
strategy_id: int,
schedule_config: Dict[str, Any],
callback: Callable
) -> bool:
"""Schedule strategy based on time.
Args:
strategy_id: Strategy ID
schedule_config: Schedule configuration
- type: 'daily', 'weekly', 'cron'
- time: Time string (HH:MM) for daily
- days: List of days for weekly
- cron: Cron expression
callback: Function to call
Returns:
True if scheduling successful
"""
try:
# Remove existing job if any
if strategy_id in self.jobs:
self.unschedule(strategy_id)
schedule_type = schedule_config.get('type', 'daily')
if schedule_type == 'daily':
time_str = schedule_config.get('time', '09:00')
hour, minute = map(int, time_str.split(':'))
trigger = CronTrigger(hour=hour, minute=minute)
elif schedule_type == 'weekly':
days = schedule_config.get('days', [0, 1, 2, 3, 4]) # Monday-Friday
time_str = schedule_config.get('time', '09:00')
hour, minute = map(int, time_str.split(':'))
trigger = CronTrigger(day_of_week=','.join(map(str, days)), hour=hour, minute=minute)
elif schedule_type == 'cron':
cron_expr = schedule_config.get('cron')
trigger = CronTrigger.from_crontab(cron_expr)
elif schedule_type == 'interval':
interval = schedule_config.get('interval', 60) # seconds
trigger = IntervalTrigger(seconds=interval)
else:
logger.error(f"Unknown schedule type: {schedule_type}")
return False
job_id = f"strategy_{strategy_id}"
self.scheduler.add_job(
callback,
trigger=trigger,
id=job_id,
replace_existing=True
)
self.jobs[strategy_id] = job_id
logger.info(f"Scheduled strategy {strategy_id} with {schedule_type} trigger")
return True
except Exception as e:
logger.error(f"Failed to schedule strategy {strategy_id}: {e}")
return False
def schedule_condition_based(
self,
strategy_id: int,
condition: Callable[[], bool],
callback: Callable,
check_interval: int = 60
) -> bool:
"""Schedule strategy based on condition.
Args:
strategy_id: Strategy ID
condition: Function that returns True when condition is met
callback: Function to call when condition is met
check_interval: Interval to check condition (seconds)
Returns:
True if scheduling successful
"""
def check_and_execute():
if condition():
callback()
return self.schedule_time_based(
strategy_id,
{'type': 'interval', 'interval': check_interval},
check_and_execute
)
def unschedule(self, strategy_id: int):
"""Unschedule a strategy.
Args:
strategy_id: Strategy ID
"""
if strategy_id in self.jobs:
job_id = self.jobs[strategy_id]
try:
self.scheduler.remove_job(job_id)
del self.jobs[strategy_id]
logger.info(f"Unscheduled strategy {strategy_id}")
except Exception as e:
logger.error(f"Failed to unschedule strategy {strategy_id}: {e}")
def is_scheduled(self, strategy_id: int) -> bool:
"""Check if strategy is scheduled.
Args:
strategy_id: Strategy ID
Returns:
True if scheduled
"""
return strategy_id in self.jobs
def start_strategy(self, strategy_id: int):
"""Start a strategy execution loop."""
from sqlalchemy.orm import Session
from sqlalchemy import create_engine
from src.core.config import get_config
config = get_config()
db_url = config.get("database.url", "postgresql://localhost/crypto_trader")
# Use sync engine for scheduler context
engine = create_engine(db_url)
with Session(engine) as session:
try:
strategy_model = session.query(Strategy).filter_by(id=strategy_id).first()
if not strategy_model:
logger.error(f"Cannot start strategy {strategy_id}: Not found")
return
# Instantiate strategy
registry = get_strategy_registry()
strategy_instance = registry.create_instance(
strategy_id=strategy_id,
name=strategy_model.strategy_type,
parameters=strategy_model.parameters,
timeframes=strategy_model.timeframes
)
if not strategy_instance:
logger.error(f"Failed to create instance for strategy {strategy_id}")
return
strategy_instance.enabled = True
# Store strategy info for status tracking
self._active_strategies[strategy_id] = {
'instance': strategy_instance,
'name': strategy_model.name,
'type': strategy_model.strategy_type,
'symbol': strategy_model.parameters.get('symbol'),
'started_at': datetime.now(),
'last_tick': None,
'last_signal': None,
'signal_count': 0,
'error_count': 0,
}
# Use 'interval' from parameters, default 60s
interval = strategy_model.parameters.get('interval', 60)
def execute_wrapper():
self._execute_strategy_sync(strategy_id)
self.schedule_time_based(
strategy_id,
{'type': 'interval', 'interval': interval},
execute_wrapper
)
logger.info(f"Started strategy {strategy_id} ({strategy_model.name})")
except Exception as e:
logger.error(f"Error initiating strategy {strategy_id}: {e}")
def stop_strategy(self, strategy_id: int):
"""Stop a strategy."""
self.unschedule(strategy_id)
# Remove from active strategies
if strategy_id in self._active_strategies:
del self._active_strategies[strategy_id]
logger.info(f"Stopped strategy {strategy_id}")
def get_strategy_status(self, strategy_id: int) -> Optional[Dict[str, Any]]:
"""Get status of a running strategy."""
if strategy_id not in self._active_strategies:
return None
info = self._active_strategies[strategy_id]
return {
'strategy_id': strategy_id,
'name': info['name'],
'type': info['type'],
'symbol': info['symbol'],
'running': True,
'started_at': info['started_at'].isoformat() if info['started_at'] else None,
'last_tick': info['last_tick'].isoformat() if info['last_tick'] else None,
'last_signal': info['last_signal'],
'signal_count': info['signal_count'],
'error_count': info['error_count'],
}
def get_all_active_strategies(self) -> List[Dict[str, Any]]:
"""Get status of all active strategies."""
return [self.get_strategy_status(sid) for sid in self._active_strategies.keys()]
def _execute_strategy_sync(self, strategy_id: int):
"""Execute a single strategy cycle (synchronous wrapper)."""
if strategy_id not in self._active_strategies:
logger.warning(f"Strategy {strategy_id} not in active list, skipping execution")
return
info = self._active_strategies[strategy_id]
strategy_instance = info['instance']
try:
# 1. Fetch Data
symbol = strategy_instance.parameters.get('symbol')
timeframe = strategy_instance.timeframes[0] if strategy_instance.timeframes else '1h'
pricing_service = get_pricing_service()
ticker = pricing_service.get_ticker(symbol) # Sync call
current_price = ticker.get('last')
if not current_price:
logger.debug(f"No price for {symbol}, skipping")
return
# Update last tick
info['last_tick'] = datetime.now()
# 2. Run async on_tick in sync context
import asyncio
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
signal = loop.run_until_complete(
strategy_instance.on_tick(symbol, current_price, timeframe, ticker)
)
if not signal:
logger.debug(f"Strategy {strategy_id}: No signal at price {current_price}")
return
# Update signal tracking
info['last_signal'] = {
'type': signal.signal_type.value,
'strength': signal.strength,
'price': float(signal.price),
'timestamp': datetime.now().isoformat(),
'metadata': signal.metadata,
}
info['signal_count'] += 1
logger.info(f"Strategy {strategy_id} generated signal: {signal.signal_type.value} @ {current_price}")
# 3. Handle Signal (Execution)
trading_engine = get_trading_engine()
# Check for Pairs Trading Metadata
secondary_symbol = signal.metadata.get('secondary_symbol')
secondary_action = signal.metadata.get('secondary_action')
if secondary_symbol and secondary_action:
# Multi-Leg Execution for Pairs Trading
logger.info(f"Executing Pairs Trade: {symbol} ({signal.signal_type.value}) & {secondary_symbol} ({secondary_action})")
# Execute Primary Leg
loop.run_until_complete(trading_engine.execute_order(
exchange_id=1,
strategy_id=strategy_id,
symbol=symbol,
side=OrderSide(signal.signal_type.value),
order_type=OrderType.MARKET,
quantity=strategy_instance.calculate_position_size(
signal, trading_engine.paper_trading.get_balance(), current_price
),
paper_trading=True
))
# Execute Secondary Leg
sec_ticker = pricing_service.get_ticker(secondary_symbol)
sec_price = sec_ticker.get('last')
if sec_price:
loop.run_until_complete(trading_engine.execute_order(
exchange_id=1,
strategy_id=strategy_id,
symbol=secondary_symbol,
side=OrderSide(secondary_action),
order_type=OrderType.MARKET,
quantity=strategy_instance.calculate_position_size(
signal, trading_engine.paper_trading.get_balance(), sec_price
),
paper_trading=True,
))
else:
# Standard Single Leg Execution
loop.run_until_complete(trading_engine.execute_order(
exchange_id=1,
strategy_id=strategy_id,
symbol=symbol,
side=OrderSide(signal.signal_type.value),
order_type=OrderType.MARKET,
quantity=signal.quantity or strategy_instance.calculate_position_size(
signal, trading_engine.paper_trading.get_balance(), current_price
),
paper_trading=True
))
except Exception as e:
logger.error(f"Strategy {strategy_id} execution error: {e}")
info['error_count'] += 1
def shutdown(self):
"""Shutdown scheduler."""
self.scheduler.shutdown()
# Global scheduler
_scheduler: Optional[StrategyScheduler] = None
def get_scheduler() -> StrategyScheduler:
"""Get global strategy scheduler instance."""
global _scheduler
if _scheduler is None:
_scheduler = StrategyScheduler()
return _scheduler

View File

@@ -0,0 +1,5 @@
"""Sentiment strategy package."""
from .sentiment_strategy import SentimentStrategy
__all__ = ['SentimentStrategy']

View File

@@ -0,0 +1,208 @@
"""Sentiment-Driven Strategy - Trades based on news sentiment and market fear/greed."""
from decimal import Decimal
from typing import Dict, Any, Optional, List
import asyncio
from datetime import datetime, timedelta
from ..base import BaseStrategy, StrategySignal, SignalType
from src.data.news_collector import get_news_collector
from src.core.logger import get_logger
from src.core.config import get_config
logger = get_logger(__name__)
class SentimentStrategy(BaseStrategy):
"""
Sentiment-Driven Strategy that trades based on news sentiment and market fear.
Modes:
1. Contrarian: Buy during extreme fear, sell during extreme greed
2. Momentum: Follow positive news cycles
3. Combo: Buy on positive news + fear (best opportunity)
Signal Logic:
- Aggregates sentiment from recent news headlines
- Combines with Fear & Greed Index when available
- Generates signals based on configured mode
"""
def __init__(self, name: str, parameters: Dict[str, Any], timeframes: List[str] = None):
super().__init__(name, parameters, timeframes)
# Strategy parameters
self.mode = parameters.get('mode', 'contrarian') # 'contrarian', 'momentum', 'combo'
self.min_sentiment_score = float(parameters.get('min_sentiment_score', 0.5))
self.fear_threshold = int(parameters.get('fear_threshold', 25)) # 0-100
self.greed_threshold = int(parameters.get('greed_threshold', 75)) # 0-100
self.news_lookback_hours = int(parameters.get('news_lookback_hours', 24))
self.min_headlines = int(parameters.get('min_headlines', 5))
self.news_collector = get_news_collector()
self.config = get_config()
# Simple sentiment keywords
self.positive_keywords = [
'surge', 'rally', 'soar', 'jump', 'gain', 'bullish', 'adoption',
'partnership', 'etf', 'approval', 'institutional', 'buy', 'long',
'breakout', 'record', 'high', 'upgrade', 'positive', 'growth'
]
self.negative_keywords = [
'crash', 'drop', 'plunge', 'fall', 'dump', 'bearish', 'hack',
'scandal', 'lawsuit', 'ban', 'regulation', 'sell', 'short',
'breakdown', 'low', 'downgrade', 'negative', 'fear', 'panic'
]
# Cache for fear & greed
self._fear_greed_cache: Optional[Dict[str, Any]] = None
self._fear_greed_timestamp: Optional[datetime] = None
def _analyze_sentiment(self, headlines: List[str]) -> float:
"""Calculate aggregate sentiment score from headlines.
Returns:
Score from -1.0 (very negative) to +1.0 (very positive)
"""
if not headlines:
return 0.0
total_score = 0.0
for headline in headlines:
headline_lower = headline.lower()
pos_count = sum(1 for kw in self.positive_keywords if kw in headline_lower)
neg_count = sum(1 for kw in self.negative_keywords if kw in headline_lower)
if pos_count + neg_count > 0:
# Normalize to -1 to 1
total_score += (pos_count - neg_count) / (pos_count + neg_count)
# Average across all headlines
return total_score / len(headlines)
async def _fetch_fear_greed_index(self) -> Optional[int]:
"""Fetch the Fear & Greed Index from alternative.me API.
Returns:
Fear & Greed value (0-100) or None if unavailable
"""
# Check cache (valid for 1 hour)
if (self._fear_greed_cache and self._fear_greed_timestamp and
(datetime.now() - self._fear_greed_timestamp).total_seconds() < 3600):
return self._fear_greed_cache.get('value')
try:
import aiohttp
url = "https://api.alternative.me/fng/"
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=10) as response:
if response.status == 200:
data = await response.json()
if data.get('data'):
value = int(data['data'][0]['value'])
self._fear_greed_cache = {'value': value}
self._fear_greed_timestamp = datetime.now()
return value
except Exception as e:
self.logger.debug(f"Failed to fetch Fear & Greed Index: {e}")
return None
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Check for sentiment-based signals."""
if not self.enabled:
return None
try:
# Extract base symbol (e.g., "BTC" from "BTC/USD")
base_symbol = symbol.split('/')[0] if '/' in symbol else symbol
# Fetch news headlines
headlines = await self.news_collector.fetch_headlines(
symbols=[base_symbol],
max_age_hours=self.news_lookback_hours
)
if len(headlines) < self.min_headlines:
self.logger.debug(f"Not enough headlines: {len(headlines)} < {self.min_headlines}")
return None
# Calculate sentiment score
sentiment = self._analyze_sentiment(headlines)
# Fetch Fear & Greed Index
fear_greed = await self._fetch_fear_greed_index()
self.logger.info(
f"Sentiment Analysis {symbol}: Score={sentiment:.2f}, "
f"Fear&Greed={fear_greed}, Headlines={len(headlines)}"
)
signal_type = None
description = ""
if self.mode == 'contrarian':
# Contrarian: Buy on extreme fear, sell on extreme greed
if fear_greed is not None:
if fear_greed < self.fear_threshold:
signal_type = SignalType.BUY
description = f"Contrarian BUY: Extreme Fear ({fear_greed})"
elif fear_greed > self.greed_threshold:
signal_type = SignalType.SELL
description = f"Contrarian SELL: Extreme Greed ({fear_greed})"
elif self.mode == 'momentum':
# Momentum: Follow positive/negative sentiment
if sentiment > self.min_sentiment_score:
signal_type = SignalType.BUY
description = f"Momentum BUY: Positive Sentiment ({sentiment:.2f})"
elif sentiment < -self.min_sentiment_score:
signal_type = SignalType.SELL
description = f"Momentum SELL: Negative Sentiment ({sentiment:.2f})"
elif self.mode == 'combo':
# Combo: Positive news + fearful market = best buying opportunity
if fear_greed is not None:
if sentiment > self.min_sentiment_score and fear_greed < 40:
signal_type = SignalType.BUY
description = f"Combo BUY: Positive News + Fear ({sentiment:.2f}, FG={fear_greed})"
elif sentiment < -self.min_sentiment_score and fear_greed > 60:
signal_type = SignalType.SELL
description = f"Combo SELL: Negative News + Greed ({sentiment:.2f}, FG={fear_greed})"
if signal_type:
# Calculate strength based on sentiment and fear/greed extremity
strength = abs(sentiment)
if fear_greed is not None:
fg_extremity = abs(50 - fear_greed) / 50 # How far from neutral
strength = (strength + fg_extremity) / 2
signal = StrategySignal(
signal_type=signal_type,
symbol=symbol,
strength=min(1.0, strength),
price=price,
metadata={
"strategy": "sentiment",
"mode": self.mode,
"sentiment_score": float(sentiment),
"fear_greed": fear_greed,
"headline_count": len(headlines),
"sample_headlines": headlines[:3], # Include first 3 for context
"description": description
}
)
self.logger.info(f"Sentiment Signal: {description}")
return signal
except Exception as e:
self.logger.error(f"Error in SentimentStrategy: {e}")
return None
return None
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal (pass-through)."""
return signal

View File

View File

@@ -0,0 +1,227 @@
"""Bollinger Bands mean reversion strategy.
Buys when price touches lower band in uptrend, sells when price
touches upper band in downtrend. Works well in ranging markets.
"""
from decimal import Decimal
from typing import Optional, Dict, Any, List
import pandas as pd
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.data.indicators import get_indicators
from src.core.logger import get_logger
logger = get_logger(__name__)
class BollingerMeanReversionStrategy(BaseStrategy):
"""Bollinger Bands mean reversion strategy.
This strategy trades mean reversion using Bollinger Bands:
- Buy when price touches lower band in uptrend
- Sell when price touches upper band in downtrend
- Uses trend filter to avoid counter-trend trades
"""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize Bollinger Bands mean reversion strategy.
Parameters:
period: Moving average period for Bollinger Bands (default 20)
std_dev: Standard deviation multiplier (default 2.0)
trend_filter: Enable trend filter (default True)
trend_ma_period: Moving average period for trend detection (default 50)
entry_threshold: How close price must be to band (0.0-1.0, default 0.95)
exit_threshold: Exit when price reaches middle band (default 0.5)
"""
super().__init__(name, parameters, timeframes)
self.period = self.parameters.get('period', 20)
self.std_dev = self.parameters.get('std_dev', 2.0)
self.trend_filter = self.parameters.get('trend_filter', True)
self.trend_ma_period = self.parameters.get('trend_ma_period', 50)
self.entry_threshold = self.parameters.get('entry_threshold', 0.95)
self.exit_threshold = self.parameters.get('exit_threshold', 0.5)
self.indicators = get_indicators()
self._price_history: List[float] = []
self._in_position = False
self._entry_price: Optional[Decimal] = None
def _get_trend_direction(self, prices: pd.Series) -> Optional[str]:
"""Determine trend direction.
Args:
prices: Price series
Returns:
'up', 'down', or None
"""
if len(prices) < self.trend_ma_period:
return None
sma = self.indicators.sma(prices, self.trend_ma_period)
current_price = prices.iloc[-1]
sma_value = sma.iloc[-1]
# Use percentage difference to determine trend
price_diff = (current_price - sma_value) / sma_value if sma_value > 0 else 0.0
if price_diff > 0.01: # 1% above SMA = uptrend
return 'up'
elif price_diff < -0.01: # 1% below SMA = downtrend
return 'down'
return None
def _check_band_touch(self, price: float, upper: float, middle: float, lower: float) -> Optional[str]:
"""Check if price is touching a band.
Args:
price: Current price
upper: Upper Bollinger Band
middle: Middle Bollinger Band
lower: Lower Bollinger Band
Returns:
'upper' if touching upper band, 'lower' if touching lower band, None otherwise
"""
band_width = upper - lower
if band_width == 0:
return None
# Calculate position within bands (0 = lower, 1 = upper)
position = (price - lower) / band_width
# Check if touching lower band
if position <= (1.0 - self.entry_threshold):
return 'lower'
# Check if touching upper band
if position >= self.entry_threshold:
return 'upper'
return None
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate signal based on Bollinger Bands mean reversion."""
# Add price to history
self._price_history.append(float(price))
# Need enough data for Bollinger Bands and trend detection
max_period = max(self.period, self.trend_ma_period)
if len(self._price_history) < max_period + 1:
return None
prices = pd.Series(self._price_history[-max_period-1:])
current_price = float(price)
# Calculate Bollinger Bands
bb_data = self.indicators.bollinger_bands(prices, period=self.period, std_dev=self.std_dev)
if len(bb_data['upper']) == 0 or pd.isna(bb_data['upper'].iloc[-1]):
return None
upper = bb_data['upper'].iloc[-1]
middle = bb_data['middle'].iloc[-1]
lower = bb_data['lower'].iloc[-1]
# Determine trend direction if using trend filter
trend = None
if self.trend_filter:
trend = self._get_trend_direction(prices)
# Check for entry signals
if not self._in_position:
band_touch = self._check_band_touch(current_price, upper, middle, lower)
if band_touch == 'lower':
# Price touching lower band - potential buy signal
if not self.trend_filter or trend == 'up':
# Only buy in uptrend (mean reversion back up)
self._in_position = True
self._entry_price = price
# Calculate signal strength based on how far below lower band
distance_from_lower = (lower - current_price) / lower if lower > 0 else 0.0
strength = min(1.0, distance_from_lower * 10) # Scale for strength
return StrategySignal(
signal_type=SignalType.BUY,
symbol=symbol,
strength=max(0.5, strength),
price=price,
metadata={
'band_touch': 'lower',
'upper': float(upper),
'middle': float(middle),
'lower': float(lower),
'price_position': (current_price - lower) / (upper - lower) if (upper - lower) > 0 else 0.5,
'trend': trend,
'strategy': 'bollinger_mean_reversion',
'type': 'entry'
}
)
elif band_touch == 'upper':
# Price touching upper band - potential sell signal (for short)
if not self.trend_filter or trend == 'down':
# Only sell in downtrend (mean reversion back down)
# Note: This is a SELL signal (could be used for shorting or exiting long positions)
# For mean reversion, we typically only go long, so this might be an exit signal
# For simplicity, we'll make it a SELL signal
pass # Could implement short entry here if desired
else:
# In position - check for exit
entry_price = self._entry_price or price
entry_float = float(entry_price)
# Exit when price reaches middle band (mean reversion complete)
band_width = upper - lower
position_in_bands = (current_price - lower) / band_width if band_width > 0 else 0.5
# Exit conditions
should_exit = False
exit_reason = None
if position_in_bands >= self.exit_threshold:
# Price has moved back toward middle - take profit
should_exit = True
exit_reason = 'target_reached'
elif band_touch := self._check_band_touch(current_price, upper, middle, lower):
# Price touched opposite band - stop loss
if band_touch == 'upper' and entry_float < middle:
should_exit = True
exit_reason = 'stop_loss'
if should_exit:
self._in_position = False
profit_pct = (current_price - entry_float) / entry_float if entry_float > 0 else 0.0
return StrategySignal(
signal_type=SignalType.SELL,
symbol=symbol,
strength=1.0,
price=price,
metadata={
'entry_price': float(entry_price),
'exit_price': current_price,
'profit_pct': profit_pct * 100,
'exit_reason': exit_reason,
'strategy': 'bollinger_mean_reversion',
'type': 'exit'
}
)
return None
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal."""
if signal.signal_type == SignalType.SELL:
self._in_position = False
self._entry_price = None
return signal if self.should_execute(signal) else None

View File

@@ -0,0 +1,246 @@
"""Multi-indicator confirmation strategy.
This strategy requires multiple indicators (RSI, MACD, Moving Average) to align
before generating a signal, reducing false signals significantly.
"""
from decimal import Decimal
from typing import Optional, Dict, Any, List
import pandas as pd
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.data.indicators import get_indicators
from src.core.logger import get_logger
logger = get_logger(__name__)
class ConfirmedStrategy(BaseStrategy):
"""Multi-indicator confirmation strategy.
Combines RSI, MACD, and Moving Average signals and only generates
signals when a configurable number of indicators agree.
"""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize confirmed strategy.
Parameters:
rsi_period: RSI period (default 14)
rsi_oversold: RSI oversold threshold (default 30)
rsi_overbought: RSI overbought threshold (default 70)
macd_fast: MACD fast period (default 12)
macd_slow: MACD slow period (default 26)
macd_signal: MACD signal period (default 9)
ma_fast: Fast MA period (default 10)
ma_slow: Slow MA period (default 30)
ma_type: MA type - 'sma' or 'ema' (default 'ema')
min_confirmations: Minimum number of indicators that must agree (default 2)
require_rsi: Require RSI confirmation (default True)
require_macd: Require MACD confirmation (default True)
require_ma: Require MA confirmation (default True)
"""
super().__init__(name, parameters, timeframes)
# RSI parameters
self.rsi_period = self.parameters.get('rsi_period', 14)
self.rsi_oversold = self.parameters.get('rsi_oversold', 30)
self.rsi_overbought = self.parameters.get('rsi_overbought', 70)
# MACD parameters
self.macd_fast = self.parameters.get('macd_fast', 12)
self.macd_slow = self.parameters.get('macd_slow', 26)
self.macd_signal = self.parameters.get('macd_signal', 9)
# MA parameters
self.ma_fast = self.parameters.get('ma_fast', 10)
self.ma_slow = self.parameters.get('ma_slow', 30)
self.ma_type = self.parameters.get('ma_type', 'ema')
# Confirmation parameters
self.min_confirmations = self.parameters.get('min_confirmations', 2)
self.require_rsi = self.parameters.get('require_rsi', True)
self.require_macd = self.parameters.get('require_macd', True)
self.require_ma = self.parameters.get('require_ma', True)
self.indicators = get_indicators()
self._price_history: List[float] = []
def _check_rsi_signal(self, prices: pd.Series) -> Optional[SignalType]:
"""Check RSI for signal.
Args:
prices: Price series
Returns:
SignalType if RSI indicates signal, None otherwise
"""
if len(prices) < self.rsi_period + 1:
return None
rsi = self.indicators.rsi(prices, self.rsi_period)
if len(rsi) == 0:
return None
current_rsi = rsi.iloc[-1]
if current_rsi < self.rsi_oversold:
return SignalType.BUY
elif current_rsi > self.rsi_overbought:
return SignalType.SELL
return None
def _check_macd_signal(self, prices: pd.Series) -> Optional[SignalType]:
"""Check MACD for signal.
Args:
prices: Price series
Returns:
SignalType if MACD indicates signal, None otherwise
"""
if len(prices) < self.macd_slow + self.macd_signal:
return None
macd_data = self.indicators.macd(prices, self.macd_fast, self.macd_slow, self.macd_signal)
if len(macd_data['macd']) < 2:
return None
macd = macd_data['macd'].iloc[-1]
signal_line = macd_data['signal'].iloc[-1]
prev_macd = macd_data['macd'].iloc[-2]
prev_signal = macd_data['signal'].iloc[-2]
# Bullish crossover
if prev_macd <= prev_signal and macd > signal_line:
return SignalType.BUY
# Bearish crossover
elif prev_macd >= prev_signal and macd < signal_line:
return SignalType.SELL
return None
def _check_ma_signal(self, prices: pd.Series) -> Optional[SignalType]:
"""Check Moving Average for signal.
Args:
prices: Price series
Returns:
SignalType if MA indicates signal, None otherwise
"""
if len(prices) < self.ma_slow + 1:
return None
if self.ma_type == 'sma':
fast_ma = self.indicators.sma(prices, self.ma_fast)
slow_ma = self.indicators.sma(prices, self.ma_slow)
else:
fast_ma = self.indicators.ema(prices, self.ma_fast)
slow_ma = self.indicators.ema(prices, self.ma_slow)
if len(fast_ma) < 2 or len(slow_ma) < 2:
return None
fast_current = fast_ma.iloc[-1]
fast_prev = fast_ma.iloc[-2]
slow_current = slow_ma.iloc[-1]
slow_prev = slow_ma.iloc[-2]
# Bullish crossover
if fast_prev <= slow_prev and fast_current > slow_current:
return SignalType.BUY
# Bearish crossover
elif fast_prev >= slow_prev and fast_current < slow_current:
return SignalType.SELL
return None
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate signal based on multi-indicator confirmation."""
# Add price to history
self._price_history.append(float(price))
# Determine minimum required data points
max_period = max(
self.rsi_period + 1,
self.macd_slow + self.macd_signal,
self.ma_slow + 1
)
if len(self._price_history) < max_period:
return None
prices = pd.Series(self._price_history[-max_period:])
# Collect signals from each indicator
signals: List[SignalType] = []
signal_metadata: Dict[str, Any] = {}
# Check RSI
if self.require_rsi:
rsi_signal = self._check_rsi_signal(prices)
if rsi_signal:
signals.append(rsi_signal)
signal_metadata['rsi'] = rsi_signal.value
# Check MACD
if self.require_macd:
macd_signal = self._check_macd_signal(prices)
if macd_signal:
signals.append(macd_signal)
signal_metadata['macd'] = macd_signal.value
# Check MA
if self.require_ma:
ma_signal = self._check_ma_signal(prices)
if ma_signal:
signals.append(ma_signal)
signal_metadata['ma'] = ma_signal.value
# Count confirmations for each signal type
buy_count = signals.count(SignalType.BUY)
sell_count = signals.count(SignalType.SELL)
# Determine final signal based on confirmations
final_signal = None
confirmation_count = 0
if buy_count >= self.min_confirmations:
final_signal = SignalType.BUY
confirmation_count = buy_count
elif sell_count >= self.min_confirmations:
final_signal = SignalType.SELL
confirmation_count = sell_count
if final_signal is None:
return None
# Calculate signal strength based on number of confirmations
max_possible = sum([self.require_rsi, self.require_macd, self.require_ma])
strength = min(1.0, confirmation_count / max_possible) if max_possible > 0 else 0.5
return StrategySignal(
signal_type=final_signal,
symbol=symbol,
strength=strength,
price=price,
metadata={
'confirmation_count': confirmation_count,
'max_confirmations': max_possible,
'signals': signal_metadata,
'strategy': 'confirmed',
'rsi_period': self.rsi_period,
'macd_fast': self.macd_fast,
'macd_slow': self.macd_slow,
'ma_type': self.ma_type,
'ma_fast': self.ma_fast,
'ma_slow': self.ma_slow
}
)
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal."""
return signal if self.should_execute(signal) else None

View File

@@ -0,0 +1,154 @@
"""Divergence detection strategy.
Detects price vs. indicator divergences (RSI/MACD) which are powerful
reversal signals with high success rates in ranging markets.
"""
from decimal import Decimal
from typing import Optional, Dict, Any, List
import pandas as pd
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.data.indicators import get_indicators
from src.core.logger import get_logger
logger = get_logger(__name__)
class DivergenceStrategy(BaseStrategy):
"""Divergence detection strategy.
Detects bullish and bearish divergences between price and indicators
(RSI or MACD) to identify potential reversals.
- Bullish divergence: Price makes lower low, indicator makes higher low → BUY
- Bearish divergence: Price makes higher high, indicator makes lower high → SELL
"""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize divergence strategy.
Parameters:
indicator_type: Type of indicator - 'rsi' or 'macd' (default 'rsi')
rsi_period: RSI period if using RSI (default 14)
macd_fast: MACD fast period if using MACD (default 12)
macd_slow: MACD slow period if using MACD (default 26)
macd_signal: MACD signal period if using MACD (default 9)
lookback: Lookback period for finding swings (default 20)
min_swings: Minimum number of swings to detect divergence (default 2)
min_confidence: Minimum confidence threshold for signal (default 0.5)
"""
super().__init__(name, parameters, timeframes)
self.indicator_type = self.parameters.get('indicator_type', 'rsi').lower()
self.rsi_period = self.parameters.get('rsi_period', 14)
self.macd_fast = self.parameters.get('macd_fast', 12)
self.macd_slow = self.parameters.get('macd_slow', 26)
self.macd_signal = self.parameters.get('macd_signal', 9)
self.lookback = self.parameters.get('lookback', 20)
self.min_swings = self.parameters.get('min_swings', 2)
self.min_confidence = self.parameters.get('min_confidence', 0.5)
self.indicators = get_indicators()
self._price_history: List[float] = []
self._last_divergence_type: Optional[str] = None
def _calculate_indicator(self, prices: pd.Series) -> Optional[pd.Series]:
"""Calculate the selected indicator.
Args:
prices: Price series
Returns:
Indicator series or None
"""
if self.indicator_type == 'rsi':
if len(prices) < self.rsi_period + 1:
return None
return self.indicators.rsi(prices, self.rsi_period)
elif self.indicator_type == 'macd':
if len(prices) < self.macd_slow + self.macd_signal:
return None
macd_data = self.indicators.macd(prices, self.macd_fast, self.macd_slow, self.macd_signal)
return macd_data['macd']
return None
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate signal based on divergence detection."""
# Add price to history
self._price_history.append(float(price))
# Determine minimum required data points
if self.indicator_type == 'rsi':
min_period = self.rsi_period + self.lookback * 3
else:
min_period = self.macd_slow + self.macd_signal + self.lookback * 3
if len(self._price_history) < min_period:
return None
prices = pd.Series(self._price_history[-min_period:])
# Calculate indicator
indicator = self._calculate_indicator(prices)
if indicator is None or len(indicator) < self.lookback * 2:
return None
# Align prices and indicator (indicator may be shorter due to calculation)
if len(prices) != len(indicator):
# Take the last len(indicator) prices to match indicator length
prices = prices.iloc[-len(indicator):]
indicator = indicator.reset_index(drop=True) if hasattr(indicator, 'reset_index') else indicator
# Detect divergence
divergence_result = self.indicators.detect_divergence(
prices=prices,
indicator=indicator,
lookback=self.lookback,
min_swings=self.min_swings
)
divergence_type = divergence_result.get('type')
confidence = divergence_result.get('confidence', 0.0)
# Check if we have a valid divergence with sufficient confidence
if divergence_type is None or confidence < self.min_confidence:
return None
# Only generate signal if divergence type changed (avoid repeated signals)
if divergence_type == self._last_divergence_type:
return None
self._last_divergence_type = divergence_type
# Map divergence type to signal type
if divergence_type == 'bullish':
signal_type = SignalType.BUY
elif divergence_type == 'bearish':
signal_type = SignalType.SELL
else:
return None
return StrategySignal(
signal_type=signal_type,
symbol=symbol,
strength=confidence,
price=price,
metadata={
'divergence_type': divergence_type,
'confidence': confidence,
'indicator_type': self.indicator_type,
'lookback': self.lookback,
'price_swing_high': divergence_result.get('price_swing_high'),
'price_swing_low': divergence_result.get('price_swing_low'),
'indicator_swing_high': divergence_result.get('indicator_swing_high'),
'indicator_swing_low': divergence_result.get('indicator_swing_low'),
'strategy': 'divergence'
}
)
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal."""
return signal if self.should_execute(signal) else None

View File

@@ -0,0 +1,72 @@
"""MACD (Moving Average Convergence Divergence) strategy."""
import pandas as pd
from decimal import Decimal
from typing import Optional, Dict, Any
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.data.indicators import get_indicators
class MACDStrategy(BaseStrategy):
"""MACD-based trading strategy."""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize MACD strategy.
Parameters:
fast: Fast EMA period (default 12)
slow: Slow EMA period (default 26)
signal: Signal line period (default 9)
"""
super().__init__(name, parameters, timeframes)
self.fast = self.parameters.get('fast', 12)
self.slow = self.parameters.get('slow', 26)
self.signal = self.parameters.get('signal', 9)
self.indicators = get_indicators()
self._price_history = []
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate signal based on MACD."""
# Add price to history
self._price_history.append(float(price))
if len(self._price_history) < self.slow + self.signal:
return None
# Calculate MACD
prices = pd.Series(self._price_history[-self.slow-self.signal:])
macd_data = self.indicators.macd(prices, self.fast, self.slow, self.signal)
if len(macd_data['macd']) < 2:
return None
macd = macd_data['macd'].iloc[-1]
signal_line = macd_data['signal'].iloc[-1]
prev_macd = macd_data['macd'].iloc[-2]
prev_signal = macd_data['signal'].iloc[-2]
# Bullish crossover
if prev_macd <= prev_signal and macd > signal_line:
return StrategySignal(
signal_type=SignalType.BUY,
symbol=symbol,
strength=min(1.0, abs(macd - signal_line) / abs(prev_macd - prev_signal) if prev_macd != prev_signal else 1.0),
price=price,
metadata={'macd': float(macd), 'signal': float(signal_line)}
)
# Bearish crossover
elif prev_macd >= prev_signal and macd < signal_line:
return StrategySignal(
signal_type=SignalType.SELL,
symbol=symbol,
strength=min(1.0, abs(macd - signal_line) / abs(prev_macd - prev_signal) if prev_macd != prev_signal else 1.0),
price=price,
metadata={'macd': float(macd), 'signal': float(signal_line)}
)
return None
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal."""
return signal if self.should_execute(signal) else None

View File

@@ -0,0 +1,79 @@
"""Moving Average Crossover strategy."""
import pandas as pd
from decimal import Decimal
from typing import Optional, Dict, Any
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.data.indicators import get_indicators
class MovingAverageStrategy(BaseStrategy):
"""Moving average crossover strategy."""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize moving average strategy.
Parameters:
fast_period: Fast MA period (default 10)
slow_period: Slow MA period (default 30)
ma_type: MA type - 'sma' or 'ema' (default 'ema')
"""
super().__init__(name, parameters, timeframes)
self.fast_period = self.parameters.get('fast_period', 10)
self.slow_period = self.parameters.get('slow_period', 30)
self.ma_type = self.parameters.get('ma_type', 'ema')
self.indicators = get_indicators()
self._price_history = []
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate signal based on MA crossover."""
# Add price to history
self._price_history.append(float(price))
if len(self._price_history) < self.slow_period + 1:
return None
# Calculate MAs
prices = pd.Series(self._price_history[-self.slow_period-1:])
if self.ma_type == 'sma':
fast_ma = self.indicators.sma(prices, self.fast_period)
slow_ma = self.indicators.sma(prices, self.slow_period)
else:
fast_ma = self.indicators.ema(prices, self.fast_period)
slow_ma = self.indicators.ema(prices, self.slow_period)
if len(fast_ma) < 2 or len(slow_ma) < 2:
return None
# Check for crossover
fast_current = fast_ma.iloc[-1]
fast_prev = fast_ma.iloc[-2]
slow_current = slow_ma.iloc[-1]
slow_prev = slow_ma.iloc[-2]
# Bullish crossover
if fast_prev <= slow_prev and fast_current > slow_current:
return StrategySignal(
signal_type=SignalType.BUY,
symbol=symbol,
strength=min(1.0, (fast_current - slow_current) / slow_current),
price=price,
metadata={'fast_ma': float(fast_current), 'slow_ma': float(slow_current)}
)
# Bearish crossover
elif fast_prev >= slow_prev and fast_current < slow_current:
return StrategySignal(
signal_type=SignalType.SELL,
symbol=symbol,
strength=min(1.0, (slow_current - fast_current) / slow_current),
price=price,
metadata={'fast_ma': float(fast_current), 'slow_ma': float(slow_current)}
)
return None
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal."""
return signal if self.should_execute(signal) else None

View File

@@ -0,0 +1,146 @@
"""Statistical Arbitrage (Pairs Trading) Strategy."""
from decimal import Decimal
from typing import Dict, Any, Optional, List
import pandas as pd
import numpy as np
from ..base import BaseStrategy, StrategySignal, SignalType
from src.data.pricing_service import get_pricing_service
from src.core.logger import get_logger
logger = get_logger(__name__)
class PairsTradingStrategy(BaseStrategy):
"""
Statistical Arbitrage Strategy that trades the spread between two correlated assets.
Logic:
1. Calculate Spread = Price(A) / Price(B)
2. Calculate Z-Score of the spread over a rolling window.
3. Mean Reversion signals:
- Z-Score > Threshold: Short Spread (Sell A, Buy B)
- Z-Score < -Threshold: Long Spread (Buy A, Sell B)
- Z-Score approx 0: Close/Neutralize (Exit both)
"""
def __init__(self, name: str, parameters: Dict[str, Any], timeframes: List[str] = None):
super().__init__(name, parameters, timeframes)
self.second_symbol = parameters.get('second_symbol')
self.lookback_window = int(parameters.get('lookback_period', 20))
self.z_threshold = float(parameters.get('z_score_threshold', 2.0))
self.pricing_service = get_pricing_service()
if not self.second_symbol:
logger.warning(f"PairsTradingStrategy {name} initialized without 'second_symbol'")
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Check for pairs trading signals."""
if not self.second_symbol or not self.enabled:
return None
# We only process on the "Primary" symbol's tick to avoid double-processing
# (Assuming the strategy is assigned to the primary symbol in the DB)
# 1. Fetch data for BOTH symbols
# We need historical data to calculate Z-Score
try:
# Fetch for Primary (Symbol A)
ohlcv_a = self.pricing_service.get_ohlcv(
symbol=symbol,
timeframe=timeframe,
limit=self.lookback_window + 5
)
# Fetch for Secondary (Symbol B)
ohlcv_b = self.pricing_service.get_ohlcv(
symbol=self.second_symbol,
timeframe=timeframe,
limit=self.lookback_window + 5
)
if not ohlcv_a or not ohlcv_b:
return None
# Convert to Series for pandas calc
# OHLCV format: [timestamp, open, high, low, close, volume]
closes_a = pd.Series([float(c[4]) for c in ohlcv_a])
closes_b = pd.Series([float(c[4]) for c in ohlcv_b])
# Ensure equal length if fetched data differs slightly
min_len = min(len(closes_a), len(closes_b))
closes_a = closes_a.iloc[-min_len:]
closes_b = closes_b.iloc[-min_len:]
# 2. Calculate Spread and Z-Score
# Spread = A / B
spread = closes_a / closes_b
# Rolling statistics
rolling_mean = spread.rolling(window=self.lookback_window).mean()
rolling_std = spread.rolling(window=self.lookback_window).std()
current_spread = spread.iloc[-1]
current_mean = rolling_mean.iloc[-1]
current_std = rolling_std.iloc[-1]
if pd.isna(current_std) or current_std == 0:
return None
z_score = (current_spread - current_mean) / current_std
self.logger.info(
f"Pairs {symbol}/{self.second_symbol}: Spread={current_spread:.4f}, Z-Score={z_score:.2f}"
)
# 3. Generate Signals
# Strategy:
# If Z > Threshold -> Spread is too high (A is expensive, B is cheap) -> Sell A, Buy B
# If Z < -Threshold -> Spread is too low (A is cheap, B is expensive) -> Buy A, Sell B
# If abs(Z) < 0.5 -> Mean reverted -> Close Positions (optional, or just go to neutral)
signal_type = None
primary_side = "hold"
secondary_side = "hold"
if z_score > self.z_threshold:
# Sell A, Buy B
signal_type = SignalType.SELL
primary_side = "sell"
secondary_side = "buy"
elif z_score < -self.z_threshold:
# Buy A, Sell B
signal_type = SignalType.BUY
primary_side = "buy"
secondary_side = "sell"
# TODO: Logic for closing when Z ~ 0 can be added here
if signal_type:
# Create Signal for Primary
signal = StrategySignal(
signal_type=signal_type,
symbol=symbol,
strength=min(abs(z_score) / self.z_threshold, 1.0), # Strength capped at 1.0
price=price,
metadata={
"strategy": "pairs_trading",
"z_score": float(z_score),
"spread": float(current_spread),
"secondary_symbol": self.second_symbol,
"secondary_action": secondary_side,
"description": f"Pairs Arbitrage: Z-Score {z_score:.2f} triggered {primary_side} {symbol} / {secondary_side} {self.second_symbol}"
}
)
return signal
except Exception as e:
self.logger.error(f"Error in PairsTradingStrategy: {e}")
return None
return None
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal (pass-through)."""
return signal

View File

@@ -0,0 +1,67 @@
"""RSI (Relative Strength Index) strategy."""
import pandas as pd
from decimal import Decimal
from typing import Optional, Dict, Any
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.data.indicators import get_indicators
class RSIStrategy(BaseStrategy):
"""RSI-based trading strategy."""
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
"""Initialize RSI strategy.
Parameters:
rsi_period: RSI period (default 14)
oversold: Oversold threshold (default 30)
overbought: Overbought threshold (default 70)
"""
super().__init__(name, parameters, timeframes)
self.rsi_period = self.parameters.get('rsi_period', 14)
self.oversold = self.parameters.get('oversold', 30)
self.overbought = self.parameters.get('overbought', 70)
self.indicators = get_indicators()
self._price_history = []
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Generate signal based on RSI."""
# Add price to history
self._price_history.append(float(price))
if len(self._price_history) < self.rsi_period + 1:
return None
# Calculate RSI
prices = pd.Series(self._price_history[-self.rsi_period-1:])
rsi = self.indicators.rsi(prices, self.rsi_period)
if len(rsi) == 0:
return None
current_rsi = rsi.iloc[-1]
# Generate signals
if current_rsi < self.oversold:
return StrategySignal(
signal_type=SignalType.BUY,
symbol=symbol,
strength=1.0 - (current_rsi / self.oversold),
price=price,
metadata={'rsi': float(current_rsi)}
)
elif current_rsi > self.overbought:
return StrategySignal(
signal_type=SignalType.SELL,
symbol=symbol,
strength=(current_rsi - self.overbought) / (100 - self.overbought),
price=price,
metadata={'rsi': float(current_rsi)}
)
return None
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal."""
return signal if self.should_execute(signal) else None

View File

@@ -0,0 +1,177 @@
"""Volatility Breakout Strategy - Captures explosive moves after consolidation."""
from decimal import Decimal
from typing import Dict, Any, Optional, List
import pandas as pd
import numpy as np
from ..base import BaseStrategy, StrategySignal, SignalType
from src.data.pricing_service import get_pricing_service
from src.data.indicators import get_indicators
from src.core.logger import get_logger
logger = get_logger(__name__)
class VolatilityBreakoutStrategy(BaseStrategy):
"""
Volatility Breakout Strategy that identifies and trades breakouts from consolidation.
Logic:
1. Detect consolidation using Bollinger Band Width (squeeze).
2. Confirm breakout when price exits the bands with volume confirmation.
3. Use ADX to ensure the trend is strong enough to follow.
Entry Conditions (BUY):
- Bollinger Band Width < Squeeze Threshold (consolidation detected)
- Price breaks above Upper Bollinger Band
- Volume > 20-day average volume (confirmation)
- ADX > 25 (strong trend)
Entry Conditions (SELL):
- Price breaks below Lower Bollinger Band
- Volume > 20-day average volume
- ADX > 25 (strong trend)
"""
def __init__(self, name: str, parameters: Dict[str, Any], timeframes: List[str] = None):
super().__init__(name, parameters, timeframes)
# Strategy parameters
self.bb_period = int(parameters.get('bb_period', 20))
self.bb_std_dev = float(parameters.get('bb_std_dev', 2.0))
self.squeeze_threshold = float(parameters.get('squeeze_threshold', 0.1))
self.volume_multiplier = float(parameters.get('volume_multiplier', 1.5))
self.adx_period = int(parameters.get('adx_period', 14))
self.min_adx = float(parameters.get('min_adx', 25.0))
self.use_volume_filter = parameters.get('use_volume_filter', True)
self.pricing_service = get_pricing_service()
self.indicators = get_indicators()
# Track squeeze state
self._in_squeeze = False
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
"""Check for volatility breakout signals."""
if not self.enabled:
return None
try:
# Fetch OHLCV data
ohlcv = self.pricing_service.get_ohlcv(
symbol=symbol,
timeframe=timeframe,
limit=max(self.bb_period, self.adx_period) + 30
)
if not ohlcv or len(ohlcv) < self.bb_period + 10:
return None
# Convert to DataFrame
df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
# Calculate Bollinger Bands
close = df['close'].astype(float)
high = df['high'].astype(float)
low = df['low'].astype(float)
volume = df['volume'].astype(float)
bb_upper, bb_middle, bb_lower = self.indicators.bollinger_bands(
close, period=self.bb_period, std_dev=self.bb_std_dev
)
# Calculate Bollinger Band Width (squeeze indicator)
bb_width = (bb_upper - bb_lower) / bb_middle
current_width = bb_width.iloc[-1]
# Calculate ADX for trend strength
adx = self.indicators.adx(high, low, close, period=self.adx_period)
current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0
# Calculate volume metrics
avg_volume = volume.rolling(window=20).mean().iloc[-1]
current_volume = volume.iloc[-1]
volume_ratio = current_volume / avg_volume if avg_volume > 0 else 0
current_price = float(price)
upper_band = bb_upper.iloc[-1]
lower_band = bb_lower.iloc[-1]
# Check for squeeze (consolidation)
is_squeeze = current_width < self.squeeze_threshold
was_in_squeeze = self._in_squeeze
self._in_squeeze = is_squeeze
self.logger.debug(
f"Volatility Breakout {symbol}: Width={current_width:.4f}, "
f"ADX={current_adx:.1f}, Vol Ratio={volume_ratio:.2f}, "
f"Squeeze={is_squeeze}"
)
# No signal if not breaking out of a squeeze
if not was_in_squeeze:
return None
# Volume filter
if self.use_volume_filter and volume_ratio < self.volume_multiplier:
self.logger.debug(f"Volume filter: {volume_ratio:.2f} < {self.volume_multiplier}")
return None
# ADX filter - ensure trend is strong
if current_adx < self.min_adx:
self.logger.debug(f"ADX filter: {current_adx:.1f} < {self.min_adx}")
return None
# Check for breakout
signal_type = None
# Bullish breakout - price breaks above upper band
if current_price > upper_band:
signal_type = SignalType.BUY
self.logger.info(
f"BULLISH BREAKOUT: {symbol} @ {current_price:.2f} > Upper Band {upper_band:.2f}"
)
# Bearish breakout - price breaks below lower band
elif current_price < lower_band:
signal_type = SignalType.SELL
self.logger.info(
f"BEARISH BREAKOUT: {symbol} @ {current_price:.2f} < Lower Band {lower_band:.2f}"
)
if signal_type:
# Calculate signal strength based on multiple factors
strength = min(1.0, (
(current_adx / 50) * 0.4 + # ADX contribution
(volume_ratio / 3) * 0.4 + # Volume contribution
0.2 # Base strength for breakout
))
signal = StrategySignal(
signal_type=signal_type,
symbol=symbol,
strength=strength,
price=price,
metadata={
"strategy": "volatility_breakout",
"bb_width": float(current_width),
"adx": float(current_adx),
"volume_ratio": float(volume_ratio),
"upper_band": float(upper_band),
"lower_band": float(lower_band),
"was_squeeze": was_in_squeeze,
"description": f"Volatility Breakout: ADX={current_adx:.1f}, Vol={volume_ratio:.1f}x"
}
)
return signal
except Exception as e:
self.logger.error(f"Error in VolatilityBreakoutStrategy: {e}")
return None
return None
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
"""Process signal (pass-through)."""
return signal

View File

@@ -0,0 +1,103 @@
"""Multi-timeframe strategy framework with synchronization."""
from typing import Dict, List, Optional, Any
from datetime import datetime
from decimal import Decimal
from src.core.logger import get_logger
logger = get_logger(__name__)
class TimeframeManager:
"""Manages multiple timeframes for strategies."""
def __init__(self, timeframes: List[str]):
"""Initialize timeframe manager.
Args:
timeframes: List of timeframes (e.g., ['1h', '15m'])
"""
self.timeframes = sorted(timeframes, key=self._timeframe_to_seconds, reverse=True)
self.data: Dict[str, Dict[str, Any]] = {tf: {} for tf in self.timeframes}
self.last_update: Dict[str, datetime] = {tf: datetime.min for tf in self.timeframes}
def _timeframe_to_seconds(self, tf: str) -> int:
"""Convert timeframe to seconds for sorting.
Args:
tf: Timeframe string (e.g., '1h', '15m')
Returns:
Seconds
"""
if tf.endswith('m'):
return int(tf[:-1]) * 60
elif tf.endswith('h'):
return int(tf[:-1]) * 3600
elif tf.endswith('d'):
return int(tf[:-1]) * 86400
return 0
def update(self, timeframe: str, symbol: str, data: Dict[str, Any]):
"""Update data for a timeframe.
Args:
timeframe: Timeframe
symbol: Trading symbol
data: Market data
"""
if timeframe not in self.timeframes:
logger.warning(f"Unknown timeframe: {timeframe}")
return
if symbol not in self.data[timeframe]:
self.data[timeframe][symbol] = {}
self.data[timeframe][symbol].update(data)
self.last_update[timeframe] = datetime.utcnow()
def get_data(self, timeframe: str, symbol: str) -> Optional[Dict[str, Any]]:
"""Get data for a timeframe.
Args:
timeframe: Timeframe
symbol: Trading symbol
Returns:
Data dictionary or None
"""
return self.data.get(timeframe, {}).get(symbol)
def get_all_timeframes(self, symbol: str) -> Dict[str, Dict[str, Any]]:
"""Get data for all timeframes.
Args:
symbol: Trading symbol
Returns:
Dictionary of timeframe -> data
"""
return {
tf: self.data[tf].get(symbol, {})
for tf in self.timeframes
}
def is_synchronized(self, symbol: str, max_age_seconds: int = 300) -> bool:
"""Check if all timeframes are synchronized (recently updated).
Args:
symbol: Trading symbol
max_age_seconds: Maximum age in seconds
Returns:
True if all timeframes are synchronized
"""
now = datetime.utcnow()
for tf in self.timeframes:
if symbol not in self.data[tf]:
return False
age = (now - self.last_update[tf]).total_seconds()
if age > max_age_seconds:
return False
return True

0
src/trading/__init__.py Normal file
View File

View File

@@ -0,0 +1,409 @@
"""Advanced order types: take-profit, trailing stop, OCO, iceberg orders."""
from decimal import Decimal
from typing import Optional, Dict, Any, List
from datetime import datetime
from src.core.database import Order, OrderType, OrderSide, OrderStatus
from src.core.logger import get_logger
from .order_manager import get_order_manager
logger = get_logger(__name__)
class TakeProfitOrder:
"""Take-profit order - automatically sells when price reaches target."""
def __init__(
self,
base_order_id: int,
target_price: Decimal,
quantity: Optional[Decimal] = None
):
"""Initialize take-profit order.
Args:
base_order_id: ID of the base position/order
target_price: Target price to trigger take-profit
quantity: Quantity to sell (None = all)
"""
self.base_order_id = base_order_id
self.target_price = target_price
self.quantity = quantity
self.triggered = False
def check_trigger(self, current_price: Decimal) -> bool:
"""Check if take-profit should trigger.
Args:
current_price: Current market price
Returns:
True if should trigger
"""
if self.triggered:
return False
# Trigger if price reaches or exceeds target
if current_price >= self.target_price:
self.triggered = True
return True
return False
class TrailingStopOrder:
"""Trailing stop-loss order - adjusts stop price as price moves favorably."""
def __init__(
self,
base_order_id: int,
initial_stop_price: Decimal,
trail_percent: Decimal,
quantity: Optional[Decimal] = None
):
"""Initialize trailing stop order.
Args:
base_order_id: ID of the base position/order
initial_stop_price: Initial stop price
trail_percent: Percentage to trail (e.g., 0.02 for 2%)
quantity: Quantity to sell (None = all)
"""
self.base_order_id = base_order_id
self.current_stop_price = initial_stop_price
self.trail_percent = trail_percent
self.quantity = quantity
self.triggered = False
self.highest_price = initial_stop_price # For long positions
def update(self, current_price: Decimal, is_long: bool = True):
"""Update trailing stop based on current price.
Args:
current_price: Current market price
is_long: True for long position, False for short
"""
if self.triggered:
return
if is_long:
# For long positions, trail upward
if current_price > self.highest_price:
self.highest_price = current_price
# Adjust stop to trail below highest price
self.current_stop_price = current_price * (1 - self.trail_percent)
else:
# For short positions, trail downward
if current_price < self.highest_price or self.highest_price == self.current_stop_price:
self.highest_price = current_price
# Adjust stop to trail above lowest price
self.current_stop_price = current_price * (1 + self.trail_percent)
def check_trigger(self, current_price: Decimal, is_long: bool = True) -> bool:
"""Check if trailing stop should trigger.
Args:
current_price: Current market price
is_long: True for long position, False for short
Returns:
True if should trigger
"""
if self.triggered:
return False
if is_long:
# Trigger if price falls below stop
if current_price <= self.current_stop_price:
self.triggered = True
return True
else:
# Trigger if price rises above stop
if current_price >= self.current_stop_price:
self.triggered = True
return True
return False
class OCOOrder:
"""One-Cancels-Other order - two orders where one cancels the other."""
def __init__(
self,
order1_id: int,
order2_id: int
):
"""Initialize OCO order.
Args:
order1_id: First order ID
order2_id: Second order ID
"""
self.order1_id = order1_id
self.order2_id = order2_id
self.executed = False
async def on_order_filled(self, filled_order_id: int):
"""Handle when one order is filled.
Args:
filled_order_id: ID of the filled order
"""
if self.executed:
return
order_manager = get_order_manager()
# Cancel the other order
if filled_order_id == self.order1_id:
await order_manager.cancel_order(self.order2_id)
elif filled_order_id == self.order2_id:
await order_manager.cancel_order(self.order1_id)
self.executed = True
logger.info(f"OCO order executed: order {filled_order_id} filled, other cancelled")
class IcebergOrder:
"""Iceberg order - large order split into smaller visible parts."""
def __init__(
self,
total_quantity: Decimal,
visible_quantity: Decimal,
symbol: str,
side: OrderSide,
price: Optional[Decimal] = None
):
"""Initialize iceberg order.
Args:
total_quantity: Total quantity to execute
visible_quantity: Visible quantity per order
symbol: Trading symbol
side: Buy or sell
price: Limit price (None for market)
"""
self.total_quantity = total_quantity
self.visible_quantity = visible_quantity
self.symbol = symbol
self.side = side
self.price = price
self.remaining_quantity = total_quantity
self.orders: List[int] = []
self.completed = False
async def create_next_order(self, exchange_id: int) -> Optional[int]:
"""Create next visible order.
Args:
exchange_id: Exchange ID
Returns:
Order ID or None if completed
"""
if self.completed or self.remaining_quantity <= 0:
return None
order_manager = get_order_manager()
order_type = OrderType.LIMIT if self.price else OrderType.MARKET
quantity = min(self.visible_quantity, self.remaining_quantity)
order = await order_manager.create_order(
exchange_id=exchange_id,
strategy_id=None,
symbol=self.symbol,
order_type=order_type,
side=self.side,
quantity=quantity,
price=self.price
)
self.orders.append(order.id)
self.remaining_quantity -= quantity
if self.remaining_quantity <= 0:
self.completed = True
return order.id
async def on_order_filled(self, order_id: int):
"""Handle when an order is filled.
Args:
order_id: Filled order ID
"""
if order_id in self.orders:
# Create next order if more quantity remains
if not self.completed and self.remaining_quantity > 0:
# This would need exchange_id - would need to store it
logger.info(f"Iceberg order {order_id} filled, {self.remaining_quantity} remaining")
class AdvancedOrderManager:
"""Manages advanced order types."""
def __init__(self):
"""Initialize advanced order manager."""
self.take_profit_orders: Dict[int, TakeProfitOrder] = {}
self.trailing_stops: Dict[int, TrailingStopOrder] = {}
self.oco_orders: Dict[int, OCOOrder] = {}
self.iceberg_orders: Dict[int, IcebergOrder] = {}
def create_take_profit(
self,
base_order_id: int,
target_price: Decimal,
quantity: Optional[Decimal] = None
) -> TakeProfitOrder:
"""Create a take-profit order.
Args:
base_order_id: Base order/position ID
target_price: Target price
quantity: Quantity (None = all)
Returns:
TakeProfitOrder instance
"""
tp_order = TakeProfitOrder(base_order_id, target_price, quantity)
self.take_profit_orders[base_order_id] = tp_order
logger.info(f"Created take-profit for order {base_order_id} at {target_price}")
return tp_order
def create_trailing_stop(
self,
base_order_id: int,
initial_stop_price: Decimal,
trail_percent: Decimal,
quantity: Optional[Decimal] = None
) -> TrailingStopOrder:
"""Create a trailing stop order.
Args:
base_order_id: Base order/position ID
initial_stop_price: Initial stop price
trail_percent: Trail percentage
quantity: Quantity (None = all)
Returns:
TrailingStopOrder instance
"""
trailing = TrailingStopOrder(base_order_id, initial_stop_price, trail_percent, quantity)
self.trailing_stops[base_order_id] = trailing
logger.info(f"Created trailing stop for order {base_order_id}")
return trailing
def create_oco(
self,
order1_id: int,
order2_id: int
) -> OCOOrder:
"""Create an OCO order.
Args:
order1_id: First order ID
order2_id: Second order ID
Returns:
OCOOrder instance
"""
oco = OCOOrder(order1_id, order2_id)
self.oco_orders[order1_id] = oco
self.oco_orders[order2_id] = oco
logger.info(f"Created OCO order: {order1_id} <-> {order2_id}")
return oco
def create_iceberg(
self,
total_quantity: Decimal,
visible_quantity: Decimal,
symbol: str,
side: OrderSide,
price: Optional[Decimal] = None
) -> IcebergOrder:
"""Create an iceberg order.
Args:
total_quantity: Total quantity
visible_quantity: Visible quantity per order
symbol: Trading symbol
side: Buy or sell
price: Limit price (None for market)
Returns:
IcebergOrder instance
"""
iceberg = IcebergOrder(total_quantity, visible_quantity, symbol, side, price)
# Store by a unique ID (could use a counter or hash)
iceberg_id = id(iceberg)
self.iceberg_orders[iceberg_id] = iceberg
logger.info(f"Created iceberg order: {total_quantity} {symbol} ({visible_quantity} visible)")
return iceberg
def update_trailing_stops(self, prices: Dict[int, Decimal], is_long: Dict[int, bool]):
"""Update all trailing stops with current prices.
Args:
prices: Dictionary of order_id -> current_price
is_long: Dictionary of order_id -> is_long_position
"""
for order_id, trailing in self.trailing_stops.items():
if order_id in prices:
trailing.update(prices[order_id], is_long.get(order_id, True))
def check_triggers(self, prices: Dict[int, Decimal], is_long: Dict[int, bool]) -> List[int]:
"""Check all triggers and return triggered order IDs.
Args:
prices: Dictionary of order_id -> current_price
is_long: Dictionary of order_id -> is_long_position
Returns:
List of triggered order IDs
"""
triggered = []
# Check take-profit orders
for order_id, tp in self.take_profit_orders.items():
if order_id in prices and tp.check_trigger(prices[order_id]):
triggered.append(order_id)
# Check trailing stops
for order_id, trailing in self.trailing_stops.items():
if order_id in prices and trailing.check_trigger(
prices[order_id],
is_long.get(order_id, True)
):
triggered.append(order_id)
return triggered
async def on_order_filled(self, order_id: int):
"""Handle order fill for advanced orders.
Args:
order_id: Filled order ID
"""
# Check OCO orders
if order_id in self.oco_orders:
await self.oco_orders[order_id].on_order_filled(order_id)
# Check iceberg orders
for iceberg_id, iceberg in self.iceberg_orders.items():
if order_id in iceberg.orders:
await iceberg.on_order_filled(order_id)
# Global advanced order manager
_advanced_order_manager: Optional[AdvancedOrderManager] = None
def get_advanced_order_manager() -> AdvancedOrderManager:
"""Get global advanced order manager instance."""
global _advanced_order_manager
if _advanced_order_manager is None:
_advanced_order_manager = AdvancedOrderManager()
return _advanced_order_manager

245
src/trading/engine.py Normal file
View File

@@ -0,0 +1,245 @@
"""Main trading engine with order execution and position management."""
from decimal import Decimal
from typing import Optional, Dict, Any, List
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.database import Order, OrderStatus, OrderSide, OrderType, get_database
from src.core.logger import get_logger
from src.core.repositories import OrderRepository
from src.exchanges import get_exchange
from .order_manager import get_order_manager
from .paper_trading import get_paper_trading
from .fee_calculator import get_fee_calculator
from src.risk.manager import get_risk_manager
logger = get_logger(__name__)
class TradingEngine:
"""Main trading engine orchestrator."""
def __init__(self):
"""Initialize trading engine."""
self.db = get_database()
self.order_manager = get_order_manager()
self.paper_trading = get_paper_trading()
self.risk_manager = get_risk_manager()
self.logger = get_logger(__name__)
self._exchanges: Dict[int, Any] = {} # exchange_id -> adapter
async def get_exchange_adapter(self, exchange_id: int):
"""Get or create exchange adapter.
Args:
exchange_id: Exchange ID
Returns:
Exchange adapter or None
"""
if exchange_id not in self._exchanges:
# We assume get_exchange is a factory function that might NOT be async itself,
# but returns an adapter that IS async.
# If get_exchange does I/O, it should be awaited.
# For now, let's assume it just instantiates the class.
adapter = await get_exchange(exchange_id)
if adapter:
self._exchanges[exchange_id] = adapter
if not adapter._connected:
await adapter.connect()
return self._exchanges.get(exchange_id)
async def execute_order(
self,
exchange_id: int,
strategy_id: Optional[int],
symbol: str,
side: OrderSide,
order_type: OrderType,
quantity: Decimal,
price: Optional[Decimal] = None,
paper_trading: bool = True
) -> Optional[Order]:
"""Execute a trading order."""
try:
# Get exchange adapter
adapter = await self.get_exchange_adapter(exchange_id)
if not adapter:
self.logger.error(f"Exchange {exchange_id} not available")
return None
# Get current balance/price for risk checks
balance = self.paper_trading.get_balance() if paper_trading else Decimal(0)
ticker = await adapter.get_ticker(symbol)
current_price = ticker.get('last', price or Decimal(0))
# Risk checks (includes fee validation)
allowed, reason = await self.risk_manager.check_order_risk(
symbol, side.value, quantity, current_price, balance, adapter
)
if not allowed:
self.logger.warning(f"Order rejected: {reason}")
return None
# Create order in database using generic repository pattern would be better
# but for now we create the object and save it using our own session management
# or rely on the order_manager which we need to check if it's async-ready.
# For this refactor, let's use the OrderRepository directly for creation
# to be safe with sessions.
async with self.db.get_session() as session:
repo = OrderRepository(session)
order = Order(
exchange_id=exchange_id,
strategy_id=strategy_id,
symbol=symbol,
order_type=order_type,
side=side,
quantity=quantity,
price=price,
paper_trading=paper_trading
)
order = await repo.create(order)
if paper_trading:
# Execute in paper trading
fill_price = current_price # Use current price for market orders
if order_type == OrderType.LIMIT and price:
fill_price = price
is_maker = (order_type == OrderType.LIMIT)
# Use paper trading fee calculation (uses configured fee_exchange)
fee_calculator = get_fee_calculator()
fee = fee_calculator.calculate_fee_for_paper_trading(
quantity=quantity,
price=fill_price,
order_type=order_type,
is_maker=is_maker
)
if await self.paper_trading.execute_order(order, fill_price, fee):
self.logger.info(f"Paper trading order {order.id} executed with fee: {fee}")
return order
else:
self.logger.warning(f"Paper trading order {order.id} rejected (insufficient funds or no position)")
async with self.db.get_session() as session:
repo = OrderRepository(session)
await repo.update_status(order.id, OrderStatus.REJECTED)
return None
else:
# Execute live order
from src.exchanges.base import Order as ExchangeOrder
exchange_order = ExchangeOrder(
symbol=symbol,
side=side.value,
order_type=order_type.value,
quantity=quantity,
price=price
)
result = await adapter.place_order(exchange_order)
if result.get('id'):
actual_fee = adapter.extract_fee_from_order_response(result)
if actual_fee is None:
fee_calculator = get_fee_calculator()
is_maker = (order_type == OrderType.LIMIT)
estimated_price = price or current_price
actual_fee = fee_calculator.calculate_fee(
quantity=quantity,
price=estimated_price,
order_type=order_type,
exchange_adapter=adapter,
is_maker=is_maker
)
async with self.db.get_session() as session:
repo = OrderRepository(session)
await repo.update_status(
order.id,
OrderStatus.OPEN,
exchange_order_id=result['id'],
fee=actual_fee
)
return order
else:
async with self.db.get_session() as session:
repo = OrderRepository(session)
await repo.update_status(order.id, OrderStatus.REJECTED)
return None
except Exception as e:
self.logger.error(f"Failed to execute order: {e}")
return None
async def cancel_order(self, order_id: int) -> bool:
"""Cancel an order."""
async with self.db.get_session() as session:
repo = OrderRepository(session)
order = await repo.get_by_id(order_id)
if not order:
return False
# If already in final state, cannot cancel
if order.status in [OrderStatus.FILLED, OrderStatus.CANCELLED, OrderStatus.REJECTED, OrderStatus.EXPIRED]:
self.logger.warning(f"Order {order_id} is already in state {order.status}, cannot cancel")
return False
# If paper trading, just update status in DB
if order.paper_trading:
await repo.update_status(order_id, OrderStatus.CANCELLED)
self.logger.info(f"Paper trading order {order_id} cancelled")
return True
# If live trading, cancel on exchange
adapter = await self.get_exchange_adapter(order.exchange_id)
if not adapter:
self.logger.error(f"Exchange adapter not found for exchange_id {order.exchange_id}")
return False
try:
if await adapter.cancel_order(order.exchange_order_id or str(order.id), order.symbol):
# Update status in DB
await repo.update_status(order_id, OrderStatus.CANCELLED)
self.logger.info(f"Live order {order_id} cancelled on exchange")
return True
else:
self.logger.error(f"Failed to cancel order {order_id} on exchange")
except Exception as e:
self.logger.error(f"Error cancelling order {order_id} on exchange: {e}")
return False
async def get_positions(self, exchange_id: Optional[int] = None) -> List[Dict[str, Any]]:
"""Get open positions."""
if exchange_id:
adapter = await self.get_exchange_adapter(exchange_id)
if adapter:
return await adapter.get_positions()
# Get from paper trading
positions = self.paper_trading.get_positions()
return [
{
'symbol': pos.symbol,
'side': pos.side,
'quantity': pos.quantity,
'entry_price': pos.entry_price,
'current_price': pos.current_price,
'unrealized_pnl': pos.unrealized_pnl,
}
for pos in positions
]
# Global trading engine
_trading_engine: Optional[TradingEngine] = None
def get_trading_engine() -> TradingEngine:
"""Get global trading engine instance."""
global _trading_engine
if _trading_engine is None:
_trading_engine = TradingEngine()
return _trading_engine

View File

@@ -0,0 +1,278 @@
"""Centralized fee calculation service for trading operations."""
from decimal import Decimal
from typing import Dict, Optional, Any
from src.core.logger import get_logger
from src.core.config import get_config
from src.exchanges.base import BaseExchangeAdapter
from src.core.database import OrderType, OrderSide
logger = get_logger(__name__)
class FeeCalculator:
"""Centralized fee calculation service."""
def __init__(self):
"""Initialize fee calculator."""
self.config = get_config()
self.logger = get_logger(__name__)
def calculate_fee(
self,
quantity: Decimal,
price: Decimal,
order_type: OrderType,
exchange_adapter: Optional[BaseExchangeAdapter] = None,
is_maker: Optional[bool] = None
) -> Decimal:
"""Calculate trading fee for an order.
Args:
quantity: Trade quantity
price: Trade price
order_type: Order type (MARKET or LIMIT)
exchange_adapter: Exchange adapter for fee structure (optional)
is_maker: Explicit maker/taker flag (if None, determined from order_type)
Returns:
Trading fee amount
"""
if quantity <= 0 or price <= 0:
return Decimal(0)
# Determine maker/taker if not explicitly provided
if is_maker is None:
is_maker = self._is_maker_order(order_type)
# Get fee structure
fee_structure = self._get_fee_structure(exchange_adapter)
fee_rate = fee_structure['maker'] if is_maker else fee_structure['taker']
# Calculate fee
trade_value = quantity * price
fee = trade_value * Decimal(str(fee_rate))
# Apply minimum fee if configured
min_fee = fee_structure.get('minimum', Decimal(0))
if min_fee > 0 and fee < min_fee:
fee = min_fee
return fee
def estimate_round_trip_fee(
self,
quantity: Decimal,
price: Decimal,
exchange_adapter: Optional[BaseExchangeAdapter] = None
) -> Decimal:
"""Estimate total fees for a round-trip trade (buy + sell).
Args:
quantity: Trade quantity
price: Trade price
exchange_adapter: Exchange adapter for fee structure (optional)
Returns:
Total estimated round-trip fee
"""
# Estimate using taker fees (worst case)
buy_fee = self.calculate_fee(quantity, price, OrderType.MARKET, exchange_adapter, is_maker=False)
sell_fee = self.calculate_fee(quantity, price, OrderType.MARKET, exchange_adapter, is_maker=False)
return buy_fee + sell_fee
def get_minimum_profit_threshold(
self,
quantity: Decimal,
price: Decimal,
exchange_adapter: Optional[BaseExchangeAdapter] = None,
multiplier: float = 2.0
) -> Decimal:
"""Calculate minimum profit threshold needed to break even after fees.
Args:
quantity: Trade quantity
price: Trade price
exchange_adapter: Exchange adapter for fee structure (optional)
multiplier: Multiplier for minimum profit (default 2.0 = 2x fees)
Returns:
Minimum profit threshold
"""
round_trip_fee = self.estimate_round_trip_fee(quantity, price, exchange_adapter)
return round_trip_fee * Decimal(str(multiplier))
def calculate_fee_reserve(
self,
position_value: Decimal,
exchange_adapter: Optional[BaseExchangeAdapter] = None,
reserve_percent: Optional[float] = None
) -> Decimal:
"""Calculate fee reserve amount for position sizing.
Args:
position_value: Intended position value
exchange_adapter: Exchange adapter for fee structure (optional)
reserve_percent: Override reserve percentage (default: 0.4% for round-trip)
Returns:
Fee reserve amount
"""
if reserve_percent is None:
# Default: 0.4% for round-trip (conservative estimate)
reserve_percent = 0.004
return position_value * Decimal(str(reserve_percent))
def _is_maker_order(self, order_type: OrderType) -> bool:
"""Determine if order type is typically a maker order.
Args:
order_type: Order type
Returns:
True if maker order, False if taker order
"""
# Limit orders that add liquidity = maker
# Market orders that take liquidity = taker
return order_type == OrderType.LIMIT
def _get_fee_structure(
self,
exchange_adapter: Optional[BaseExchangeAdapter] = None
) -> Dict[str, Any]:
"""Get fee structure from exchange or config defaults.
Args:
exchange_adapter: Exchange adapter (optional)
Returns:
Fee structure dictionary with 'maker', 'taker', and optional 'minimum'
"""
# Try to get from exchange adapter first
if exchange_adapter:
try:
fee_structure = exchange_adapter.get_fee_structure()
if fee_structure:
return fee_structure
except Exception as e:
self.logger.warning(f"Failed to get fee structure from exchange: {e}")
# Get from config with exchange-specific overrides
default_fees = self.config.get("trading.default_fees", {
"maker": 0.001, # 0.1%
"taker": 0.001, # 0.1%
"minimum": 0.0
})
# Check for exchange-specific fees if adapter provided
if exchange_adapter:
exchange_name = exchange_adapter.name.lower()
exchange_fees = self.config.get(f"trading.exchanges.{exchange_name}.fees")
if exchange_fees:
default_fees.update(exchange_fees)
return default_fees
def get_fee_structure_by_exchange_name(
self,
exchange_name: str
) -> Dict[str, Any]:
"""Get fee structure for a specific exchange by name (for paper trading).
Args:
exchange_name: Exchange name (e.g., 'coinbase', 'kraken', 'binance')
Returns:
Fee structure dictionary with 'maker', 'taker', and optional 'minimum'
"""
# Get default fees
default_fees = self.config.get("trading.default_fees", {
"maker": 0.001, # 0.1%
"taker": 0.001, # 0.1%
"minimum": 0.0
})
# Check for exchange-specific fees
exchange_fees = self.config.get(f"trading.exchanges.{exchange_name.lower()}.fees")
if exchange_fees:
return {**default_fees, **exchange_fees}
return default_fees
def calculate_fee_for_paper_trading(
self,
quantity: Decimal,
price: Decimal,
order_type: OrderType,
is_maker: Optional[bool] = None
) -> Decimal:
"""Calculate trading fee for paper trading using configured exchange.
Args:
quantity: Trade quantity
price: Trade price
order_type: Order type (MARKET or LIMIT)
is_maker: Explicit maker/taker flag (if None, determined from order_type)
Returns:
Trading fee amount
"""
if quantity <= 0 or price <= 0:
return Decimal(0)
# Determine maker/taker if not explicitly provided
if is_maker is None:
is_maker = self._is_maker_order(order_type)
# Get fee exchange from config
fee_exchange = self.config.get("paper_trading.fee_exchange", "coinbase")
fee_structure = self.get_fee_structure_by_exchange_name(fee_exchange)
fee_rate = fee_structure['maker'] if is_maker else fee_structure['taker']
# Calculate fee
trade_value = quantity * price
fee = trade_value * Decimal(str(fee_rate))
# Apply minimum fee if configured
min_fee = fee_structure.get('minimum', Decimal(0))
if min_fee > 0 and fee < min_fee:
fee = min_fee
return fee
def get_fee_percentage(
self,
order_type: OrderType,
exchange_adapter: Optional[BaseExchangeAdapter] = None
) -> float:
"""Get fee percentage for order type.
Args:
order_type: Order type
exchange_adapter: Exchange adapter (optional)
Returns:
Fee percentage (e.g., 0.001 for 0.1%)
"""
is_maker = self._is_maker_order(order_type)
fee_structure = self._get_fee_structure(exchange_adapter)
return fee_structure['maker'] if is_maker else fee_structure['taker']
# Global fee calculator instance
_fee_calculator: Optional[FeeCalculator] = None
def get_fee_calculator() -> FeeCalculator:
"""Get global fee calculator instance.
Returns:
FeeCalculator instance
"""
global _fee_calculator
if _fee_calculator is None:
_fee_calculator = FeeCalculator()
return _fee_calculator

122
src/trading/futures.py Normal file
View File

@@ -0,0 +1,122 @@
"""Futures and leverage trading support with margin calculations."""
from decimal import Decimal
from typing import Dict, Optional
from src.core.logger import get_logger
logger = get_logger(__name__)
class FuturesManager:
"""Manages futures and leverage trading."""
def __init__(self):
"""Initialize futures manager."""
self.logger = get_logger(__name__)
def calculate_margin(
self,
quantity: Decimal,
price: Decimal,
leverage: int,
margin_type: str = "isolated"
) -> Decimal:
"""Calculate required margin.
Args:
quantity: Position quantity
price: Entry price
leverage: Leverage multiplier
margin_type: "isolated" or "cross"
Returns:
Required margin
"""
position_value = quantity * price
margin = position_value / Decimal(leverage)
return margin
def calculate_liquidation_price(
self,
entry_price: Decimal,
leverage: int,
side: str, # "long" or "short"
maintenance_margin: Decimal = Decimal("0.01") # 1%
) -> Decimal:
"""Calculate liquidation price.
Args:
entry_price: Entry price
leverage: Leverage multiplier
side: Position side
maintenance_margin: Maintenance margin rate
Returns:
Liquidation price
"""
if side == "long":
# For long: liquidation when price drops too much
liquidation = entry_price * (1 - (1 / leverage) + maintenance_margin)
else:
# For short: liquidation when price rises too much
liquidation = entry_price * (1 + (1 / leverage) - maintenance_margin)
return liquidation
def calculate_funding_rate(
self,
mark_price: Decimal,
index_price: Decimal
) -> Decimal:
"""Calculate funding rate for perpetual futures.
Args:
mark_price: Mark price
index_price: Index price
Returns:
Funding rate (8-hour rate)
"""
premium = (mark_price - index_price) / index_price
funding_rate = premium * Decimal("0.01") # Simplified
return funding_rate
def calculate_unrealized_pnl(
self,
entry_price: Decimal,
current_price: Decimal,
quantity: Decimal,
side: str,
leverage: int = 1
) -> Decimal:
"""Calculate unrealized P&L for futures position.
Args:
entry_price: Entry price
current_price: Current mark price
quantity: Position quantity
side: "long" or "short"
leverage: Leverage multiplier
Returns:
Unrealized P&L
"""
if side == "long":
pnl = (current_price - entry_price) * quantity * leverage
else:
pnl = (entry_price - current_price) * quantity * leverage
return pnl
# Global futures manager
_futures_manager: Optional[FuturesManager] = None
def get_futures_manager() -> FuturesManager:
"""Get global futures manager instance."""
global _futures_manager
if _futures_manager is None:
_futures_manager = FuturesManager()
return _futures_manager

Some files were not shown because too many files have changed in this diff Show More