Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/alerts/__init__.py
Normal file
0
src/alerts/__init__.py
Normal file
26
src/alerts/channels.py
Normal file
26
src/alerts/channels.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Alert delivery channels."""
|
||||
|
||||
from typing import Optional
|
||||
from src.ui.utils.notifications import get_notification_manager
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AlertChannel:
|
||||
"""Manages alert delivery channels."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize alert channel."""
|
||||
self.notifications = get_notification_manager()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def send(self, alert_type: str, message: str):
|
||||
"""Send alert through all channels.
|
||||
|
||||
Args:
|
||||
alert_type: Alert type
|
||||
message: Alert message
|
||||
"""
|
||||
self.notifications.notify_alert(alert_type, message)
|
||||
|
||||
139
src/alerts/engine.py
Normal file
139
src/alerts/engine.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Alert rule engine."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Optional, Callable, List
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, Alert
|
||||
from src.core.logger import get_logger
|
||||
from .channels import AlertChannel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AlertEngine:
|
||||
"""Alert rule engine."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize alert engine."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
self.channel = AlertChannel()
|
||||
self._active_alerts: Dict[int, Dict] = {}
|
||||
|
||||
def evaluate_price_alert(
|
||||
self,
|
||||
alert_id: int,
|
||||
symbol: str,
|
||||
current_price: Decimal
|
||||
) -> bool:
|
||||
"""Evaluate price alert.
|
||||
|
||||
Args:
|
||||
alert_id: Alert ID
|
||||
symbol: Trading symbol
|
||||
current_price: Current price
|
||||
|
||||
Returns:
|
||||
True if alert should trigger
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
alert = session.query(Alert).filter_by(id=alert_id).first()
|
||||
if not alert or not alert.enabled:
|
||||
return False
|
||||
|
||||
condition = alert.condition
|
||||
alert_type = condition.get('type')
|
||||
threshold = Decimal(str(condition.get('threshold', 0)))
|
||||
operator = condition.get('operator', '>') # >, <, >=, <=
|
||||
|
||||
if alert_type == 'price_above':
|
||||
return current_price > threshold
|
||||
elif alert_type == 'price_below':
|
||||
return current_price < threshold
|
||||
elif alert_type == 'price_change':
|
||||
# Would need previous price
|
||||
return False
|
||||
|
||||
return False
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def evaluate_indicator_alert(
|
||||
self,
|
||||
alert_id: int,
|
||||
indicator_value: float
|
||||
) -> bool:
|
||||
"""Evaluate indicator alert.
|
||||
|
||||
Args:
|
||||
alert_id: Alert ID
|
||||
indicator_value: Current indicator value
|
||||
|
||||
Returns:
|
||||
True if alert should trigger
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
alert = session.query(Alert).filter_by(id=alert_id).first()
|
||||
if not alert or not alert.enabled:
|
||||
return False
|
||||
|
||||
condition = alert.condition
|
||||
threshold = float(condition.get('threshold', 0))
|
||||
operator = condition.get('operator', '>')
|
||||
|
||||
if operator == '>':
|
||||
return indicator_value > threshold
|
||||
elif operator == '<':
|
||||
return indicator_value < threshold
|
||||
elif operator == 'crosses_above':
|
||||
# Would need previous value
|
||||
return False
|
||||
elif operator == 'crosses_below':
|
||||
# Would need previous value
|
||||
return False
|
||||
|
||||
return False
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def trigger_alert(self, alert_id: int, message: str):
|
||||
"""Trigger an alert.
|
||||
|
||||
Args:
|
||||
alert_id: Alert ID
|
||||
message: Alert message
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
alert = session.query(Alert).filter_by(id=alert_id).first()
|
||||
if not alert:
|
||||
return
|
||||
|
||||
alert.triggered = True
|
||||
alert.triggered_at = datetime.utcnow()
|
||||
session.commit()
|
||||
|
||||
# Send notification
|
||||
self.channel.send(alert.alert_type, message)
|
||||
logger.info(f"Alert {alert_id} triggered: {message}")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to trigger alert {alert_id}: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# Global alert engine
|
||||
_alert_engine: Optional[AlertEngine] = None
|
||||
|
||||
|
||||
def get_alert_engine() -> AlertEngine:
|
||||
"""Get global alert engine instance."""
|
||||
global _alert_engine
|
||||
if _alert_engine is None:
|
||||
_alert_engine = AlertEngine()
|
||||
return _alert_engine
|
||||
|
||||
78
src/alerts/manager.py
Normal file
78
src/alerts/manager.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Alert management."""
|
||||
|
||||
from typing import Dict, Optional, List
|
||||
from sqlalchemy import select
|
||||
from src.core.database import get_database, Alert
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AlertManager:
|
||||
"""Manages alerts."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize alert manager."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def create_alert(
|
||||
self,
|
||||
name: str,
|
||||
alert_type: str,
|
||||
condition: Dict
|
||||
) -> Alert:
|
||||
"""Create a new alert.
|
||||
|
||||
Args:
|
||||
name: Alert name
|
||||
alert_type: Alert type (price, indicator, risk, system)
|
||||
condition: Alert condition configuration
|
||||
|
||||
Returns:
|
||||
Alert object
|
||||
"""
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
alert = Alert(
|
||||
name=name,
|
||||
alert_type=alert_type,
|
||||
condition=condition,
|
||||
enabled=True
|
||||
)
|
||||
session.add(alert)
|
||||
await session.commit()
|
||||
await session.refresh(alert)
|
||||
return alert
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to create alert: {e}")
|
||||
raise
|
||||
|
||||
async def list_alerts(self, enabled_only: bool = False) -> List[Alert]:
|
||||
"""List all alerts.
|
||||
|
||||
Args:
|
||||
enabled_only: Only return enabled alerts
|
||||
|
||||
Returns:
|
||||
List of alerts
|
||||
"""
|
||||
async with self.db.get_session() as session:
|
||||
stmt = select(Alert)
|
||||
if enabled_only:
|
||||
stmt = stmt.where(Alert.enabled == True)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
# Global alert manager
|
||||
_alert_manager: Optional[AlertManager] = None
|
||||
|
||||
|
||||
def get_alert_manager() -> AlertManager:
|
||||
"""Get global alert manager instance."""
|
||||
global _alert_manager
|
||||
if _alert_manager is None:
|
||||
_alert_manager = AlertManager()
|
||||
return _alert_manager
|
||||
150
src/autopilot/__init__.py
Normal file
150
src/autopilot/__init__.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""AutoPilot autonomous trading engine module.
|
||||
|
||||
This module provides an autonomous background service that combines
|
||||
geometric pattern recognition with NLP-based sentiment analysis
|
||||
to generate trade signals.
|
||||
|
||||
Supported Patterns:
|
||||
- Head and Shoulders (bullish/bearish)
|
||||
- Double Top/Bottom
|
||||
- Triple Top/Bottom
|
||||
- Triangle patterns (Ascending, Descending, Symmetrical)
|
||||
- Wedge patterns (Rising, Falling)
|
||||
- Flag and Pennant patterns
|
||||
- Candlestick patterns (Engulfing, Hammer, Doji, Stars, etc.)
|
||||
- MA Crossovers (Golden Cross, Death Cross)
|
||||
- Support/Resistance levels
|
||||
- Harmonic patterns (ABCD)
|
||||
- Gap patterns
|
||||
"""
|
||||
|
||||
from .intelligent_autopilot import (
|
||||
IntelligentAutopilot,
|
||||
get_intelligent_autopilot,
|
||||
stop_all_autopilots,
|
||||
)
|
||||
from .market_analyzer import (
|
||||
MarketAnalyzer,
|
||||
MarketConditions,
|
||||
MarketRegime,
|
||||
get_market_analyzer,
|
||||
)
|
||||
from .strategy_selector import (
|
||||
StrategySelector,
|
||||
get_strategy_selector,
|
||||
)
|
||||
from .performance_tracker import (
|
||||
PerformanceTracker,
|
||||
get_performance_tracker,
|
||||
)
|
||||
from typing import Dict, Any
|
||||
|
||||
def get_autopilot_mode_info() -> Dict[str, Any]:
|
||||
"""Get information about available autopilot modes.
|
||||
|
||||
Returns:
|
||||
Dictionary with mode information including descriptions, capabilities, and tradeoffs
|
||||
"""
|
||||
return {
|
||||
"modes": {
|
||||
"pattern": {
|
||||
"name": "Pattern-Based Autopilot",
|
||||
"description": "Detects technical chart patterns (Head & Shoulders, triangles, etc.) and combines with sentiment analysis",
|
||||
"how_it_works": "Rule-based logic - pattern + sentiment alignment = signal",
|
||||
"best_for": [
|
||||
"Users who want transparency",
|
||||
"Users who understand technical analysis",
|
||||
"Users who prefer explainable decisions"
|
||||
],
|
||||
"tradeoffs": {
|
||||
"pros": [
|
||||
"Transparent and explainable",
|
||||
"No training data required",
|
||||
"Fast and lightweight"
|
||||
],
|
||||
"cons": [
|
||||
"Less adaptive to market changes",
|
||||
"Fixed decision rules"
|
||||
]
|
||||
},
|
||||
"features": [
|
||||
"Pattern recognition (40+ patterns)",
|
||||
"Sentiment analysis (FinBERT)",
|
||||
"Real-time signal generation",
|
||||
"Transparent decision logic"
|
||||
],
|
||||
"requirements": {
|
||||
"training_data": False,
|
||||
"setup_complexity": "Low",
|
||||
"resource_usage": "Low"
|
||||
}
|
||||
},
|
||||
"intelligent": {
|
||||
"name": "ML-Based Autopilot",
|
||||
"description": "Uses machine learning to select the best strategy based on current market conditions",
|
||||
"how_it_works": "ML model analyzes market conditions and selects optimal strategy from available strategies",
|
||||
"best_for": [
|
||||
"Users who want adaptive, data-driven decisions",
|
||||
"Users who don't need to understand every decision",
|
||||
"Advanced users seeking optimization"
|
||||
],
|
||||
"tradeoffs": {
|
||||
"pros": [
|
||||
"Adapts to market conditions",
|
||||
"Learns from historical performance",
|
||||
"Can optimize strategy selection"
|
||||
],
|
||||
"cons": [
|
||||
"Requires training data",
|
||||
"Less transparent (black box)",
|
||||
"More complex setup"
|
||||
]
|
||||
},
|
||||
"features": [
|
||||
"ML-based strategy selection",
|
||||
"Market condition analysis",
|
||||
"Performance tracking",
|
||||
"Auto-execution support"
|
||||
],
|
||||
"requirements": {
|
||||
"training_data": True,
|
||||
"setup_complexity": "Medium",
|
||||
"resource_usage": "Medium"
|
||||
}
|
||||
}
|
||||
},
|
||||
"comparison": {
|
||||
"transparency": {
|
||||
"pattern": "High - All decisions are explainable",
|
||||
"intelligent": "Low - ML model decisions are less transparent"
|
||||
},
|
||||
"adaptability": {
|
||||
"pattern": "Low - Fixed rules",
|
||||
"intelligent": "High - Learns and adapts"
|
||||
},
|
||||
"setup_time": {
|
||||
"pattern": "Immediate - No setup required",
|
||||
"intelligent": "Requires training data collection"
|
||||
},
|
||||
"resource_usage": {
|
||||
"pattern": "Low - Lightweight",
|
||||
"intelligent": "Medium - ML model overhead"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"stop_all_autopilots",
|
||||
"IntelligentAutopilot",
|
||||
"get_intelligent_autopilot",
|
||||
"MarketAnalyzer",
|
||||
"MarketConditions",
|
||||
"MarketRegime",
|
||||
"get_market_analyzer",
|
||||
"StrategySelector",
|
||||
"get_strategy_selector",
|
||||
"PerformanceTracker",
|
||||
"get_performance_tracker",
|
||||
"get_autopilot_mode_info",
|
||||
]
|
||||
717
src/autopilot/intelligent_autopilot.py
Normal file
717
src/autopilot/intelligent_autopilot.py
Normal file
@@ -0,0 +1,717 @@
|
||||
"""Intelligent autopilot orchestrator with ML-based strategy selection."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
import pandas as pd
|
||||
|
||||
from src.core.logger import get_logger
|
||||
from src.core.database import get_database, OrderSide, OrderType
|
||||
from src.core.config import get_config
|
||||
from src.core.redis import get_redis_client
|
||||
from .market_analyzer import MarketConditions, get_market_analyzer
|
||||
from .strategy_selector import get_strategy_selector
|
||||
from .performance_tracker import get_performance_tracker
|
||||
from src.strategies.base import get_strategy_registry, BaseStrategy, StrategySignal, SignalType
|
||||
from src.trading.engine import get_trading_engine
|
||||
from src.risk.manager import get_risk_manager
|
||||
from src.data.collector import get_data_collector
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Strategies that require user configuration (excluded from default autopilot)
|
||||
# These are only loaded if the user has configured them in the Strategies page
|
||||
STRATEGIES_REQUIRING_CONFIG = {
|
||||
'pairs_trading', # Requires second_symbol
|
||||
'grid', # Requires grid_spacing, num_levels
|
||||
'dca', # Requires amount, interval
|
||||
'market_making', # Requires spread_percent, max_inventory
|
||||
}
|
||||
|
||||
|
||||
class IntelligentAutopilot:
|
||||
"""Intelligent autopilot with ML-based strategy selection and auto-execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange_id: int = 1,
|
||||
timeframe: str = "1h",
|
||||
interval: float = 60.0,
|
||||
paper_trading: bool = True
|
||||
):
|
||||
"""Initialize intelligent autopilot.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., "BTC/USD")
|
||||
exchange_id: Exchange ID
|
||||
timeframe: Analysis timeframe
|
||||
interval: Analysis cycle interval in seconds
|
||||
paper_trading: Paper trading mode
|
||||
"""
|
||||
self.symbol = symbol
|
||||
self.exchange_id = exchange_id
|
||||
self.timeframe = timeframe
|
||||
self.interval = interval
|
||||
self.paper_trading = paper_trading
|
||||
|
||||
self.logger = get_logger(__name__)
|
||||
self.config = get_config()
|
||||
self.db = get_database()
|
||||
self.redis = get_redis_client()
|
||||
|
||||
# Core components
|
||||
self.market_analyzer = get_market_analyzer()
|
||||
self.strategy_selector = get_strategy_selector()
|
||||
self.performance_tracker = get_performance_tracker()
|
||||
self.strategy_registry = get_strategy_registry()
|
||||
self.trading_engine = get_trading_engine()
|
||||
self.risk_manager = get_risk_manager()
|
||||
self.data_collector = get_data_collector()
|
||||
|
||||
# Configuration
|
||||
self.min_confidence_threshold = self.config.get(
|
||||
"autopilot.intelligent.min_confidence_threshold",
|
||||
0.75
|
||||
)
|
||||
self.max_trades_per_day = self.config.get(
|
||||
"autopilot.intelligent.max_trades_per_day",
|
||||
10
|
||||
)
|
||||
self.min_profit_target = self.config.get(
|
||||
"autopilot.intelligent.min_profit_target",
|
||||
0.02 # 2% minimum expected profit
|
||||
)
|
||||
self.enable_auto_execution = self.config.get(
|
||||
"autopilot.intelligent.enable_auto_execution",
|
||||
True
|
||||
)
|
||||
|
||||
# State
|
||||
self._local_running = False # Local flag for the loop
|
||||
self._last_analysis: Optional[Dict[str, Any]] = None
|
||||
self._selected_strategy: Optional[str] = None
|
||||
self._strategy_instances: Dict[str, BaseStrategy] = {}
|
||||
self._ohlcv_data: Optional[pd.DataFrame] = None
|
||||
|
||||
# Redis Keys
|
||||
safe_symbol = symbol.replace("/", "-")
|
||||
self.key_running = f"autopilot:running:{safe_symbol}:{timeframe}"
|
||||
self.key_trades = f"autopilot:trades:{safe_symbol}"
|
||||
|
||||
# Initialize strategy instances
|
||||
self._initialize_strategies()
|
||||
|
||||
@property
|
||||
def _trades_today_key(self) -> str:
|
||||
"""Get Redis key for today's trade count."""
|
||||
today = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
return f"{self.key_trades}:{today}"
|
||||
|
||||
async def _fetch_market_data(self):
|
||||
"""Fetch OHLCV data from src.database."""
|
||||
from sqlalchemy import select
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
from src.core.database import MarketData, Exchange
|
||||
|
||||
# Get exchange name
|
||||
stmt = select(Exchange).filter_by(id=self.exchange_id).limit(1)
|
||||
result = await session.execute(stmt)
|
||||
exchange = result.scalar_one_or_none()
|
||||
exchange_name = exchange.name if exchange else "coinbase"
|
||||
|
||||
# Fetch recent OHLCV data (last 200 candles for analysis)
|
||||
ohlcv_stmt = select(MarketData).filter_by(
|
||||
exchange=exchange_name,
|
||||
symbol=self.symbol,
|
||||
timeframe=self.timeframe
|
||||
).order_by(MarketData.timestamp.desc()).limit(200)
|
||||
|
||||
ohlcv_result = await session.execute(ohlcv_stmt)
|
||||
market_data = ohlcv_result.scalars().all()
|
||||
|
||||
if market_data:
|
||||
# Convert to DataFrame
|
||||
data = {
|
||||
'timestamp': [md.timestamp for md in reversed(market_data)],
|
||||
'open': [float(md.open) for md in reversed(market_data)],
|
||||
'high': [float(md.high) for md in reversed(market_data)],
|
||||
'low': [float(md.low) for md in reversed(market_data)],
|
||||
'close': [float(md.close) for md in reversed(market_data)],
|
||||
'volume': [float(md.volume) for md in reversed(market_data)],
|
||||
}
|
||||
self._ohlcv_data = pd.DataFrame(data)
|
||||
self.logger.info(f"Loaded {len(self._ohlcv_data)} candles from src.database")
|
||||
else:
|
||||
self.logger.warning("No market data found in database")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error fetching market data: {e}")
|
||||
|
||||
def _initialize_strategies(self):
|
||||
"""Initialize strategy instances for strategies that work without configuration."""
|
||||
available_strategies = self.strategy_registry.list_available()
|
||||
|
||||
for strategy_name in available_strategies:
|
||||
# Skip strategies that require user configuration
|
||||
if strategy_name.lower() in STRATEGIES_REQUIRING_CONFIG:
|
||||
self.logger.debug(f"Skipping {strategy_name} (requires configuration)")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create instance with default parameters
|
||||
strategy_class = self.strategy_registry._strategies.get(strategy_name.lower())
|
||||
if strategy_class:
|
||||
instance = strategy_class(
|
||||
name=strategy_name,
|
||||
parameters={},
|
||||
timeframes=[self.timeframe]
|
||||
)
|
||||
instance.enabled = True
|
||||
self._strategy_instances[strategy_name] = instance
|
||||
self.logger.debug(f"Initialized strategy: {strategy_name}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to initialize strategy {strategy_name}: {e}")
|
||||
|
||||
async def _load_user_configured_strategies(self):
|
||||
"""Load user-configured strategies from database.
|
||||
|
||||
This enables strategies that require configuration (like pairs_trading)
|
||||
when the user has set them up in the Strategies page.
|
||||
Also overrides default parameters if user has configured them.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from src.core.database import Strategy
|
||||
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
# Get all enabled strategies from database
|
||||
stmt = select(Strategy).where(Strategy.enabled == True)
|
||||
result = await session.execute(stmt)
|
||||
user_strategies = result.scalars().all()
|
||||
|
||||
for db_strategy in user_strategies:
|
||||
strategy_type = db_strategy.strategy_type
|
||||
if not strategy_type:
|
||||
continue
|
||||
|
||||
strategy_type_lower = strategy_type.lower()
|
||||
|
||||
# Check if this strategy type is registered
|
||||
strategy_class = self.strategy_registry._strategies.get(strategy_type_lower)
|
||||
if not strategy_class:
|
||||
self.logger.debug(f"Strategy type {strategy_type} not registered")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create instance with user's parameters
|
||||
instance = strategy_class(
|
||||
name=db_strategy.name,
|
||||
parameters=db_strategy.parameters or {},
|
||||
timeframes=db_strategy.timeframes or [self.timeframe]
|
||||
)
|
||||
instance.enabled = True
|
||||
|
||||
# Use strategy_type as key (overwrites default if exists)
|
||||
self._strategy_instances[strategy_type_lower] = instance
|
||||
|
||||
self.logger.info(
|
||||
f"Loaded user-configured strategy: {db_strategy.name} "
|
||||
f"(type: {strategy_type}, params: {db_strategy.parameters})"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Failed to load user strategy {db_strategy.name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading user-configured strategies: {e}")
|
||||
|
||||
def update_market_data(self, ohlcv_data: pd.DataFrame):
|
||||
"""Update OHLCV market data.
|
||||
|
||||
Args:
|
||||
ohlcv_data: DataFrame with columns: open, high, low, close, volume
|
||||
"""
|
||||
self._ohlcv_data = ohlcv_data.copy()
|
||||
self.logger.debug(f"Updated market data: {len(ohlcv_data)} candles")
|
||||
|
||||
async def start(self):
|
||||
"""Start the intelligent autopilot loop."""
|
||||
redis_client = self.redis.get_client()
|
||||
|
||||
# Check if already running (distributed lock style)
|
||||
is_running = await redis_client.get(self.key_running)
|
||||
if is_running:
|
||||
self.logger.warning(f"Autopilot for {self.symbol} is already running elsewhere")
|
||||
return
|
||||
|
||||
# Set running state in Redis
|
||||
await redis_client.set(self.key_running, "1")
|
||||
self._local_running = True
|
||||
|
||||
self.logger.info(
|
||||
f"Intelligent autopilot started for {self.symbol} "
|
||||
f"(timeframe: {self.timeframe}, interval: {self.interval}s)"
|
||||
)
|
||||
|
||||
# Initial data fetch
|
||||
await self._fetch_market_data()
|
||||
|
||||
# Try to load or train model
|
||||
try:
|
||||
# Check if this task should be offloaded to Celery in future
|
||||
await self.strategy_selector.train_model(force_retrain=False)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Model training failed, will use fallback: {e}")
|
||||
|
||||
# Load user-configured strategies from database
|
||||
await self._load_user_configured_strategies()
|
||||
|
||||
self.logger.info(
|
||||
f"Autopilot has {len(self._strategy_instances)} strategies available: "
|
||||
f"{list(self._strategy_instances.keys())}"
|
||||
)
|
||||
|
||||
try:
|
||||
while self._local_running:
|
||||
# Distributed check: verify key still exists
|
||||
if not await redis_client.exists(self.key_running):
|
||||
self.logger.info("Remote stop signal received")
|
||||
break
|
||||
|
||||
try:
|
||||
await self._analysis_cycle()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Analysis cycle error: {e}")
|
||||
|
||||
await asyncio.sleep(self.interval)
|
||||
finally:
|
||||
self._local_running = False
|
||||
# Ensure key is deleted (cleanup)
|
||||
await redis_client.delete(self.key_running)
|
||||
self.logger.info("Intelligent autopilot stopped")
|
||||
|
||||
async def _analysis_cycle(self):
|
||||
"""Perform one analysis and execution cycle."""
|
||||
# Check daily trade limit from Redis
|
||||
trades_today = await self._get_trades_today()
|
||||
|
||||
if trades_today >= self.max_trades_per_day:
|
||||
self.logger.debug(f"Daily trade limit reached: {trades_today}/{self.max_trades_per_day}")
|
||||
return
|
||||
|
||||
# Get market data (refresh if needed)
|
||||
if self._ohlcv_data is None or len(self._ohlcv_data) < 50:
|
||||
await self._fetch_market_data()
|
||||
if self._ohlcv_data is None or len(self._ohlcv_data) < 50:
|
||||
self.logger.warning("Insufficient market data for analysis")
|
||||
return
|
||||
|
||||
# Analyze market conditions
|
||||
market_conditions = self.market_analyzer.analyze_current_conditions(
|
||||
symbol=self.symbol,
|
||||
timeframe=self.timeframe,
|
||||
ohlcv_data=self._ohlcv_data
|
||||
)
|
||||
|
||||
# Select best strategy using ML
|
||||
selection_result = self.strategy_selector.select_best_strategy(
|
||||
market_conditions,
|
||||
min_confidence=self.min_confidence_threshold
|
||||
)
|
||||
|
||||
if selection_result is None:
|
||||
self.logger.debug("No strategy selected (confidence too low)")
|
||||
self._last_analysis = {
|
||||
'market_conditions': market_conditions.to_dict(),
|
||||
'selected_strategy': None,
|
||||
'confidence': 0.0
|
||||
}
|
||||
return
|
||||
|
||||
selected_strategy_name, confidence, all_predictions = selection_result
|
||||
self._selected_strategy = selected_strategy_name
|
||||
|
||||
# Get strategy instance
|
||||
strategy_instance = self._strategy_instances.get(selected_strategy_name)
|
||||
if not strategy_instance:
|
||||
self.logger.error(f"Strategy instance not found: {selected_strategy_name}")
|
||||
return
|
||||
|
||||
# Generate signal from selected strategy
|
||||
current_price = Decimal(str(self._ohlcv_data['close'].iloc[-1]))
|
||||
signal = await self._generate_strategy_signal(
|
||||
strategy_instance,
|
||||
current_price
|
||||
)
|
||||
|
||||
if signal is None or signal.signal_type == SignalType.HOLD:
|
||||
self.logger.debug(f"No actionable signal from {selected_strategy_name}")
|
||||
self._last_analysis = {
|
||||
'market_conditions': market_conditions.to_dict(),
|
||||
'selected_strategy': selected_strategy_name,
|
||||
'confidence': confidence,
|
||||
'signal': None
|
||||
}
|
||||
return
|
||||
|
||||
# Evaluate opportunity
|
||||
if not self._evaluate_opportunity(signal, confidence, market_conditions):
|
||||
self.logger.debug("Opportunity does not meet execution criteria")
|
||||
return
|
||||
|
||||
# Execute trade
|
||||
if self.enable_auto_execution:
|
||||
await self._execute_opportunity(strategy_instance, signal, market_conditions)
|
||||
else:
|
||||
self.logger.info(f"Auto-execution disabled. Signal: {signal.signal_type.value}")
|
||||
|
||||
# Store analysis result
|
||||
self._last_analysis = {
|
||||
'market_conditions': market_conditions.to_dict(),
|
||||
'selected_strategy': selected_strategy_name,
|
||||
'confidence': confidence,
|
||||
'signal': {
|
||||
'type': signal.signal_type.value,
|
||||
'strength': signal.strength,
|
||||
'price': float(signal.price) if signal.price else None
|
||||
}
|
||||
}
|
||||
|
||||
async def _generate_strategy_signal(
|
||||
self,
|
||||
strategy: BaseStrategy,
|
||||
current_price: Decimal
|
||||
) -> Optional[StrategySignal]:
|
||||
"""Generate signal from strategy."""
|
||||
try:
|
||||
# Prepare market data for strategy
|
||||
latest_candle = self._ohlcv_data.iloc[-1]
|
||||
data = {
|
||||
'open': latest_candle['open'],
|
||||
'high': latest_candle['high'],
|
||||
'low': latest_candle['low'],
|
||||
'volume': latest_candle['volume']
|
||||
}
|
||||
|
||||
# Generate signal
|
||||
signal = await strategy.on_tick(
|
||||
symbol=self.symbol,
|
||||
price=current_price,
|
||||
timeframe=self.timeframe,
|
||||
data=data
|
||||
)
|
||||
|
||||
if signal:
|
||||
# Process signal
|
||||
signal = strategy.on_signal(signal)
|
||||
|
||||
return signal
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error generating strategy signal: {e}")
|
||||
return None
|
||||
|
||||
def _evaluate_opportunity(
|
||||
self,
|
||||
signal: StrategySignal,
|
||||
confidence: float,
|
||||
market_conditions: MarketConditions
|
||||
) -> bool:
|
||||
"""Evaluate if opportunity meets execution criteria."""
|
||||
# Check confidence threshold
|
||||
if confidence < self.min_confidence_threshold:
|
||||
return False
|
||||
|
||||
# Check signal strength
|
||||
if signal.strength < 0.5:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _can_execute_order(
|
||||
self,
|
||||
side: OrderSide,
|
||||
quantity: Decimal,
|
||||
price: Decimal
|
||||
) -> Tuple[bool, str]:
|
||||
"""Pre-flight check if order can be executed."""
|
||||
# 1. Check minimum order value ($1 USD)
|
||||
order_value = quantity * price
|
||||
if order_value < Decimal('1.0'):
|
||||
return False, f"Order value ${order_value:.2f} below minimum $1.00"
|
||||
|
||||
# 2. For BUY: check sufficient balance (include fee buffer)
|
||||
if side == OrderSide.BUY:
|
||||
if self.paper_trading:
|
||||
balance = self.trading_engine.paper_trading.get_balance()
|
||||
else:
|
||||
balance = Decimal('0') # Live trading balance check handled elsewhere
|
||||
|
||||
# Add 1% fee buffer
|
||||
fee_estimate = order_value * Decimal('0.01')
|
||||
total_required = order_value + fee_estimate
|
||||
|
||||
if balance < total_required:
|
||||
return False, f"Insufficient funds: ${balance:.2f} < ${total_required:.2f}"
|
||||
|
||||
# 3. For SELL: check position exists with sufficient quantity
|
||||
if side == OrderSide.SELL:
|
||||
if self.paper_trading:
|
||||
positions = self.trading_engine.paper_trading.get_positions()
|
||||
position = next((p for p in positions if p.symbol == self.symbol), None)
|
||||
|
||||
if not position:
|
||||
return False, f"No position to sell for {self.symbol}"
|
||||
|
||||
if position.quantity < quantity:
|
||||
return False, f"Insufficient position: {position.quantity} < {quantity}"
|
||||
|
||||
return True, "OK"
|
||||
|
||||
def _determine_order_type_and_price(
|
||||
self,
|
||||
side: OrderSide,
|
||||
signal_strength: float,
|
||||
current_price: Decimal,
|
||||
is_stop_loss: bool = False
|
||||
) -> Tuple[OrderType, Optional[Decimal]]:
|
||||
"""Determine optimal order type and price for execution."""
|
||||
# Strong signals (>80% strength) or stop-loss use MARKET for speed
|
||||
if signal_strength > 0.8 or is_stop_loss:
|
||||
reason = "stop-loss" if is_stop_loss else f"high strength ({signal_strength:.2f})"
|
||||
self.logger.debug(f"Using MARKET order ({reason})")
|
||||
return OrderType.MARKET, None
|
||||
|
||||
# Normal signals use LIMIT for better price
|
||||
# BUY: bid slightly below market (0.1% discount)
|
||||
# SELL (take-profit): ask slightly above market (0.1% premium)
|
||||
if side == OrderSide.BUY:
|
||||
limit_price = current_price * Decimal('0.999') # 0.1% below
|
||||
else:
|
||||
limit_price = current_price * Decimal('1.001') # 0.1% above
|
||||
|
||||
# Round to reasonable precision (2 decimal places for USD pairs)
|
||||
limit_price = limit_price.quantize(Decimal('0.01'))
|
||||
|
||||
return OrderType.LIMIT, limit_price
|
||||
|
||||
async def _execute_opportunity(
|
||||
self,
|
||||
strategy: BaseStrategy,
|
||||
signal: StrategySignal,
|
||||
market_conditions: MarketConditions
|
||||
):
|
||||
"""Execute trade opportunity."""
|
||||
try:
|
||||
current_price = signal.price or Decimal(str(self._ohlcv_data['close'].iloc[-1]))
|
||||
|
||||
# Calculate position size
|
||||
balance = self.trading_engine.paper_trading.get_balance() if self.paper_trading else Decimal(0)
|
||||
quantity = strategy.calculate_position_size(signal, balance, current_price)
|
||||
|
||||
if quantity <= 0:
|
||||
self.logger.warning("Invalid position size calculated")
|
||||
return
|
||||
|
||||
# Determine order side
|
||||
if signal.signal_type == SignalType.BUY:
|
||||
side = OrderSide.BUY
|
||||
elif signal.signal_type == SignalType.SELL:
|
||||
side = OrderSide.SELL
|
||||
else:
|
||||
self.logger.warning(f"Unsupported signal type: {signal.signal_type}")
|
||||
return
|
||||
|
||||
# Pre-flight validation - skip order if it would fail
|
||||
can_execute, reason = await self._can_execute_order(side, quantity, current_price)
|
||||
if not can_execute:
|
||||
self.logger.info(f"Skipping order ({strategy.name}): {reason}")
|
||||
return
|
||||
|
||||
# Determine optimal order type
|
||||
is_stop_loss = False
|
||||
if side == OrderSide.SELL and self.paper_trading:
|
||||
positions = self.trading_engine.paper_trading.get_positions()
|
||||
position = next((p for p in positions if p.symbol == self.symbol), None)
|
||||
if position:
|
||||
is_stop_loss = current_price < position.entry_price
|
||||
|
||||
order_type, limit_price = self._determine_order_type_and_price(
|
||||
side=side,
|
||||
signal_strength=signal.strength,
|
||||
current_price=current_price,
|
||||
is_stop_loss=is_stop_loss
|
||||
)
|
||||
|
||||
# Execute order with smart order type
|
||||
order = await self.trading_engine.execute_order(
|
||||
exchange_id=self.exchange_id,
|
||||
strategy_id=None, # Intelligent autopilot doesn't use strategy_id
|
||||
symbol=self.symbol,
|
||||
side=side,
|
||||
order_type=order_type,
|
||||
quantity=quantity,
|
||||
price=limit_price, # None for MARKET, price for LIMIT
|
||||
paper_trading=self.paper_trading
|
||||
)
|
||||
|
||||
order_type_str = "LIMIT" if order_type == OrderType.LIMIT else "MARKET"
|
||||
if order:
|
||||
self.logger.info(
|
||||
f"Executed {side.value} {order_type_str} order: {quantity} {self.symbol} "
|
||||
f"at {limit_price or current_price} (strategy: {strategy.name})"
|
||||
)
|
||||
|
||||
# Increment Redis trade counter
|
||||
await self._increment_trades_today()
|
||||
|
||||
# Record trade for ML training (async, don't wait)
|
||||
asyncio.create_task(self._record_trade_for_learning(
|
||||
strategy.name,
|
||||
market_conditions,
|
||||
order
|
||||
))
|
||||
else:
|
||||
self.logger.warning("Order execution failed")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing opportunity: {e}")
|
||||
|
||||
async def _record_trade_for_learning(
|
||||
self,
|
||||
strategy_name: str,
|
||||
market_conditions: MarketConditions,
|
||||
order: Any
|
||||
):
|
||||
"""Record trade for ML learning (called after trade completes)."""
|
||||
try:
|
||||
# Wait a bit for order to complete
|
||||
await asyncio.sleep(5)
|
||||
|
||||
trade_result = {
|
||||
'return_pct': 0.0,
|
||||
'sharpe_ratio': 0.0,
|
||||
'win_rate': 0.0,
|
||||
'max_drawdown': 0.0,
|
||||
'trade_count': 1
|
||||
}
|
||||
|
||||
await self.performance_tracker.record_trade(
|
||||
strategy_name=strategy_name,
|
||||
market_conditions=market_conditions,
|
||||
trade_result=trade_result
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error recording trade for learning: {e}")
|
||||
|
||||
async def _get_trades_today(self) -> int:
|
||||
"""Get number of trades executed today from Redis."""
|
||||
try:
|
||||
redis_client = self.redis.get_client()
|
||||
count = await redis_client.get(self._trades_today_key)
|
||||
return int(count) if count else 0
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting trade count from Redis: {e}")
|
||||
return 0
|
||||
|
||||
async def _increment_trades_today(self):
|
||||
"""Increment daily trade counter in Redis."""
|
||||
try:
|
||||
redis_client = self.redis.get_client()
|
||||
key = self._trades_today_key
|
||||
|
||||
# Increment
|
||||
await redis_client.incr(key)
|
||||
|
||||
# Set expiry to 24 hours (if new key)
|
||||
if await redis_client.ttl(key) == -1:
|
||||
await redis_client.expire(key, 86400)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error incrementing trade count in Redis: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the intelligent autopilot (signals distributed stop)."""
|
||||
# Synchronous method to trigger stop
|
||||
# Since we need to delete Redis key async, we create a task
|
||||
asyncio.create_task(self._stop_async())
|
||||
|
||||
async def _stop_async(self):
|
||||
"""Async stop implementation."""
|
||||
self._local_running = False
|
||||
redis_client = self.redis.get_client()
|
||||
await redis_client.delete(self.key_running)
|
||||
self.logger.info("Intelligent autopilot stopping... (signal sent)")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if autopilot is running (Local check)."""
|
||||
return self._local_running
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get current autopilot status (Synchronous wrapper)."""
|
||||
return {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'running': self._local_running,
|
||||
'selected_strategy': self._selected_strategy,
|
||||
'trades_today': 0, # Fetch async if needed via get_distributed_status
|
||||
'max_trades_per_day': self.max_trades_per_day,
|
||||
'min_confidence_threshold': self.min_confidence_threshold,
|
||||
'enable_auto_execution': self.enable_auto_execution,
|
||||
'last_analysis': self._last_analysis,
|
||||
'model_info': self.strategy_selector.get_model_info()
|
||||
}
|
||||
|
||||
async def get_distributed_status(self) -> Dict[str, Any]:
|
||||
"""Get full distributed status (Async)."""
|
||||
try:
|
||||
redis_client = self.redis.get_client()
|
||||
is_running_distributed = await redis_client.exists(self.key_running)
|
||||
trades_today = await self._get_trades_today()
|
||||
|
||||
status = self.get_status()
|
||||
status['running'] = bool(is_running_distributed)
|
||||
status['trades_today'] = trades_today
|
||||
return status
|
||||
except Exception:
|
||||
# Fallback to local status if Redis fails
|
||||
return self.get_status()
|
||||
|
||||
|
||||
# Global instances (factory cache)
|
||||
_intelligent_autopilots: Dict[str, IntelligentAutopilot] = {}
|
||||
|
||||
|
||||
def get_intelligent_autopilot(
|
||||
symbol: str,
|
||||
exchange_id: int = 1,
|
||||
timeframe: str = "1h",
|
||||
**kwargs
|
||||
) -> IntelligentAutopilot:
|
||||
"""Get or create intelligent autopilot instance."""
|
||||
key = f"{symbol}:{exchange_id}:{timeframe}"
|
||||
if key not in _intelligent_autopilots:
|
||||
_intelligent_autopilots[key] = IntelligentAutopilot(
|
||||
symbol=symbol,
|
||||
exchange_id=exchange_id,
|
||||
timeframe=timeframe,
|
||||
**kwargs
|
||||
)
|
||||
logger.info(f"Created new IntelligentAutopilot instance for {key}")
|
||||
|
||||
return _intelligent_autopilots[key]
|
||||
|
||||
|
||||
def stop_all_autopilots():
|
||||
"""Stop all running IntelligentAutoPilot instances."""
|
||||
for key, autopilot in _intelligent_autopilots.items():
|
||||
if autopilot.is_running:
|
||||
autopilot.stop()
|
||||
logger.info(f"Stopped IntelligentAutopilot for {key}")
|
||||
485
src/autopilot/market_analyzer.py
Normal file
485
src/autopilot/market_analyzer.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Market condition analyzer for intelligent autopilot.
|
||||
|
||||
Analyzes real-time market conditions and extracts features for ML model.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from src.core.logger import get_logger
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MarketRegime(str, Enum):
|
||||
"""Market regime classification."""
|
||||
TRENDING_UP = "trending_up"
|
||||
TRENDING_DOWN = "trending_down"
|
||||
RANGING = "ranging"
|
||||
HIGH_VOLATILITY = "high_volatility"
|
||||
LOW_VOLATILITY = "low_volatility"
|
||||
BREAKOUT = "breakout"
|
||||
REVERSAL = "reversal"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketConditions:
|
||||
"""Market conditions snapshot."""
|
||||
symbol: str
|
||||
timeframe: str
|
||||
regime: MarketRegime
|
||||
features: Dict[str, float]
|
||||
timestamp: pd.Timestamp
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage."""
|
||||
return {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'regime': self.regime.value,
|
||||
'features': self.features,
|
||||
'timestamp': self.timestamp.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class MarketAnalyzer:
|
||||
"""Analyzes market conditions and extracts features for ML."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize market analyzer."""
|
||||
self.logger = get_logger(__name__)
|
||||
self.indicators = get_indicators()
|
||||
|
||||
def analyze_current_conditions(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
ohlcv_data: pd.DataFrame
|
||||
) -> MarketConditions:
|
||||
"""Analyze current market conditions.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
ohlcv_data: OHLCV DataFrame with columns: open, high, low, close, volume
|
||||
|
||||
Returns:
|
||||
MarketConditions object
|
||||
"""
|
||||
if len(ohlcv_data) < 50:
|
||||
self.logger.warning(f"Insufficient data for analysis: {len(ohlcv_data)} candles")
|
||||
return MarketConditions(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
regime=MarketRegime.UNKNOWN,
|
||||
features={},
|
||||
timestamp=pd.Timestamp.now()
|
||||
)
|
||||
|
||||
# Extract features
|
||||
features = self.extract_features(ohlcv_data)
|
||||
|
||||
# Classify market regime
|
||||
regime = self.classify_market_regime(features, ohlcv_data)
|
||||
|
||||
return MarketConditions(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
regime=regime,
|
||||
features=features,
|
||||
timestamp=pd.Timestamp.now()
|
||||
)
|
||||
|
||||
def extract_features(self, df: pd.DataFrame) -> Dict[str, float]:
|
||||
"""Extract comprehensive market features.
|
||||
|
||||
Args:
|
||||
df: OHLCV DataFrame
|
||||
|
||||
Returns:
|
||||
Dictionary of feature names to values
|
||||
"""
|
||||
features = {}
|
||||
|
||||
try:
|
||||
close = df['close']
|
||||
high = df['high']
|
||||
low = df['low']
|
||||
open_price = df['open']
|
||||
volume = df['volume']
|
||||
|
||||
# Trend Features
|
||||
features['sma_20'] = float(close.rolling(20).mean().iloc[-1])
|
||||
features['sma_50'] = float(close.rolling(50).mean().iloc[-1]) if len(close) >= 50 else features['sma_20']
|
||||
features['ema_12'] = float(self.indicators.ema(close, period=12).iloc[-1])
|
||||
features['ema_26'] = float(self.indicators.ema(close, period=26).iloc[-1]) if len(close) >= 26 else features['ema_12']
|
||||
|
||||
# Price position relative to MAs
|
||||
current_price = float(close.iloc[-1])
|
||||
features['price_vs_sma20'] = (current_price - features['sma_20']) / features['sma_20'] if features['sma_20'] > 0 else 0.0
|
||||
features['price_vs_sma50'] = (current_price - features['sma_50']) / features['sma_50'] if features['sma_50'] > 0 else 0.0
|
||||
features['sma20_vs_sma50'] = (features['sma_20'] - features['sma_50']) / features['sma_50'] if features['sma_50'] > 0 else 0.0
|
||||
|
||||
# Trend strength (ADX)
|
||||
if len(df) >= 14:
|
||||
adx = self.indicators.adx(high, low, close, period=14)
|
||||
features['adx'] = float(adx.iloc[-1]) if not pd.isna(adx.iloc[-1]) else 0.0
|
||||
else:
|
||||
features['adx'] = 0.0
|
||||
|
||||
# Momentum Features
|
||||
if len(close) >= 14:
|
||||
rsi = self.indicators.rsi(close, period=14)
|
||||
features['rsi'] = float(rsi.iloc[-1]) if not pd.isna(rsi.iloc[-1]) else 50.0
|
||||
else:
|
||||
features['rsi'] = 50.0
|
||||
|
||||
# MACD
|
||||
if len(close) >= 26:
|
||||
macd_result = self.indicators.macd(close, fast=12, slow=26, signal=9)
|
||||
features['macd'] = float(macd_result['macd'].iloc[-1]) if not pd.isna(macd_result['macd'].iloc[-1]) else 0.0
|
||||
features['macd_signal'] = float(macd_result['signal'].iloc[-1]) if not pd.isna(macd_result['signal'].iloc[-1]) else 0.0
|
||||
features['macd_histogram'] = float(macd_result['histogram'].iloc[-1]) if not pd.isna(macd_result['histogram'].iloc[-1]) else 0.0
|
||||
else:
|
||||
features['macd'] = 0.0
|
||||
features['macd_signal'] = 0.0
|
||||
features['macd_histogram'] = 0.0
|
||||
|
||||
# Volatility Features
|
||||
if len(df) >= 14:
|
||||
atr = self.indicators.atr(high, low, close, period=14)
|
||||
features['atr'] = float(atr.iloc[-1]) if not pd.isna(atr.iloc[-1]) else 0.0
|
||||
features['atr_percent'] = (features['atr'] / current_price) * 100 if current_price > 0 else 0.0
|
||||
else:
|
||||
features['atr'] = 0.0
|
||||
features['atr_percent'] = 0.0
|
||||
|
||||
# Bollinger Bands
|
||||
if len(close) >= 20:
|
||||
bb = self.indicators.bollinger_bands(close, period=20, std_dev=2)
|
||||
features['bb_upper'] = float(bb['upper'].iloc[-1]) if not pd.isna(bb['upper'].iloc[-1]) else current_price
|
||||
features['bb_lower'] = float(bb['lower'].iloc[-1]) if not pd.isna(bb['lower'].iloc[-1]) else current_price
|
||||
features['bb_middle'] = float(bb['middle'].iloc[-1]) if not pd.isna(bb['middle'].iloc[-1]) else current_price
|
||||
features['bb_width'] = (features['bb_upper'] - features['bb_lower']) / features['bb_middle'] if features['bb_middle'] > 0 else 0.0
|
||||
features['bb_position'] = (current_price - features['bb_lower']) / (features['bb_upper'] - features['bb_lower']) if (features['bb_upper'] - features['bb_lower']) > 0 else 0.5
|
||||
else:
|
||||
features['bb_upper'] = current_price
|
||||
features['bb_lower'] = current_price
|
||||
features['bb_middle'] = current_price
|
||||
features['bb_width'] = 0.0
|
||||
features['bb_position'] = 0.5
|
||||
|
||||
# Volume Features
|
||||
if len(volume) >= 20:
|
||||
volume_ma = volume.rolling(20).mean()
|
||||
features['volume_ratio'] = float(volume.iloc[-1] / volume_ma.iloc[-1]) if volume_ma.iloc[-1] > 0 else 1.0
|
||||
features['volume_trend'] = float((volume.iloc[-5:].mean() - volume.iloc[-20:-5].mean()) / volume.iloc[-20:-5].mean()) if volume.iloc[-20:-5].mean() > 0 else 0.0
|
||||
else:
|
||||
features['volume_ratio'] = 1.0
|
||||
features['volume_trend'] = 0.0
|
||||
|
||||
# OBV (On-Balance Volume)
|
||||
if len(close) >= 10:
|
||||
obv = self.indicators.obv(close, volume)
|
||||
features['obv_trend'] = float((obv.iloc[-1] - obv.iloc[-10]) / abs(obv.iloc[-10])) if obv.iloc[-10] != 0 else 0.0
|
||||
else:
|
||||
features['obv_trend'] = 0.0
|
||||
|
||||
# Price Action Features
|
||||
features['price_change_1'] = float((close.iloc[-1] - close.iloc[-2]) / close.iloc[-2]) if len(close) >= 2 and close.iloc[-2] > 0 else 0.0
|
||||
features['price_change_5'] = float((close.iloc[-1] - close.iloc[-6]) / close.iloc[-6]) if len(close) >= 6 and close.iloc[-6] > 0 else 0.0
|
||||
features['price_change_20'] = float((close.iloc[-1] - close.iloc[-21]) / close.iloc[-21]) if len(close) >= 21 and close.iloc[-21] > 0 else 0.0
|
||||
|
||||
# High-Low Range
|
||||
features['high_low_range'] = float((high.iloc[-1] - low.iloc[-1]) / current_price) if current_price > 0 else 0.0
|
||||
features['high_low_range_5'] = float((high.iloc[-5:].max() - low.iloc[-5:].min()) / current_price) if current_price > 0 else 0.0
|
||||
|
||||
# Candlestick Patterns (simplified)
|
||||
if len(df) >= 3:
|
||||
features['is_bullish_candle'] = 1.0 if close.iloc[-1] > open_price.iloc[-1] else 0.0
|
||||
features['candle_body_size'] = float(abs(close.iloc[-1] - open_price.iloc[-1]) / current_price) if current_price > 0 else 0.0
|
||||
features['candle_wick_upper'] = float((high.iloc[-1] - max(close.iloc[-1], open_price.iloc[-1])) / current_price) if current_price > 0 else 0.0
|
||||
features['candle_wick_lower'] = float((min(close.iloc[-1], open_price.iloc[-1]) - low.iloc[-1]) / current_price) if current_price > 0 else 0.0
|
||||
else:
|
||||
features['is_bullish_candle'] = 0.5
|
||||
features['candle_body_size'] = 0.0
|
||||
features['candle_wick_upper'] = 0.0
|
||||
features['candle_wick_lower'] = 0.0
|
||||
|
||||
# Support/Resistance Proximity (simplified - using recent highs/lows)
|
||||
if len(df) >= 20:
|
||||
recent_high = float(high.iloc[-20:].max())
|
||||
recent_low = float(low.iloc[-20:].min())
|
||||
features['distance_to_resistance'] = (recent_high - current_price) / current_price if current_price > 0 else 0.0
|
||||
features['distance_to_support'] = (current_price - recent_low) / current_price if current_price > 0 else 0.0
|
||||
else:
|
||||
features['distance_to_resistance'] = 0.0
|
||||
features['distance_to_support'] = 0.0
|
||||
|
||||
# ==========================================
|
||||
# NEW FEATURES FOR IMPROVED ML ACCURACY
|
||||
# ==========================================
|
||||
|
||||
# Volatility Percentile - current ATR vs 30-day rolling ATR
|
||||
if len(df) >= 30:
|
||||
atr_series = self.indicators.atr(high, low, close, period=14)
|
||||
atr_30_mean = float(atr_series.iloc[-30:].mean()) if len(atr_series) >= 30 else features.get('atr', 0)
|
||||
features['volatility_percentile'] = features.get('atr', 0) / atr_30_mean if atr_30_mean > 0 else 1.0
|
||||
else:
|
||||
features['volatility_percentile'] = 1.0
|
||||
|
||||
# Momentum Delta - 3-period vs 10-period price change (acceleration)
|
||||
if len(close) >= 11:
|
||||
momentum_3 = float((close.iloc[-1] - close.iloc[-4]) / close.iloc[-4]) if close.iloc[-4] > 0 else 0.0
|
||||
momentum_10 = float((close.iloc[-1] - close.iloc[-11]) / close.iloc[-11]) if close.iloc[-11] > 0 else 0.0
|
||||
features['momentum_3'] = momentum_3
|
||||
features['momentum_10'] = momentum_10
|
||||
features['momentum_delta'] = momentum_3 - momentum_10 # Positive = accelerating
|
||||
else:
|
||||
features['momentum_3'] = 0.0
|
||||
features['momentum_10'] = 0.0
|
||||
features['momentum_delta'] = 0.0
|
||||
|
||||
# Trend Strength Change - ADX rate of change
|
||||
if len(df) >= 20:
|
||||
adx_series = self.indicators.adx(high, low, close, period=14)
|
||||
if len(adx_series) >= 5:
|
||||
adx_current = float(adx_series.iloc[-1]) if not pd.isna(adx_series.iloc[-1]) else 0.0
|
||||
adx_prev = float(adx_series.iloc[-5]) if not pd.isna(adx_series.iloc[-5]) else adx_current
|
||||
features['adx_change'] = adx_current - adx_prev # Positive = strengthening trend
|
||||
else:
|
||||
features['adx_change'] = 0.0
|
||||
else:
|
||||
features['adx_change'] = 0.0
|
||||
|
||||
# Volume Climax Detection - extreme volume with reversal candle
|
||||
if len(df) >= 20:
|
||||
volume_std = float(volume.iloc[-20:].std())
|
||||
volume_mean = float(volume.iloc[-20:].mean())
|
||||
current_volume = float(volume.iloc[-1])
|
||||
z_score = (current_volume - volume_mean) / volume_std if volume_std > 0 else 0.0
|
||||
is_reversal = (close.iloc[-1] < open_price.iloc[-1]) != (close.iloc[-2] < open_price.iloc[-2])
|
||||
features['volume_z_score'] = z_score
|
||||
features['volume_climax'] = 1.0 if (z_score > 2.0 and is_reversal) else 0.0
|
||||
else:
|
||||
features['volume_z_score'] = 0.0
|
||||
features['volume_climax'] = 0.0
|
||||
|
||||
# RSI Extremes - overbought/oversold signals
|
||||
rsi_value = features.get('rsi', 50.0)
|
||||
features['rsi_oversold'] = 1.0 if rsi_value < 30 else 0.0
|
||||
features['rsi_overbought'] = 1.0 if rsi_value > 70 else 0.0
|
||||
features['rsi_extreme'] = 1.0 if (rsi_value < 25 or rsi_value > 75) else 0.0
|
||||
|
||||
# Price Position in Range - where is price in N-day range
|
||||
if len(df) >= 20:
|
||||
range_high = float(high.iloc[-20:].max())
|
||||
range_low = float(low.iloc[-20:].min())
|
||||
range_size = range_high - range_low
|
||||
features['price_in_range'] = (current_price - range_low) / range_size if range_size > 0 else 0.5
|
||||
else:
|
||||
features['price_in_range'] = 0.5
|
||||
|
||||
# Trend Alignment - short term vs medium term
|
||||
sma_10 = float(close.rolling(10).mean().iloc[-1]) if len(close) >= 10 else current_price
|
||||
features['sma_10'] = sma_10
|
||||
features['trend_aligned'] = 1.0 if (sma_10 > features.get('sma_20', sma_10) > features.get('sma_50', sma_10)) or \
|
||||
(sma_10 < features.get('sma_20', sma_10) < features.get('sma_50', sma_10)) else 0.0
|
||||
|
||||
# Consecutive Candles - streak of bullish/bearish candles
|
||||
if len(df) >= 5:
|
||||
bullish_streak = 0
|
||||
bearish_streak = 0
|
||||
for i in range(1, 6):
|
||||
if close.iloc[-i] > open_price.iloc[-i]:
|
||||
if bearish_streak == 0:
|
||||
bullish_streak += 1
|
||||
else:
|
||||
break
|
||||
else:
|
||||
if bullish_streak == 0:
|
||||
bearish_streak += 1
|
||||
else:
|
||||
break
|
||||
features['bullish_streak'] = float(bullish_streak)
|
||||
features['bearish_streak'] = float(bearish_streak)
|
||||
else:
|
||||
features['bullish_streak'] = 0.0
|
||||
features['bearish_streak'] = 0.0
|
||||
|
||||
# MACD Crossover proximity
|
||||
macd_val = features.get('macd', 0.0)
|
||||
macd_sig = features.get('macd_signal', 0.0)
|
||||
if macd_sig != 0:
|
||||
features['macd_signal_ratio'] = macd_val / abs(macd_sig) if macd_sig != 0 else 0.0
|
||||
else:
|
||||
features['macd_signal_ratio'] = 0.0
|
||||
|
||||
# Bollinger Band squeeze detection
|
||||
bb_width = features.get('bb_width', 0.0)
|
||||
if len(df) >= 30:
|
||||
bb_history = []
|
||||
for i in range(30):
|
||||
if len(close) > i + 20:
|
||||
bb_temp = self.indicators.bollinger_bands(close.iloc[:-i-1] if i > 0 else close, period=20, std_dev=2)
|
||||
if 'upper' in bb_temp and 'lower' in bb_temp:
|
||||
width = (float(bb_temp['upper'].iloc[-1]) - float(bb_temp['lower'].iloc[-1])) / float(bb_temp['middle'].iloc[-1]) if not pd.isna(bb_temp['middle'].iloc[-1]) and bb_temp['middle'].iloc[-1] > 0 else 0
|
||||
bb_history.append(width)
|
||||
if bb_history:
|
||||
avg_width = np.mean(bb_history)
|
||||
features['bb_squeeze'] = 1.0 if bb_width < avg_width * 0.7 else 0.0
|
||||
else:
|
||||
features['bb_squeeze'] = 0.0
|
||||
else:
|
||||
features['bb_squeeze'] = 0.0
|
||||
|
||||
# ==========================================
|
||||
# SENTIMENT FEATURES
|
||||
# ==========================================
|
||||
|
||||
# Fear & Greed Index (0-100, 0=extreme fear, 100=extreme greed)
|
||||
# Fetch from API or use cached value
|
||||
fear_greed = self._get_fear_greed_index()
|
||||
features['fear_greed_index'] = fear_greed
|
||||
# Normalize to -1 to 1 range (fear negative, greed positive)
|
||||
features['fear_greed_normalized'] = (fear_greed - 50) / 50
|
||||
# Binary indicators
|
||||
features['is_extreme_fear'] = 1.0 if fear_greed < 25 else 0.0
|
||||
features['is_extreme_greed'] = 1.0 if fear_greed > 75 else 0.0
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error extracting features: {e}")
|
||||
# Return empty features on error
|
||||
return {}
|
||||
|
||||
return features
|
||||
|
||||
def _get_fear_greed_index(self) -> float:
|
||||
"""Fetch the current Fear & Greed Index.
|
||||
|
||||
Uses alternative.me API for Bitcoin Fear & Greed Index.
|
||||
Falls back to neutral (50) if unavailable.
|
||||
|
||||
Returns:
|
||||
Fear & Greed value 0-100
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
|
||||
# Check cache
|
||||
cache_key = '_fear_greed_cache'
|
||||
cache_time_key = '_fear_greed_cache_time'
|
||||
|
||||
if hasattr(self, cache_key) and hasattr(self, cache_time_key):
|
||||
# Cache for 1 hour
|
||||
cache_age = (pd.Timestamp.now() - getattr(self, cache_time_key)).total_seconds()
|
||||
if cache_age < 3600:
|
||||
return getattr(self, cache_key)
|
||||
|
||||
# Fetch from API
|
||||
response = requests.get(
|
||||
"https://api.alternative.me/fng/?limit=1",
|
||||
timeout=5
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('data') and len(data['data']) > 0:
|
||||
value = float(data['data'][0].get('value', 50))
|
||||
# Cache the value
|
||||
setattr(self, cache_key, value)
|
||||
setattr(self, cache_time_key, pd.Timestamp.now())
|
||||
return value
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Could not fetch Fear & Greed Index: {e}")
|
||||
|
||||
# Return neutral if unavailable
|
||||
return 50.0
|
||||
|
||||
def classify_market_regime(
|
||||
self,
|
||||
features: Dict[str, float],
|
||||
df: pd.DataFrame
|
||||
) -> MarketRegime:
|
||||
"""Classify market regime based on features.
|
||||
|
||||
Args:
|
||||
features: Extracted features
|
||||
df: OHLCV DataFrame
|
||||
|
||||
Returns:
|
||||
MarketRegime classification
|
||||
"""
|
||||
if not features:
|
||||
return MarketRegime.UNKNOWN
|
||||
|
||||
try:
|
||||
close = df['close']
|
||||
current_price = float(close.iloc[-1])
|
||||
|
||||
# Check for trending conditions
|
||||
adx = features.get('adx', 0.0)
|
||||
price_vs_sma20 = features.get('price_vs_sma20', 0.0)
|
||||
sma20_vs_sma50 = features.get('sma20_vs_sma50', 0.0)
|
||||
|
||||
# Strong uptrend
|
||||
if adx > 25 and price_vs_sma20 > 0.01 and sma20_vs_sma50 > 0.01:
|
||||
return MarketRegime.TRENDING_UP
|
||||
|
||||
# Strong downtrend
|
||||
if adx > 25 and price_vs_sma20 < -0.01 and sma20_vs_sma50 < -0.01:
|
||||
return MarketRegime.TRENDING_DOWN
|
||||
|
||||
# Check volatility
|
||||
atr_percent = features.get('atr_percent', 0.0)
|
||||
bb_width = features.get('bb_width', 0.0)
|
||||
|
||||
if atr_percent > 3.0 or bb_width > 0.05:
|
||||
return MarketRegime.HIGH_VOLATILITY
|
||||
|
||||
if atr_percent < 1.0 and bb_width < 0.02:
|
||||
return MarketRegime.LOW_VOLATILITY
|
||||
|
||||
# Check for ranging (low ADX, price oscillating)
|
||||
if adx < 20:
|
||||
# Check if price is oscillating around MAs
|
||||
price_change_20 = features.get('price_change_20', 0.0)
|
||||
if abs(price_change_20) < 0.05: # Less than 5% change over 20 periods
|
||||
return MarketRegime.RANGING
|
||||
|
||||
# Check for breakout
|
||||
bb_position = features.get('bb_position', 0.5)
|
||||
volume_ratio = features.get('volume_ratio', 1.0)
|
||||
if (bb_position > 0.95 or bb_position < 0.05) and volume_ratio > 1.5:
|
||||
return MarketRegime.BREAKOUT
|
||||
|
||||
# Check for reversal signals
|
||||
rsi = features.get('rsi', 50.0)
|
||||
macd_histogram = features.get('macd_histogram', 0.0)
|
||||
if (rsi > 70 and macd_histogram < 0) or (rsi < 30 and macd_histogram > 0):
|
||||
return MarketRegime.REVERSAL
|
||||
|
||||
# Default to ranging if unclear
|
||||
return MarketRegime.RANGING
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error classifying market regime: {e}")
|
||||
return MarketRegime.UNKNOWN
|
||||
|
||||
|
||||
# Global instance
|
||||
_market_analyzer: Optional[MarketAnalyzer] = None
|
||||
|
||||
|
||||
def get_market_analyzer() -> MarketAnalyzer:
|
||||
"""Get global market analyzer instance."""
|
||||
global _market_analyzer
|
||||
if _market_analyzer is None:
|
||||
_market_analyzer = MarketAnalyzer()
|
||||
return _market_analyzer
|
||||
|
||||
519
src/autopilot/models.py
Normal file
519
src/autopilot/models.py
Normal file
@@ -0,0 +1,519 @@
|
||||
"""ML model definitions and persistence for strategy selection.
|
||||
|
||||
Implements automated model selection with:
|
||||
- XGBoost
|
||||
- LightGBM
|
||||
- RandomForest
|
||||
- MLP Neural Network
|
||||
|
||||
Uses TimeSeriesSplit cross-validation and walk-forward validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.model_selection import TimeSeriesSplit, cross_val_score
|
||||
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
||||
from sklearn.metrics import accuracy_score, top_k_accuracy_score
|
||||
import joblib
|
||||
import warnings
|
||||
|
||||
from src.core.logger import get_logger
|
||||
from src.core.config import get_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Try importing optional dependencies
|
||||
try:
|
||||
from xgboost import XGBClassifier
|
||||
HAS_XGBOOST = True
|
||||
except ImportError:
|
||||
HAS_XGBOOST = False
|
||||
logger.warning("XGBoost not installed. Run: pip install xgboost")
|
||||
|
||||
try:
|
||||
from lightgbm import LGBMClassifier
|
||||
HAS_LIGHTGBM = True
|
||||
except ImportError:
|
||||
HAS_LIGHTGBM = False
|
||||
logger.warning("LightGBM not installed. Run: pip install lightgbm")
|
||||
|
||||
|
||||
class ModelSelector:
|
||||
"""Automated model selection for strategy prediction.
|
||||
|
||||
Trains multiple models, compares via cross-validation, picks the best.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize model selector."""
|
||||
self.logger = get_logger(__name__)
|
||||
self.config = get_config()
|
||||
|
||||
self.scaler = StandardScaler()
|
||||
self.label_encoder = LabelEncoder()
|
||||
self.feature_names: List[str] = []
|
||||
self.strategy_names: List[str] = []
|
||||
self.is_trained = False
|
||||
self.training_metadata: Dict[str, Any] = {}
|
||||
|
||||
# Best model after training
|
||||
self.best_model = None
|
||||
self.best_model_name: str = ""
|
||||
|
||||
# Model storage path
|
||||
self.model_dir = Path.home() / ".local" / "share" / "crypto_trader" / "models"
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_candidate_models(self) -> Dict[str, Any]:
|
||||
"""Get dictionary of candidate models to compare.
|
||||
|
||||
Returns:
|
||||
Dictionary of model name -> model instance
|
||||
"""
|
||||
models = {}
|
||||
|
||||
# RandomForest - stable, interpretable baseline
|
||||
models['random_forest'] = RandomForestClassifier(
|
||||
n_estimators=200,
|
||||
max_depth=15,
|
||||
min_samples_split=10,
|
||||
min_samples_leaf=4,
|
||||
class_weight='balanced',
|
||||
random_state=42,
|
||||
n_jobs=-1
|
||||
)
|
||||
|
||||
# XGBoost - typically best for tabular data
|
||||
if HAS_XGBOOST:
|
||||
models['xgboost'] = XGBClassifier(
|
||||
n_estimators=200,
|
||||
max_depth=8,
|
||||
learning_rate=0.1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
n_jobs=-1,
|
||||
use_label_encoder=False,
|
||||
eval_metric='mlogloss'
|
||||
)
|
||||
|
||||
# LightGBM - fast, handles noisy data well
|
||||
if HAS_LIGHTGBM:
|
||||
models['lightgbm'] = LGBMClassifier(
|
||||
n_estimators=200,
|
||||
max_depth=10,
|
||||
learning_rate=0.1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
class_weight='balanced',
|
||||
random_state=42,
|
||||
n_jobs=-1,
|
||||
verbose=-1
|
||||
)
|
||||
|
||||
# MLP Neural Network - high ceiling with enough data
|
||||
models['mlp'] = MLPClassifier(
|
||||
hidden_layer_sizes=(128, 64, 32),
|
||||
activation='relu',
|
||||
solver='adam',
|
||||
alpha=0.01, # L2 regularization
|
||||
batch_size='auto',
|
||||
learning_rate='adaptive',
|
||||
learning_rate_init=0.001,
|
||||
max_iter=500,
|
||||
early_stopping=True,
|
||||
validation_fraction=0.1,
|
||||
n_iter_no_change=20,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
return models
|
||||
|
||||
def train(
|
||||
self,
|
||||
X: pd.DataFrame,
|
||||
y: np.ndarray,
|
||||
strategy_names: List[str],
|
||||
use_ensemble: bool = False,
|
||||
n_splits: int = 5,
|
||||
training_symbols: List[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Train and select best model via cross-validation.
|
||||
|
||||
Args:
|
||||
X: Feature matrix (market conditions)
|
||||
y: Target values (strategy names)
|
||||
strategy_names: List of strategy names
|
||||
use_ensemble: If True, use voting ensemble of top models instead of single best
|
||||
n_splits: Number of cross-validation splits
|
||||
training_symbols: List of symbols used for training
|
||||
|
||||
Returns:
|
||||
Training metrics dictionary
|
||||
"""
|
||||
self.strategy_names = strategy_names
|
||||
self.feature_names = list(X.columns)
|
||||
self.training_symbols = training_symbols or []
|
||||
|
||||
# Encode labels
|
||||
y_encoded = self.label_encoder.fit_transform(y)
|
||||
|
||||
# Scale features
|
||||
X_scaled = self.scaler.fit_transform(X)
|
||||
|
||||
# Time series cross-validation
|
||||
tscv = TimeSeriesSplit(n_splits=n_splits)
|
||||
|
||||
# Get candidate models
|
||||
models = self._get_candidate_models()
|
||||
|
||||
self.logger.info(f"Training {len(models)} candidate models with {len(X)} samples...")
|
||||
|
||||
# Evaluate each model
|
||||
model_scores = {}
|
||||
model_cv_results = {}
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
for name, model in models.items():
|
||||
try:
|
||||
self.logger.info(f"Evaluating {name}...")
|
||||
scores = cross_val_score(
|
||||
model, X_scaled, y_encoded,
|
||||
cv=tscv, scoring='accuracy', n_jobs=-1
|
||||
)
|
||||
# Filter out nan scores before calculating mean
|
||||
valid_scores = [s for s in scores if not np.isnan(s)]
|
||||
if valid_scores:
|
||||
mean_score = np.mean(valid_scores)
|
||||
std_score = np.std(valid_scores) if len(valid_scores) > 1 else 0.0
|
||||
model_scores[name] = mean_score
|
||||
model_cv_results[name] = {
|
||||
'mean': float(mean_score),
|
||||
'std': float(std_score),
|
||||
'scores': [float(s) if not np.isnan(s) else 0.0 for s in scores]
|
||||
}
|
||||
self.logger.info(f" {name}: {mean_score:.3f} (+/- {std_score:.3f})")
|
||||
else:
|
||||
self.logger.warning(f" {name}: All scores were nan, skipping")
|
||||
model_scores[name] = 0.0
|
||||
except Exception as e:
|
||||
self.logger.warning(f" {name} failed: {e}")
|
||||
model_scores[name] = 0.0
|
||||
|
||||
# Filter out models with zero scores before selecting best
|
||||
valid_model_scores = {k: v for k, v in model_scores.items() if v > 0}
|
||||
|
||||
# Select best model or create ensemble
|
||||
if not valid_model_scores:
|
||||
self.logger.warning("No models trained successfully, using RandomForest as fallback")
|
||||
self.best_model_name = "random_forest"
|
||||
self.best_model = models["random_forest"]
|
||||
elif use_ensemble and len(valid_model_scores) >= 2:
|
||||
# Use top 3 models in voting ensemble, but only if they have at least 10% accuracy
|
||||
MIN_ENSEMBLE_ACCURACY = 0.10
|
||||
good_models = {k: v for k, v in valid_model_scores.items() if v >= MIN_ENSEMBLE_ACCURACY}
|
||||
|
||||
if len(good_models) >= 2:
|
||||
top_models = sorted(good_models.items(), key=lambda x: x[1], reverse=True)[:3]
|
||||
estimators = [(name, models[name]) for name, _ in top_models]
|
||||
|
||||
self.best_model = VotingClassifier(estimators=estimators, voting='soft')
|
||||
self.best_model_name = "ensemble"
|
||||
self.logger.info(f"Using ensemble of: {[n for n, _ in top_models]} (excluding models below {MIN_ENSEMBLE_ACCURACY:.0%} accuracy)")
|
||||
else:
|
||||
# Not enough good models for ensemble, use single best
|
||||
self.best_model_name = max(valid_model_scores, key=valid_model_scores.get)
|
||||
self.best_model = models[self.best_model_name]
|
||||
self.logger.info(f"Not enough models for ensemble (min {MIN_ENSEMBLE_ACCURACY:.0%}), using best: {self.best_model_name}")
|
||||
else:
|
||||
# Use single best model
|
||||
self.best_model_name = max(valid_model_scores, key=valid_model_scores.get)
|
||||
self.best_model = models[self.best_model_name]
|
||||
self.logger.info(f"Best model: {self.best_model_name} ({valid_model_scores[self.best_model_name]:.3f})")
|
||||
|
||||
# Train best model on full dataset
|
||||
self.best_model.fit(X_scaled, y_encoded)
|
||||
|
||||
# Calculate final metrics with walk-forward validation
|
||||
train_pred = self.best_model.predict(X_scaled)
|
||||
train_acc = accuracy_score(y_encoded, train_pred)
|
||||
|
||||
# Top-3 accuracy if we have probabilities
|
||||
top3_acc = None
|
||||
try:
|
||||
if hasattr(self.best_model, 'predict_proba'):
|
||||
proba = self.best_model.predict_proba(X_scaled)
|
||||
k = min(3, len(self.strategy_names))
|
||||
top3_acc = float(top_k_accuracy_score(y_encoded, proba, k=k))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Run walk-forward validation
|
||||
wf_scores = self._walk_forward_validation(X_scaled, y_encoded)
|
||||
|
||||
# Calculate estimated CV accuracy for ensemble (average of component models)
|
||||
if self.best_model_name == "ensemble" and len(valid_model_scores) > 0:
|
||||
# Use average of models in the ensemble for CV estimate
|
||||
ensemble_cv_acc = np.mean([v for v in valid_model_scores.values() if v >= 0.10])
|
||||
else:
|
||||
ensemble_cv_acc = model_scores.get(self.best_model_name, train_acc)
|
||||
|
||||
metrics = {
|
||||
'train_accuracy': float(train_acc),
|
||||
'test_accuracy': float(ensemble_cv_acc),
|
||||
'cv_mean_accuracy': float(ensemble_cv_acc),
|
||||
'walk_forward_accuracy': float(np.mean(wf_scores)) if wf_scores else None,
|
||||
'top3_accuracy': top3_acc,
|
||||
'n_samples': len(X),
|
||||
'n_features': len(self.feature_names),
|
||||
'n_strategies': len(strategy_names),
|
||||
'best_model': self.best_model_name,
|
||||
'all_model_scores': model_cv_results
|
||||
}
|
||||
|
||||
self.is_trained = True
|
||||
self.training_metadata = {
|
||||
'trained_at': datetime.utcnow().isoformat(),
|
||||
'metrics': metrics,
|
||||
'model_type': 'classifier',
|
||||
'best_model_name': self.best_model_name,
|
||||
'training_symbols': self.training_symbols if hasattr(self, 'training_symbols') else []
|
||||
}
|
||||
|
||||
self.logger.info(f"Training complete. Best: {self.best_model_name}, "
|
||||
f"CV accuracy: {metrics['cv_mean_accuracy']:.1%}")
|
||||
|
||||
return metrics
|
||||
|
||||
def _walk_forward_validation(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
train_ratio: float = 0.7,
|
||||
step_size: int = None
|
||||
) -> List[float]:
|
||||
"""Perform walk-forward validation.
|
||||
|
||||
Train on months 1-6, test on 7. Train on 1-7, test on 8. Etc.
|
||||
|
||||
Args:
|
||||
X: Scaled feature matrix
|
||||
y: Encoded labels
|
||||
train_ratio: Initial training set ratio
|
||||
step_size: Step size for rolling window
|
||||
|
||||
Returns:
|
||||
List of accuracy scores for each step
|
||||
"""
|
||||
n_samples = len(X)
|
||||
if n_samples < 100:
|
||||
self.logger.warning("Insufficient data for walk-forward validation")
|
||||
return []
|
||||
|
||||
initial_train_size = int(n_samples * train_ratio)
|
||||
if step_size is None:
|
||||
step_size = max(10, n_samples // 20)
|
||||
|
||||
scores = []
|
||||
models = self._get_candidate_models()
|
||||
model = models.get(self.best_model_name, list(models.values())[0])
|
||||
|
||||
for i in range(initial_train_size, n_samples - step_size, step_size):
|
||||
X_train = X[:i]
|
||||
y_train = y[:i]
|
||||
X_test = X[i:i + step_size]
|
||||
y_test = y[i:i + step_size]
|
||||
|
||||
try:
|
||||
model_clone = type(model)(**model.get_params())
|
||||
model_clone.fit(X_train, y_train)
|
||||
pred = model_clone.predict(X_test)
|
||||
scores.append(accuracy_score(y_test, pred))
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Walk-forward step failed: {e}")
|
||||
|
||||
if scores:
|
||||
self.logger.info(f"Walk-forward validation: {np.mean(scores):.3f} "
|
||||
f"(+/- {np.std(scores):.3f}) over {len(scores)} steps")
|
||||
|
||||
return scores
|
||||
|
||||
def predict(
|
||||
self,
|
||||
features: Dict[str, float]
|
||||
) -> Tuple[str, float, Dict[str, float]]:
|
||||
"""Predict best strategy for given market conditions.
|
||||
|
||||
Args:
|
||||
features: Market condition features
|
||||
|
||||
Returns:
|
||||
Tuple of (best_strategy_name, confidence_score, all_predictions)
|
||||
"""
|
||||
if not self.is_trained:
|
||||
raise ValueError("Model not trained. Call train() first.")
|
||||
|
||||
# Convert features to DataFrame
|
||||
feature_df = pd.DataFrame([features])
|
||||
|
||||
# Ensure all required features are present
|
||||
for feat in self.feature_names:
|
||||
if feat not in feature_df.columns:
|
||||
feature_df[feat] = 0.0
|
||||
|
||||
# Select only required features in correct order
|
||||
X = feature_df[self.feature_names]
|
||||
|
||||
# Scale
|
||||
X_scaled = self.scaler.transform(X)
|
||||
|
||||
# Get probabilities for each strategy
|
||||
if hasattr(self.best_model, 'predict_proba'):
|
||||
probabilities = self.best_model.predict_proba(X_scaled)[0]
|
||||
else:
|
||||
# Fallback for models without predict_proba
|
||||
pred = self.best_model.predict(X_scaled)[0]
|
||||
probabilities = np.zeros(len(self.strategy_names))
|
||||
probabilities[pred] = 1.0
|
||||
|
||||
# Map probabilities to strategy names
|
||||
encoded_classes = self.label_encoder.classes_
|
||||
strategy_probs = {}
|
||||
for idx, prob in enumerate(probabilities):
|
||||
if idx < len(encoded_classes):
|
||||
strategy_name = encoded_classes[idx]
|
||||
strategy_probs[strategy_name] = float(prob)
|
||||
|
||||
# Get best strategy
|
||||
best_idx = np.argmax(probabilities)
|
||||
best_strategy = encoded_classes[best_idx]
|
||||
confidence = float(probabilities[best_idx])
|
||||
|
||||
return best_strategy, confidence, strategy_probs
|
||||
|
||||
def get_feature_importance(self) -> Dict[str, float]:
|
||||
"""Get feature importance scores.
|
||||
|
||||
Returns:
|
||||
Dictionary of feature names to importance scores (JSON serializable)
|
||||
"""
|
||||
if not self.is_trained or self.best_model is None:
|
||||
return {}
|
||||
|
||||
try:
|
||||
if hasattr(self.best_model, 'feature_importances_'):
|
||||
importances = self.best_model.feature_importances_
|
||||
# Convert numpy types to native Python floats for JSON serialization
|
||||
return {name: float(val) for name, val in zip(self.feature_names, importances)}
|
||||
elif hasattr(self.best_model, 'coefs_'):
|
||||
# For MLP, use absolute mean of first layer weights
|
||||
importances = np.abs(self.best_model.coefs_[0]).mean(axis=1)
|
||||
return {name: float(val) for name, val in zip(self.feature_names, importances)}
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not get feature importance: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
def save(self, filename: Optional[str] = None) -> str:
|
||||
"""Save model to disk.
|
||||
|
||||
Args:
|
||||
filename: Optional custom filename
|
||||
|
||||
Returns:
|
||||
Path to saved model
|
||||
"""
|
||||
if filename is None:
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"strategy_selector_{self.best_model_name}_{timestamp}.joblib"
|
||||
|
||||
filepath = self.model_dir / filename
|
||||
|
||||
model_data = {
|
||||
'model': self.best_model,
|
||||
'scaler': self.scaler,
|
||||
'label_encoder': self.label_encoder,
|
||||
'feature_names': self.feature_names,
|
||||
'strategy_names': self.strategy_names,
|
||||
'is_trained': self.is_trained,
|
||||
'training_metadata': self.training_metadata,
|
||||
'model_type': 'classifier',
|
||||
'best_model_name': self.best_model_name
|
||||
}
|
||||
|
||||
joblib.dump(model_data, filepath)
|
||||
self.logger.info(f"Model saved to {filepath}")
|
||||
return str(filepath)
|
||||
|
||||
def load(self, filepath: str) -> bool:
|
||||
"""Load model from disk.
|
||||
|
||||
Args:
|
||||
filepath: Path to model file
|
||||
|
||||
Returns:
|
||||
True if loaded successfully
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(filepath):
|
||||
self.logger.error(f"Model file not found: {filepath}")
|
||||
return False
|
||||
|
||||
model_data = joblib.load(filepath)
|
||||
|
||||
self.best_model = model_data['model']
|
||||
self.scaler = model_data['scaler']
|
||||
self.label_encoder = model_data.get('label_encoder', LabelEncoder())
|
||||
self.feature_names = model_data['feature_names']
|
||||
self.strategy_names = model_data['strategy_names']
|
||||
self.is_trained = model_data['is_trained']
|
||||
self.training_metadata = model_data.get('training_metadata', {})
|
||||
self.best_model_name = model_data.get('best_model_name', 'unknown')
|
||||
|
||||
self.logger.info(f"Model loaded from {filepath}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load model: {e}")
|
||||
return False
|
||||
|
||||
def get_latest_model_path(self) -> Optional[str]:
|
||||
"""Get path to latest saved model.
|
||||
|
||||
Returns:
|
||||
Path to latest model or None
|
||||
"""
|
||||
model_files = list(self.model_dir.glob("strategy_selector_*.joblib"))
|
||||
if not model_files:
|
||||
return None
|
||||
|
||||
# Sort by modification time
|
||||
latest = max(model_files, key=lambda p: p.stat().st_mtime)
|
||||
return str(latest)
|
||||
|
||||
|
||||
# Backwards-compatible alias
|
||||
class StrategySelectorModel(ModelSelector):
|
||||
"""Backwards-compatible alias for ModelSelector.
|
||||
|
||||
Now uses automated model selection instead of just RandomForest.
|
||||
"""
|
||||
|
||||
def __init__(self, model_type: str = "classifier"):
|
||||
"""Initialize model.
|
||||
|
||||
Args:
|
||||
model_type: Model type (only 'classifier' supported now)
|
||||
"""
|
||||
super().__init__()
|
||||
self.model_type = model_type
|
||||
if model_type != "classifier":
|
||||
self.logger.warning("Only 'classifier' model type is supported. Using classifier.")
|
||||
351
src/autopilot/performance_tracker.py
Normal file
351
src/autopilot/performance_tracker.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""Strategy performance tracker for ML training data collection."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from src.core.database import get_database, Trade, Strategy
|
||||
from src.core.logger import get_logger
|
||||
from .market_analyzer import MarketConditions
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PerformanceTracker:
|
||||
"""Tracks strategy performance for ML training."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize performance tracker."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def record_trade(
|
||||
self,
|
||||
strategy_name: str,
|
||||
market_conditions: MarketConditions,
|
||||
trade_result: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Record a trade for ML training.
|
||||
|
||||
Args:
|
||||
strategy_name: Name of strategy used
|
||||
market_conditions: Market conditions at trade time
|
||||
trade_result: Trade result with performance metrics
|
||||
|
||||
Returns:
|
||||
True if recorded successfully
|
||||
"""
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
# Store market conditions snapshot
|
||||
from src.core.database import MarketConditionsSnapshot
|
||||
|
||||
snapshot = MarketConditionsSnapshot(
|
||||
symbol=market_conditions.symbol,
|
||||
timeframe=market_conditions.timeframe,
|
||||
regime=market_conditions.regime.value,
|
||||
features=market_conditions.features,
|
||||
strategy_name=strategy_name,
|
||||
timestamp=market_conditions.timestamp
|
||||
)
|
||||
session.add(snapshot)
|
||||
|
||||
# Store performance record
|
||||
from src.core.database import StrategyPerformance
|
||||
|
||||
performance = StrategyPerformance(
|
||||
strategy_name=strategy_name,
|
||||
symbol=market_conditions.symbol,
|
||||
timeframe=market_conditions.timeframe,
|
||||
market_regime=market_conditions.regime.value,
|
||||
return_pct=float(trade_result.get('return_pct', 0.0)),
|
||||
sharpe_ratio=float(trade_result.get('sharpe_ratio', 0.0)),
|
||||
win_rate=float(trade_result.get('win_rate', 0.0)),
|
||||
max_drawdown=float(trade_result.get('max_drawdown', 0.0)),
|
||||
trade_count=int(trade_result.get('trade_count', 1)),
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
session.add(performance)
|
||||
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.logger.error(f"Failed to record trade: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error recording trade: {e}")
|
||||
return False
|
||||
|
||||
async def get_performance_history(
|
||||
self,
|
||||
strategy_name: Optional[str] = None,
|
||||
market_regime: Optional[str] = None,
|
||||
days: int = 30
|
||||
) -> pd.DataFrame:
|
||||
"""Get performance history for training.
|
||||
|
||||
Args:
|
||||
strategy_name: Filter by strategy name
|
||||
market_regime: Filter by market regime
|
||||
days: Number of days to look back
|
||||
|
||||
Returns:
|
||||
DataFrame with performance history
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
from src.core.database import StrategyPerformance, MarketConditionsSnapshot
|
||||
|
||||
# Query performance records
|
||||
stmt = select(StrategyPerformance)
|
||||
|
||||
if strategy_name:
|
||||
stmt = stmt.where(StrategyPerformance.strategy_name == strategy_name)
|
||||
if market_regime:
|
||||
stmt = stmt.where(StrategyPerformance.market_regime == market_regime)
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
stmt = stmt.where(StrategyPerformance.timestamp >= cutoff_date)
|
||||
# Limit to prevent excessive queries - if we have lots of data, sample it
|
||||
stmt = stmt.order_by(StrategyPerformance.timestamp.desc()).limit(10000)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
records = result.scalars().all()
|
||||
|
||||
if not records:
|
||||
return pd.DataFrame()
|
||||
|
||||
self.logger.info(f"Processing {len(records)} performance records for training data")
|
||||
|
||||
# Convert to DataFrame - optimize by batching snapshot queries
|
||||
data = []
|
||||
# Batch snapshot lookups to reduce N+1 query problem
|
||||
snapshot_cache = {}
|
||||
for record in records:
|
||||
cache_key = f"{record.strategy_name}:{record.symbol}:{record.timeframe}:{record.timestamp.date()}"
|
||||
|
||||
if cache_key not in snapshot_cache:
|
||||
# Get corresponding market conditions (only once per day per strategy)
|
||||
snapshot_stmt = select(MarketConditionsSnapshot).filter_by(
|
||||
strategy_name=record.strategy_name,
|
||||
symbol=record.symbol,
|
||||
timeframe=record.timeframe
|
||||
).where(
|
||||
MarketConditionsSnapshot.timestamp <= record.timestamp
|
||||
).order_by(
|
||||
MarketConditionsSnapshot.timestamp.desc()
|
||||
).limit(1)
|
||||
|
||||
snapshot_result = await session.execute(snapshot_stmt)
|
||||
snapshot = snapshot_result.scalar_one_or_none()
|
||||
snapshot_cache[cache_key] = snapshot
|
||||
else:
|
||||
snapshot = snapshot_cache[cache_key]
|
||||
|
||||
row = {
|
||||
'strategy_name': record.strategy_name,
|
||||
'symbol': record.symbol,
|
||||
'timeframe': record.timeframe,
|
||||
'market_regime': record.market_regime,
|
||||
'return_pct': float(record.return_pct),
|
||||
'sharpe_ratio': float(record.sharpe_ratio),
|
||||
'win_rate': float(record.win_rate),
|
||||
'max_drawdown': float(record.max_drawdown),
|
||||
'trade_count': int(record.trade_count),
|
||||
'timestamp': record.timestamp
|
||||
}
|
||||
|
||||
# Add market features if available
|
||||
if snapshot and snapshot.features:
|
||||
row.update(snapshot.features)
|
||||
|
||||
data.append(row)
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting performance history: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
async def calculate_metrics(
|
||||
self,
|
||||
strategy_name: str,
|
||||
period_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculate performance metrics for a strategy.
|
||||
|
||||
Args:
|
||||
strategy_name: Strategy name
|
||||
period_days: Period to calculate metrics over
|
||||
|
||||
Returns:
|
||||
Dictionary of performance metrics
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
from src.core.database import StrategyPerformance
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=period_days)
|
||||
|
||||
stmt = select(StrategyPerformance).where(
|
||||
StrategyPerformance.strategy_name == strategy_name,
|
||||
StrategyPerformance.timestamp >= cutoff_date
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
records = result.scalars().all()
|
||||
|
||||
if not records:
|
||||
return {
|
||||
'total_trades': 0,
|
||||
'win_rate': 0.0,
|
||||
'avg_return': 0.0,
|
||||
'total_return': 0.0,
|
||||
'sharpe_ratio': 0.0,
|
||||
'max_drawdown': 0.0
|
||||
}
|
||||
|
||||
returns = [float(r.return_pct) for r in records]
|
||||
wins = [r for r in returns if r > 0]
|
||||
|
||||
total_return = sum(returns)
|
||||
avg_return = np.mean(returns) if returns else 0.0
|
||||
win_rate = len(wins) / len(returns) if returns else 0.0
|
||||
|
||||
# Calculate Sharpe ratio (simplified)
|
||||
if len(returns) > 1:
|
||||
sharpe = (avg_return / np.std(returns)) * np.sqrt(252) if np.std(returns) > 0 else 0.0
|
||||
else:
|
||||
sharpe = 0.0
|
||||
|
||||
# Max drawdown
|
||||
max_dd = max([float(r.max_drawdown) for r in records]) if records else 0.0
|
||||
|
||||
return {
|
||||
'total_trades': len(records),
|
||||
'win_rate': float(win_rate),
|
||||
'avg_return': float(avg_return),
|
||||
'total_return': float(total_return),
|
||||
'sharpe_ratio': float(sharpe),
|
||||
'max_drawdown': float(max_dd)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating metrics: {e}")
|
||||
return {}
|
||||
|
||||
async def prepare_training_data(
|
||||
self,
|
||||
min_samples_per_strategy: int = 10
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Prepare training data for ML model.
|
||||
|
||||
Args:
|
||||
min_samples_per_strategy: Minimum samples required per strategy
|
||||
|
||||
Returns:
|
||||
Dictionary with 'X' (features), 'y' (targets), and 'strategy_names'
|
||||
"""
|
||||
try:
|
||||
# Get all performance history - extended to 365 days for better training
|
||||
df = await self.get_performance_history(days=365)
|
||||
|
||||
if df.empty:
|
||||
self.logger.warning("No performance history available for training")
|
||||
return None
|
||||
|
||||
# Filter strategies with enough samples
|
||||
strategy_counts = df['strategy_name'].value_counts()
|
||||
valid_strategies = strategy_counts[strategy_counts >= min_samples_per_strategy].index.tolist()
|
||||
|
||||
if not valid_strategies:
|
||||
self.logger.warning(f"No strategies with at least {min_samples_per_strategy} samples")
|
||||
return None
|
||||
|
||||
df = df[df['strategy_name'].isin(valid_strategies)]
|
||||
|
||||
# Extract features (exclude metadata columns)
|
||||
feature_cols = [col for col in df.columns if col not in [
|
||||
'strategy_name', 'symbol', 'timeframe', 'market_regime',
|
||||
'return_pct', 'sharpe_ratio', 'win_rate', 'max_drawdown',
|
||||
'trade_count', 'timestamp'
|
||||
]]
|
||||
|
||||
# Fill missing features with 0
|
||||
for col in feature_cols:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
|
||||
|
||||
X = df[feature_cols]
|
||||
y = df['strategy_name'].values
|
||||
|
||||
return {
|
||||
'X': X,
|
||||
'y': y,
|
||||
'strategy_names': valid_strategies,
|
||||
'feature_names': feature_cols,
|
||||
'training_symbols': sorted(df['symbol'].unique().tolist())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error preparing training data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_strategy_sample_counts(self, days: int = 365) -> Dict[str, int]:
|
||||
"""Get sample counts per strategy.
|
||||
|
||||
Args:
|
||||
days: Number of days to look back
|
||||
|
||||
Returns:
|
||||
Dictionary mapping strategy names to sample counts
|
||||
"""
|
||||
from sqlalchemy import select, func
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
from src.core.database import StrategyPerformance
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
# Group by strategy name and count
|
||||
stmt = select(
|
||||
StrategyPerformance.strategy_name,
|
||||
func.count(StrategyPerformance.id)
|
||||
).where(
|
||||
StrategyPerformance.timestamp >= cutoff_date
|
||||
).group_by(
|
||||
StrategyPerformance.strategy_name
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
counts = dict(result.all())
|
||||
|
||||
return counts
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting strategy sample counts: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# Global instance
|
||||
_performance_tracker: Optional[PerformanceTracker] = None
|
||||
|
||||
|
||||
def get_performance_tracker() -> PerformanceTracker:
|
||||
"""Get global performance tracker instance."""
|
||||
global _performance_tracker
|
||||
if _performance_tracker is None:
|
||||
_performance_tracker = PerformanceTracker()
|
||||
return _performance_tracker
|
||||
|
||||
190
src/autopilot/strategy_groups.py
Normal file
190
src/autopilot/strategy_groups.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Strategy grouping for improved ML accuracy.
|
||||
|
||||
Groups 14 individual strategies into 5 logical groups to reduce classification
|
||||
complexity and improve model accuracy.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from enum import Enum
|
||||
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StrategyGroup(str, Enum):
|
||||
"""Strategy group classifications."""
|
||||
TREND_FOLLOWING = "trend_following"
|
||||
MEAN_REVERSION = "mean_reversion"
|
||||
MOMENTUM = "momentum"
|
||||
MARKET_MAKING = "market_making"
|
||||
SENTIMENT_BASED = "sentiment_based"
|
||||
|
||||
|
||||
# Map each strategy to its group
|
||||
STRATEGY_TO_GROUP: Dict[str, StrategyGroup] = {
|
||||
# Trend Following: Follow established trends
|
||||
"moving_average": StrategyGroup.TREND_FOLLOWING,
|
||||
"macd": StrategyGroup.TREND_FOLLOWING,
|
||||
"confirmed": StrategyGroup.TREND_FOLLOWING,
|
||||
|
||||
# Mean Reversion: Bet on price returning to mean
|
||||
"rsi": StrategyGroup.MEAN_REVERSION,
|
||||
"bollinger_mean_reversion": StrategyGroup.MEAN_REVERSION,
|
||||
"grid": StrategyGroup.MEAN_REVERSION,
|
||||
"divergence": StrategyGroup.MEAN_REVERSION,
|
||||
|
||||
# Momentum: Capture fast price moves
|
||||
"momentum": StrategyGroup.MOMENTUM,
|
||||
"volatility_breakout": StrategyGroup.MOMENTUM,
|
||||
|
||||
# Market Making: Profit from bid-ask spread
|
||||
"market_making": StrategyGroup.MARKET_MAKING,
|
||||
"dca": StrategyGroup.MARKET_MAKING,
|
||||
|
||||
# Sentiment Based: Use external signals
|
||||
"sentiment": StrategyGroup.SENTIMENT_BASED,
|
||||
"pairs_trading": StrategyGroup.SENTIMENT_BASED,
|
||||
"consensus": StrategyGroup.SENTIMENT_BASED,
|
||||
}
|
||||
|
||||
# Reverse mapping: group -> list of strategies
|
||||
GROUP_TO_STRATEGIES: Dict[StrategyGroup, List[str]] = {}
|
||||
for strategy, group in STRATEGY_TO_GROUP.items():
|
||||
if group not in GROUP_TO_STRATEGIES:
|
||||
GROUP_TO_STRATEGIES[group] = []
|
||||
GROUP_TO_STRATEGIES[group].append(strategy)
|
||||
|
||||
|
||||
def get_strategy_group(strategy_name: str) -> Optional[StrategyGroup]:
|
||||
"""Get the group a strategy belongs to.
|
||||
|
||||
Args:
|
||||
strategy_name: Name of the strategy (case-insensitive)
|
||||
|
||||
Returns:
|
||||
StrategyGroup or None if strategy not found
|
||||
"""
|
||||
return STRATEGY_TO_GROUP.get(strategy_name.lower())
|
||||
|
||||
|
||||
def get_strategies_in_group(group: StrategyGroup) -> List[str]:
|
||||
"""Get all strategies belonging to a group.
|
||||
|
||||
Args:
|
||||
group: Strategy group
|
||||
|
||||
Returns:
|
||||
List of strategy names in the group
|
||||
"""
|
||||
return GROUP_TO_STRATEGIES.get(group, [])
|
||||
|
||||
|
||||
def get_all_groups() -> List[StrategyGroup]:
|
||||
"""Get all available strategy groups.
|
||||
|
||||
Returns:
|
||||
List of all strategy groups
|
||||
"""
|
||||
return list(StrategyGroup)
|
||||
|
||||
|
||||
def get_best_strategy_in_group(
|
||||
group: StrategyGroup,
|
||||
market_features: Dict[str, float],
|
||||
available_strategies: Optional[List[str]] = None
|
||||
) -> Tuple[str, float]:
|
||||
"""Select the best individual strategy within a group based on market conditions.
|
||||
|
||||
Uses rule-based heuristics to pick the optimal strategy when the ML model
|
||||
has predicted a group.
|
||||
|
||||
Args:
|
||||
group: The predicted strategy group
|
||||
market_features: Current market condition features
|
||||
available_strategies: Optional list of available strategies (filters choices)
|
||||
|
||||
Returns:
|
||||
Tuple of (best_strategy_name, confidence_score)
|
||||
"""
|
||||
strategies = get_strategies_in_group(group)
|
||||
|
||||
# Filter by availability if provided
|
||||
if available_strategies:
|
||||
strategies = [s for s in strategies if s in available_strategies]
|
||||
|
||||
if not strategies:
|
||||
# Fallback if no strategies available in group
|
||||
logger.warning(f"No strategies available in group {group}")
|
||||
return ("rsi", 0.5) # Safe default
|
||||
|
||||
# If only one option, return it
|
||||
if len(strategies) == 1:
|
||||
return (strategies[0], 0.7)
|
||||
|
||||
# Rule-based selection within group based on market features
|
||||
rsi = market_features.get("rsi", 50.0)
|
||||
adx = market_features.get("adx", 25.0)
|
||||
volatility = market_features.get("atr_percent", 2.0)
|
||||
volume_ratio = market_features.get("volume_ratio", 1.0)
|
||||
|
||||
if group == StrategyGroup.TREND_FOLLOWING:
|
||||
# Strong trends: use confirmed, moderate: moving_average, diverging: macd
|
||||
if adx > 30:
|
||||
return ("confirmed", 0.75) if "confirmed" in strategies else ("moving_average", 0.7)
|
||||
elif adx > 20:
|
||||
return ("moving_average", 0.7) if "moving_average" in strategies else ("macd", 0.65)
|
||||
else:
|
||||
return ("macd", 0.6) if "macd" in strategies else (strategies[0], 0.55)
|
||||
|
||||
elif group == StrategyGroup.MEAN_REVERSION:
|
||||
# Extreme RSI: rsi strategy, tight range: bollinger, low vol: grid
|
||||
if rsi < 30 or rsi > 70:
|
||||
return ("rsi", 0.75) if "rsi" in strategies else ("bollinger_mean_reversion", 0.7)
|
||||
elif volatility < 1.5:
|
||||
return ("grid", 0.7) if "grid" in strategies else ("bollinger_mean_reversion", 0.65)
|
||||
else:
|
||||
return ("bollinger_mean_reversion", 0.65) if "bollinger_mean_reversion" in strategies else (strategies[0], 0.6)
|
||||
|
||||
elif group == StrategyGroup.MOMENTUM:
|
||||
# High volume spike: volatility_breakout, otherwise momentum
|
||||
if volume_ratio > 1.5:
|
||||
return ("volatility_breakout", 0.75) if "volatility_breakout" in strategies else ("momentum", 0.7)
|
||||
else:
|
||||
return ("momentum", 0.7) if "momentum" in strategies else (strategies[0], 0.65)
|
||||
|
||||
elif group == StrategyGroup.MARKET_MAKING:
|
||||
# Low volatility: market_making, otherwise dca
|
||||
if volatility < 2.0:
|
||||
return ("market_making", 0.7) if "market_making" in strategies else ("dca", 0.65)
|
||||
else:
|
||||
return ("dca", 0.7) if "dca" in strategies else (strategies[0], 0.6)
|
||||
|
||||
elif group == StrategyGroup.SENTIMENT_BASED:
|
||||
# Default to sentiment, fall back to consensus
|
||||
if "sentiment" in strategies:
|
||||
return ("sentiment", 0.65)
|
||||
elif "consensus" in strategies:
|
||||
return ("consensus", 0.6)
|
||||
else:
|
||||
return (strategies[0], 0.55)
|
||||
|
||||
# Fallback
|
||||
return (strategies[0], 0.5)
|
||||
|
||||
|
||||
def convert_strategy_to_group_label(strategy_name: str) -> str:
|
||||
"""Convert a strategy name to its group label for ML training.
|
||||
|
||||
Args:
|
||||
strategy_name: Individual strategy name
|
||||
|
||||
Returns:
|
||||
Group label string (e.g., "trend_following")
|
||||
"""
|
||||
group = get_strategy_group(strategy_name)
|
||||
if group:
|
||||
return group.value
|
||||
else:
|
||||
logger.warning(f"Strategy '{strategy_name}' not in any group, using as-is")
|
||||
return strategy_name
|
||||
581
src/autopilot/strategy_selector.py
Normal file
581
src/autopilot/strategy_selector.py
Normal file
@@ -0,0 +1,581 @@
|
||||
"""ML-based strategy selector for intelligent autopilot."""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
|
||||
from src.core.logger import get_logger
|
||||
from src.core.config import get_config
|
||||
from .market_analyzer import MarketConditions, MarketRegime, get_market_analyzer
|
||||
from .models import StrategySelectorModel
|
||||
from .performance_tracker import get_performance_tracker
|
||||
from .strategy_groups import (
|
||||
StrategyGroup, get_strategy_group, get_strategies_in_group,
|
||||
get_best_strategy_in_group, convert_strategy_to_group_label, get_all_groups
|
||||
)
|
||||
from src.strategies import get_strategy_registry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StrategySelector:
|
||||
"""ML-based strategy selector."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize strategy selector."""
|
||||
self.logger = get_logger(__name__)
|
||||
self.config = get_config()
|
||||
self.market_analyzer = get_market_analyzer()
|
||||
self.performance_tracker = get_performance_tracker()
|
||||
self.model = StrategySelectorModel(model_type="classifier")
|
||||
self.strategy_registry = get_strategy_registry()
|
||||
self._available_strategies: List[str] = []
|
||||
|
||||
# Bootstrap configuration - improved defaults for better ML accuracy
|
||||
self.bootstrap_days = self.config.get("autopilot.intelligent.bootstrap.days", 365)
|
||||
self.bootstrap_timeframe = self.config.get("autopilot.intelligent.bootstrap.timeframe", "1h")
|
||||
self.min_samples_per_strategy = self.config.get("autopilot.intelligent.bootstrap.min_samples_per_strategy", 10)
|
||||
self.bootstrap_symbols = self.config.get("autopilot.intelligent.bootstrap.symbols", ["BTC/USD", "ETH/USD"])
|
||||
self.bootstrap_timeframes = ["15m", "1h", "4h"] # Multi-timeframe for more data
|
||||
|
||||
self._available_strategies: List[str] = []
|
||||
self._load_available_strategies()
|
||||
|
||||
self._last_model_ts = 0.0
|
||||
|
||||
# Auto-load persisted model if available
|
||||
self._try_load_saved_model()
|
||||
|
||||
def _try_load_saved_model(self):
|
||||
"""Try to load a previously saved model."""
|
||||
import os
|
||||
try:
|
||||
latest_model = self.model.get_latest_model_path()
|
||||
if latest_model:
|
||||
mtime = os.path.getmtime(latest_model)
|
||||
# Only load if it's new or we haven't loaded anything
|
||||
if mtime > self._last_model_ts:
|
||||
if self.model.load(latest_model):
|
||||
self._last_model_ts = mtime
|
||||
self.logger.info(f"Loaded persisted model: {latest_model}")
|
||||
else:
|
||||
self.logger.warning(f"Failed to load model {latest_model}")
|
||||
else:
|
||||
if self._last_model_ts == 0:
|
||||
self.logger.info("No saved model found, model needs training")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to load saved model: {e}")
|
||||
|
||||
def _load_available_strategies(self):
|
||||
"""Load list of available strategies."""
|
||||
self._available_strategies = self.strategy_registry.list_available()
|
||||
self.logger.info(f"Available strategies: {self._available_strategies}")
|
||||
|
||||
async def train_model(
|
||||
self,
|
||||
min_samples_per_strategy: int = 10,
|
||||
force_retrain: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Train the ML model on historical performance data.
|
||||
|
||||
Args:
|
||||
min_samples_per_strategy: Minimum samples required per strategy
|
||||
force_retrain: Force retraining even if model exists
|
||||
|
||||
Returns:
|
||||
Training metrics dictionary
|
||||
"""
|
||||
# Try to load existing model first
|
||||
if not force_retrain:
|
||||
latest_model = self.model.get_latest_model_path()
|
||||
if latest_model and self.model.load(latest_model):
|
||||
self.logger.info("Loaded existing trained model")
|
||||
return self.model.training_metadata.get('metrics', {})
|
||||
|
||||
# Prepare training data
|
||||
training_data = await self.performance_tracker.prepare_training_data(
|
||||
min_samples_per_strategy=min_samples_per_strategy
|
||||
)
|
||||
|
||||
if training_data is None:
|
||||
self.logger.warning("Insufficient training data. Using fallback rule-based selection.")
|
||||
return {}
|
||||
|
||||
X = training_data['X']
|
||||
y = training_data['y']
|
||||
strategy_names = training_data['strategy_names']
|
||||
|
||||
# Convert individual strategy names to group labels for better ML accuracy
|
||||
# This reduces the number of classes from 14 to 5
|
||||
y_groups = [convert_strategy_to_group_label(name) for name in y]
|
||||
group_names = [g.value for g in get_all_groups()]
|
||||
|
||||
self.logger.info(f"Training with {len(set(y_groups))} strategy groups (from {len(strategy_names)} strategies)")
|
||||
|
||||
if len(X) < min_samples_per_strategy * len(group_names):
|
||||
self.logger.warning("Insufficient training data. Using fallback rule-based selection.")
|
||||
return {}
|
||||
|
||||
# Train model with ensemble mode for better accuracy
|
||||
# Ensemble combines XGBoost + LightGBM + RandomForest via voting
|
||||
# Now predicts GROUPS instead of individual strategies
|
||||
metrics = self.model.train(
|
||||
X,
|
||||
y_groups, # Use group labels
|
||||
group_names, # Use group names
|
||||
use_ensemble=True,
|
||||
training_symbols=training_data.get('training_symbols', [])
|
||||
)
|
||||
|
||||
# Save model
|
||||
self.model.save()
|
||||
|
||||
return metrics
|
||||
|
||||
def select_best_strategy(
|
||||
self,
|
||||
market_conditions: MarketConditions,
|
||||
min_confidence: float = 0.5
|
||||
) -> Optional[Tuple[str, float, Dict[str, float]]]:
|
||||
"""Select best strategy for current market conditions.
|
||||
|
||||
The ML model predicts a strategy GROUP, then we use rule-based logic
|
||||
to select the best individual strategy within that group.
|
||||
|
||||
Args:
|
||||
market_conditions: Current market conditions
|
||||
min_confidence: Minimum confidence threshold
|
||||
|
||||
Returns:
|
||||
Tuple of (strategy_name, confidence, all_predictions) or None
|
||||
"""
|
||||
# If model not trained, use fallback
|
||||
if not self.model.is_trained:
|
||||
self.logger.debug("Model not trained, using rule-based fallback")
|
||||
return self._fallback_strategy_selection(market_conditions)
|
||||
|
||||
try:
|
||||
# Predict strategy GROUP using ML model
|
||||
predicted_group, group_confidence, all_group_predictions = self.model.predict(
|
||||
market_conditions.features
|
||||
)
|
||||
|
||||
if group_confidence < min_confidence:
|
||||
self.logger.debug(
|
||||
f"ML group confidence {group_confidence:.2f} below threshold {min_confidence}, "
|
||||
"using fallback"
|
||||
)
|
||||
return self._fallback_strategy_selection(market_conditions)
|
||||
|
||||
# Convert group string to enum
|
||||
try:
|
||||
group_enum = StrategyGroup(predicted_group)
|
||||
except ValueError:
|
||||
self.logger.warning(f"Unknown group '{predicted_group}', using fallback")
|
||||
return self._fallback_strategy_selection(market_conditions)
|
||||
|
||||
# Select best individual strategy within the predicted group
|
||||
best_strategy, strategy_confidence = get_best_strategy_in_group(
|
||||
group_enum,
|
||||
market_conditions.features,
|
||||
available_strategies=self._available_strategies
|
||||
)
|
||||
|
||||
# Combined confidence: group confidence * strategy confidence
|
||||
combined_confidence = group_confidence * strategy_confidence
|
||||
|
||||
# Build all_predictions dict with individual strategies
|
||||
all_predictions = {}
|
||||
for group_name, group_score in all_group_predictions.items():
|
||||
try:
|
||||
group_e = StrategyGroup(group_name)
|
||||
strategies_in_group = get_strategies_in_group(group_e)
|
||||
for strat in strategies_in_group:
|
||||
if strat in self._available_strategies:
|
||||
all_predictions[strat] = group_score
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self.logger.info(
|
||||
f"ML selected group: {predicted_group} (conf: {group_confidence:.2f}) "
|
||||
f"-> strategy: {best_strategy} (combined conf: {combined_confidence:.2f})"
|
||||
)
|
||||
|
||||
return best_strategy, combined_confidence, all_predictions
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in ML prediction: {e}")
|
||||
return self._fallback_strategy_selection(market_conditions)
|
||||
|
||||
def _fallback_strategy_selection(
|
||||
self,
|
||||
market_conditions: MarketConditions
|
||||
) -> Optional[Tuple[str, float, Dict[str, float]]]:
|
||||
"""Fallback rule-based strategy selection.
|
||||
|
||||
Args:
|
||||
market_conditions: Current market conditions
|
||||
|
||||
Returns:
|
||||
Tuple of (strategy_name, confidence, all_predictions)
|
||||
"""
|
||||
features = market_conditions.features
|
||||
regime = market_conditions.regime
|
||||
|
||||
# Rule-based selection based on market regime
|
||||
strategy_rules = {
|
||||
MarketRegime.TRENDING_UP: "moving_average",
|
||||
MarketRegime.TRENDING_DOWN: "moving_average",
|
||||
MarketRegime.RANGING: "rsi",
|
||||
MarketRegime.HIGH_VOLATILITY: "momentum",
|
||||
MarketRegime.LOW_VOLATILITY: "grid",
|
||||
MarketRegime.BREAKOUT: "momentum",
|
||||
MarketRegime.REVERSAL: "rsi"
|
||||
}
|
||||
|
||||
# Map regime enum to string
|
||||
regime_str = regime.value if hasattr(regime, 'value') else str(regime)
|
||||
|
||||
# Select strategy based on regime
|
||||
selected_strategy = strategy_rules.get(regime, "rsi")
|
||||
|
||||
# Check if strategy is available
|
||||
if selected_strategy not in self._available_strategies:
|
||||
# Fallback to first available
|
||||
if self._available_strategies:
|
||||
selected_strategy = self._available_strategies[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
# Calculate confidence based on regime clarity
|
||||
rsi = features.get('rsi', 50.0)
|
||||
adx = features.get('adx', 0.0)
|
||||
|
||||
# Higher confidence for clear signals
|
||||
if regime in [MarketRegime.TRENDING_UP, MarketRegime.TRENDING_DOWN] and adx > 25:
|
||||
confidence = 0.7
|
||||
elif regime == MarketRegime.RANGING and (rsi < 30 or rsi > 70):
|
||||
confidence = 0.65
|
||||
else:
|
||||
confidence = 0.5
|
||||
|
||||
all_predictions = {selected_strategy: confidence}
|
||||
|
||||
self.logger.info(
|
||||
f"Fallback selected strategy: {selected_strategy} "
|
||||
f"(confidence: {confidence:.2f}, regime: {regime_str})"
|
||||
)
|
||||
|
||||
return selected_strategy, confidence, all_predictions
|
||||
|
||||
def get_strategy_rankings(
|
||||
self,
|
||||
market_conditions: MarketConditions
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""Get all strategies ranked by expected performance.
|
||||
|
||||
Args:
|
||||
market_conditions: Current market conditions
|
||||
|
||||
Returns:
|
||||
List of (strategy_name, score) tuples, sorted by score descending
|
||||
"""
|
||||
result = self.select_best_strategy(market_conditions, min_confidence=0.0)
|
||||
|
||||
if result is None:
|
||||
return []
|
||||
|
||||
_, _, all_predictions = result
|
||||
|
||||
# Sort by score descending
|
||||
rankings = sorted(
|
||||
all_predictions.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return rankings
|
||||
|
||||
def update_model(self, trade_result: Dict[str, Any]) -> bool:
|
||||
"""Update model with new trade result (incremental learning).
|
||||
|
||||
Args:
|
||||
trade_result: Trade result with performance metrics
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
"""
|
||||
# For now, we'll retrain periodically rather than incremental updates
|
||||
# This can be enhanced later with online learning algorithms
|
||||
self.logger.debug("Model update requested (will retrain on next cycle)")
|
||||
return True
|
||||
|
||||
async def bootstrap_training_data(
|
||||
self,
|
||||
symbol: str = "BTC/USD",
|
||||
timeframe: Optional[str] = None,
|
||||
days: Optional[int] = None,
|
||||
exchange_name: str = "Binance Public"
|
||||
) -> Dict[str, Any]:
|
||||
"""Bootstrap training data by running backtests on historical data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe (defaults to config value)
|
||||
days: Number of days of historical data (defaults to config value)
|
||||
exchange_name: Exchange name
|
||||
|
||||
Returns:
|
||||
Dictionary with bootstrap results
|
||||
"""
|
||||
# Use config values as defaults
|
||||
if timeframe is None:
|
||||
timeframe = self.bootstrap_timeframe
|
||||
if days is None:
|
||||
days = self.bootstrap_days
|
||||
from src.backtesting.engine import BacktestingEngine
|
||||
from src.exchanges.public_data import PublicDataAdapter
|
||||
from src.data.collector import get_data_collector
|
||||
from src.core.database import get_database, MarketData
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from sqlalchemy import select, func
|
||||
import time
|
||||
|
||||
self.logger.info(f"Bootstrapping training data for {symbol} ({timeframe})")
|
||||
|
||||
db = get_database()
|
||||
|
||||
# Step 1: Check and fetch historical data
|
||||
async with db.get_session() as session:
|
||||
try:
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
# Check if we have data
|
||||
stmt = select(func.count()).select_from(MarketData).where(
|
||||
MarketData.exchange == exchange_name,
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe,
|
||||
MarketData.timestamp >= start_date
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
existing_data = result.scalar()
|
||||
|
||||
if existing_data < 100: # Need at least 100 candles
|
||||
self.logger.info(f"Fetching historical data ({days} days)...")
|
||||
adapter = PublicDataAdapter()
|
||||
if await adapter.connect():
|
||||
collector = get_data_collector()
|
||||
current_date = start_date
|
||||
chunk_days = 30
|
||||
|
||||
while current_date < end_date:
|
||||
chunk_end = min(current_date + timedelta(days=chunk_days), end_date)
|
||||
ohlcv = await adapter.get_ohlcv(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
since=current_date,
|
||||
limit=1000
|
||||
)
|
||||
if ohlcv:
|
||||
await collector.store_ohlcv(exchange_name, symbol, timeframe, ohlcv)
|
||||
current_date = chunk_end
|
||||
time.sleep(1)
|
||||
|
||||
await adapter.disconnect()
|
||||
self.logger.info("Historical data fetched")
|
||||
else:
|
||||
self.logger.error("Failed to connect to exchange")
|
||||
return {"error": "Failed to fetch historical data"}
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking/fetching historical data: {e}")
|
||||
return {"error": f"Database error: {e}"}
|
||||
|
||||
# Step 2: Run backtests for each strategy
|
||||
backtest_engine = BacktestingEngine()
|
||||
bootstrap_results = []
|
||||
|
||||
self.logger.info(f"Available strategies for bootstrap: {self._available_strategies}")
|
||||
|
||||
total_strategies = len(self._available_strategies)
|
||||
for strategy_idx, strategy_name in enumerate(self._available_strategies):
|
||||
try:
|
||||
strategy_class = self.strategy_registry._strategies.get(strategy_name.lower())
|
||||
if not strategy_class:
|
||||
continue
|
||||
|
||||
strategy = strategy_class(
|
||||
name=strategy_name,
|
||||
parameters={},
|
||||
timeframes=[timeframe]
|
||||
)
|
||||
strategy.enabled = True
|
||||
|
||||
self.logger.info(f"Running backtest for {strategy_name} ({strategy_idx + 1}/{total_strategies})...")
|
||||
|
||||
# Run backtest
|
||||
results = await backtest_engine.run_backtest(
|
||||
strategy=strategy,
|
||||
symbol=symbol,
|
||||
exchange=exchange_name,
|
||||
timeframe=timeframe,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
initial_capital=Decimal("10000.0"),
|
||||
slippage=0.001
|
||||
)
|
||||
|
||||
if "error" in results:
|
||||
self.logger.warning(f"Backtest failed for {strategy_name}: {results['error']}")
|
||||
continue
|
||||
|
||||
self.logger.info(f"Backtest successful for {strategy_name}, return: {results.get('total_return', 0.0)}%")
|
||||
|
||||
# Get market conditions for the period
|
||||
market_data = await backtest_engine._get_historical_data(
|
||||
exchange_name, symbol, timeframe, start_date, end_date
|
||||
)
|
||||
|
||||
if len(market_data) > 0:
|
||||
# =====================================================
|
||||
# IMPROVED SAMPLING: Regime-change detection + time spacing
|
||||
# =====================================================
|
||||
# This creates more diverse, independent training samples
|
||||
# by sampling at meaningful market transitions rather than
|
||||
# every N candles (which creates nearly-identical samples)
|
||||
|
||||
# Minimum time spacing between samples (in candles)
|
||||
# For 1h: 24 candles = 24 hours
|
||||
# For 15m: 96 candles = 24 hours
|
||||
# For 4h: 6 candles = 24 hours
|
||||
timeframe_to_spacing = {
|
||||
"1m": 1440, # 24 hours
|
||||
"5m": 288,
|
||||
"15m": 96,
|
||||
"30m": 48,
|
||||
"1h": 24,
|
||||
"4h": 6,
|
||||
"1d": 1
|
||||
}
|
||||
min_spacing = timeframe_to_spacing.get(timeframe, 24)
|
||||
|
||||
samples_recorded = 0
|
||||
last_regime = None
|
||||
last_sample_idx = -min_spacing # Allow first sample immediately
|
||||
|
||||
# Limit processing to prevent excessive computation
|
||||
# For small datasets (5 days), process all points but yield periodically
|
||||
data_points = len(market_data) - 50
|
||||
self.logger.info(f"Processing {data_points} data points for {strategy_name}...")
|
||||
|
||||
# Need at least 50 candles for feature calculation
|
||||
for i in range(50, len(market_data)):
|
||||
# Yield control periodically to prevent blocking (every 10 iterations)
|
||||
if i % 10 == 0:
|
||||
await asyncio.sleep(0) # Yield to event loop
|
||||
|
||||
sample_data = market_data.iloc[i-50:i]
|
||||
conditions = self.market_analyzer.analyze_current_conditions(
|
||||
symbol, timeframe, sample_data
|
||||
)
|
||||
current_regime = conditions.regime
|
||||
|
||||
# Determine if we should sample at this point
|
||||
# 1. Sample on regime change (market transition = valuable data)
|
||||
# 2. Sample after minimum time spacing (ensure independence)
|
||||
regime_changed = (last_regime is not None and current_regime != last_regime)
|
||||
time_elapsed = (i - last_sample_idx) >= min_spacing
|
||||
|
||||
should_sample = regime_changed or (time_elapsed and last_regime is None) or (time_elapsed and i == 50)
|
||||
|
||||
# Also sample periodically even without regime change (every 2x min_spacing)
|
||||
periodic_sample = (i - last_sample_idx) >= (min_spacing * 2)
|
||||
|
||||
if should_sample or periodic_sample:
|
||||
# Record as training data
|
||||
trade_result = {
|
||||
'return_pct': results.get('total_return', 0.0) / 100.0,
|
||||
'sharpe_ratio': results.get('sharpe_ratio', 0.0),
|
||||
'win_rate': results.get('win_rate', 0.5),
|
||||
'max_drawdown': abs(results.get('max_drawdown', 0.0)) / 100.0,
|
||||
'trade_count': results.get('total_trades', 0)
|
||||
}
|
||||
|
||||
await self.performance_tracker.record_trade(
|
||||
strategy_name=strategy_name,
|
||||
market_conditions=conditions,
|
||||
trade_result=trade_result
|
||||
)
|
||||
samples_recorded += 1
|
||||
last_sample_idx = i
|
||||
|
||||
if regime_changed:
|
||||
self.logger.debug(
|
||||
f"Sampled at regime change: {last_regime} -> {current_regime}"
|
||||
)
|
||||
|
||||
last_regime = current_regime
|
||||
|
||||
self.logger.info(f"Recorded {samples_recorded} samples for {strategy_name}")
|
||||
|
||||
bootstrap_results.append({
|
||||
'strategy': strategy_name,
|
||||
'trades_recorded': samples_recorded,
|
||||
'backtest_return': results.get('total_return', 0.0)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error bootstrapping {strategy_name}: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
total_samples = sum(r['trades_recorded'] for r in bootstrap_results)
|
||||
self.logger.info(f"Bootstrap complete: {total_samples} training samples created")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'strategies': bootstrap_results,
|
||||
'total_samples': total_samples
|
||||
}
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the current model.
|
||||
|
||||
Returns:
|
||||
Dictionary with model information
|
||||
"""
|
||||
# Always check if a newer model is available on disk
|
||||
self._try_load_saved_model()
|
||||
|
||||
info = {
|
||||
'is_trained': self.model.is_trained,
|
||||
'model_type': self.model.model_type,
|
||||
'available_strategies': self._available_strategies,
|
||||
'feature_count': len(self.model.feature_names) if self.model.is_trained else 0
|
||||
}
|
||||
|
||||
if self.model.is_trained:
|
||||
info.update({
|
||||
'training_metadata': self.model.training_metadata,
|
||||
'feature_importance': dict(
|
||||
sorted(
|
||||
self.model.get_feature_importance().items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:10] # Top 10 features
|
||||
)
|
||||
})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
# Global instance
|
||||
_strategy_selector: Optional[StrategySelector] = None
|
||||
|
||||
|
||||
def get_strategy_selector() -> StrategySelector:
|
||||
"""Get global strategy selector instance."""
|
||||
global _strategy_selector
|
||||
if _strategy_selector is None:
|
||||
_strategy_selector = StrategySelector()
|
||||
return _strategy_selector
|
||||
|
||||
0
src/backtesting/__init__.py
Normal file
0
src/backtesting/__init__.py
Normal file
51
src/backtesting/data_provider.py
Normal file
51
src/backtesting/data_provider.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Historical data management for backtesting."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, MarketData
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataProvider:
|
||||
"""Provides historical data for backtesting."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data provider."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime
|
||||
) -> List[MarketData]:
|
||||
"""Get historical data.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Returns:
|
||||
List of MarketData objects
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
return session.query(MarketData).filter(
|
||||
MarketData.exchange == exchange,
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe,
|
||||
MarketData.timestamp >= start_date,
|
||||
MarketData.timestamp <= end_date
|
||||
).order_by(MarketData.timestamp).all()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
207
src/backtesting/engine.py
Normal file
207
src/backtesting/engine.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Backtesting engine with historical data replay and performance metrics."""
|
||||
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from sqlalchemy import select
|
||||
from src.core.database import get_database, MarketData, OrderType
|
||||
from src.core.logger import get_logger
|
||||
from src.strategies.base import BaseStrategy
|
||||
from src.trading.paper_trading import PaperTradingSimulator
|
||||
from src.exchanges.factory import get_exchange
|
||||
from .metrics import BacktestMetrics
|
||||
from .slippage import FeeModel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BacktestingEngine:
|
||||
"""Backtesting engine for strategy evaluation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize backtesting engine."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def run_backtest(
|
||||
self,
|
||||
strategy: BaseStrategy,
|
||||
symbol: str,
|
||||
exchange: str,
|
||||
timeframe: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
initial_capital: Decimal = Decimal("100.0"),
|
||||
slippage: float = 0.001, # 0.1% slippage
|
||||
fee_model: Optional[FeeModel] = None,
|
||||
exchange_id: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run backtest on strategy.
|
||||
|
||||
Args:
|
||||
strategy: Strategy instance
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
timeframe: Timeframe
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
initial_capital: Initial capital
|
||||
slippage: Slippage percentage
|
||||
fee_model: FeeModel instance (if None, creates default)
|
||||
exchange_id: Exchange ID for fee structure retrieval
|
||||
|
||||
Returns:
|
||||
Backtest results dictionary
|
||||
"""
|
||||
# Get historical data
|
||||
data = await self._get_historical_data(exchange, symbol, timeframe, start_date, end_date)
|
||||
|
||||
if len(data) == 0:
|
||||
return {"error": "No historical data available"}
|
||||
|
||||
# Initialize fee model
|
||||
if fee_model is None:
|
||||
# Try to get exchange adapter for fee structure
|
||||
exchange_adapter = None
|
||||
if exchange_id:
|
||||
try:
|
||||
exchange_adapter = await get_exchange(exchange_id)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not get exchange adapter for fees: {e}")
|
||||
|
||||
fee_model = FeeModel(exchange_adapter=exchange_adapter)
|
||||
|
||||
# Initialize paper trading simulator
|
||||
simulator = PaperTradingSimulator(initial_capital)
|
||||
await simulator.initialize()
|
||||
strategy.enabled = True
|
||||
|
||||
# Run backtest
|
||||
trades = []
|
||||
total_fees = Decimal(0)
|
||||
|
||||
for i, row in data.iterrows():
|
||||
price = Decimal(str(row['close']))
|
||||
|
||||
# Generate signal
|
||||
signal = await strategy.on_tick(
|
||||
symbol,
|
||||
price,
|
||||
timeframe,
|
||||
{'open': row['open'], 'high': row['high'], 'low': row['low'], 'volume': row['volume']}
|
||||
)
|
||||
|
||||
if signal and strategy.should_execute(signal):
|
||||
# Determine order type (default to MARKET for backtesting)
|
||||
# In real trading, strategies could specify limit orders
|
||||
order_type = OrderType.MARKET
|
||||
is_maker = (order_type == OrderType.LIMIT)
|
||||
|
||||
# Execute trade with slippage and fees
|
||||
slippage_multiplier = (Decimal("1") + Decimal(str(slippage))) if signal.signal_type.value == "buy" else (Decimal("1") - Decimal(str(slippage)))
|
||||
fill_price = price * slippage_multiplier
|
||||
quantity = signal.quantity or strategy.calculate_position_size(signal, simulator.get_balance(), fill_price)
|
||||
|
||||
# Calculate fee using FeeModel with maker/taker distinction
|
||||
fee = fee_model.calculate_fee(quantity, fill_price, is_maker=is_maker)
|
||||
total_fees += fee
|
||||
|
||||
# Create and execute order
|
||||
from src.core.database import Order as DBOrder, OrderSide
|
||||
from src.trading.order_manager import get_order_manager
|
||||
|
||||
order_manager = get_order_manager()
|
||||
order = await order_manager.create_order(
|
||||
exchange_id=exchange_id or 1, # Use provided exchange_id or placeholder
|
||||
strategy_id=None,
|
||||
symbol=symbol,
|
||||
order_type=order_type,
|
||||
side=OrderSide.BUY if signal.signal_type.value == "buy" else OrderSide.SELL,
|
||||
quantity=quantity,
|
||||
paper_trading=True
|
||||
)
|
||||
|
||||
if order and await simulator.execute_order(order, fill_price, fee):
|
||||
trades.append({
|
||||
'timestamp': i,
|
||||
'price': float(fill_price),
|
||||
'quantity': float(quantity),
|
||||
'side': signal.signal_type.value,
|
||||
'fee': float(fee),
|
||||
})
|
||||
|
||||
# Calculate metrics
|
||||
metrics = BacktestMetrics()
|
||||
results = metrics.calculate_metrics(simulator, trades, initial_capital, total_fees)
|
||||
|
||||
# Add fee information to results
|
||||
results['total_fees'] = float(total_fees)
|
||||
results['fee_percentage'] = float((total_fees / initial_capital) * 100) if initial_capital > 0 else 0.0
|
||||
|
||||
return results
|
||||
|
||||
async def _get_historical_data(
|
||||
self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime
|
||||
) -> pd.DataFrame:
|
||||
"""Get historical OHLCV data.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data
|
||||
"""
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
stmt = select(MarketData).filter(
|
||||
MarketData.exchange == exchange,
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe,
|
||||
MarketData.timestamp >= start_date,
|
||||
MarketData.timestamp <= end_date
|
||||
).order_by(MarketData.timestamp)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
market_data = result.scalars().all()
|
||||
|
||||
if len(market_data) == 0:
|
||||
return pd.DataFrame()
|
||||
|
||||
data = {
|
||||
'timestamp': [md.timestamp for md in market_data],
|
||||
'open': [float(md.open) for md in market_data],
|
||||
'high': [float(md.high) for md in market_data],
|
||||
'low': [float(md.low) for md in market_data],
|
||||
'close': [float(md.close) for md in market_data],
|
||||
'volume': [float(md.volume) for md in market_data],
|
||||
}
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get historical data: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
# Global backtesting engine
|
||||
_backtesting_engine: Optional[BacktestingEngine] = None
|
||||
|
||||
|
||||
def get_backtest_engine() -> BacktestingEngine:
|
||||
"""Get global backtesting engine instance."""
|
||||
global _backtesting_engine
|
||||
if _backtesting_engine is None:
|
||||
_backtesting_engine = BacktestingEngine()
|
||||
return _backtesting_engine
|
||||
|
||||
85
src/backtesting/metrics.py
Normal file
85
src/backtesting/metrics.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Performance metrics for backtesting."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Any, Optional
|
||||
from src.trading.paper_trading import PaperTradingSimulator
|
||||
|
||||
class BacktestMetrics:
|
||||
"""Calculates backtest performance metrics."""
|
||||
|
||||
def calculate_metrics(
|
||||
self,
|
||||
simulator: PaperTradingSimulator,
|
||||
trades: List[Dict[str, Any]],
|
||||
initial_capital: Decimal,
|
||||
total_fees: Optional[Decimal] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculate backtest metrics, including fee-adjusted metrics.
|
||||
|
||||
Args:
|
||||
simulator: Paper trading simulator
|
||||
trades: List of executed trades (may include 'fee' field)
|
||||
initial_capital: Initial capital
|
||||
total_fees: Total fees paid (if None, calculated from trades)
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics
|
||||
"""
|
||||
final_value = simulator.get_portfolio_value()
|
||||
|
||||
# Calculate total fees if not provided
|
||||
if total_fees is None:
|
||||
total_fees = sum(
|
||||
Decimal(str(trade.get('fee', 0))) for trade in trades
|
||||
)
|
||||
|
||||
# Gross return (before fees)
|
||||
gross_return = ((final_value - initial_capital) / initial_capital) * 100 if initial_capital > 0 else 0
|
||||
|
||||
# Net return (after fees) - fees are already deducted in simulator
|
||||
# But we calculate what return would be without fees for comparison
|
||||
value_without_fees = final_value + total_fees
|
||||
net_return = ((final_value - initial_capital) / initial_capital) * 100 if initial_capital > 0 else 0
|
||||
gross_return_without_fees = ((value_without_fees - initial_capital) / initial_capital) * 100 if initial_capital > 0 else 0
|
||||
|
||||
# Fee impact
|
||||
fee_impact = gross_return_without_fees - net_return
|
||||
|
||||
# Calculate win rate from trades
|
||||
win_rate = self._calculate_win_rate(trades)
|
||||
|
||||
return {
|
||||
"initial_capital": float(initial_capital),
|
||||
"final_capital": float(final_value),
|
||||
"total_return": float(net_return), # Net return (after fees)
|
||||
"gross_return": float(gross_return_without_fees), # Gross return (before fees)
|
||||
"total_fees": float(total_fees),
|
||||
"fee_impact_percent": float(fee_impact),
|
||||
"fee_percentage": float((total_fees / initial_capital) * 100) if initial_capital > 0 else 0.0,
|
||||
"total_trades": len(trades),
|
||||
"win_rate": win_rate,
|
||||
"sharpe_ratio": 0.0, # Placeholder - would need returns series
|
||||
"max_drawdown": 0.0, # Placeholder - would need equity curve
|
||||
}
|
||||
|
||||
def _calculate_win_rate(self, trades: List[Dict[str, Any]]) -> float:
|
||||
"""Calculate win rate from trades.
|
||||
|
||||
Args:
|
||||
trades: List of trades
|
||||
|
||||
Returns:
|
||||
Win rate (0.0 to 1.0)
|
||||
"""
|
||||
if len(trades) < 2:
|
||||
return 0.0
|
||||
|
||||
# Simple approach: count buy-sell pairs
|
||||
# This is a simplified calculation - full implementation would track positions
|
||||
wins = 0
|
||||
total_pairs = 0
|
||||
|
||||
# Group trades by symbol and calculate pairs
|
||||
# For now, return placeholder
|
||||
return 0.5
|
||||
|
||||
175
src/backtesting/slippage.py
Normal file
175
src/backtesting/slippage.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Slippage modeling for realistic backtesting."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Optional, Any
|
||||
from src.core.logger import get_logger
|
||||
from src.exchanges.base import BaseExchangeAdapter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SlippageModel:
|
||||
"""Models slippage for realistic backtesting."""
|
||||
|
||||
def __init__(self, slippage_rate: float = 0.001):
|
||||
"""Initialize slippage model.
|
||||
|
||||
Args:
|
||||
slippage_rate: Slippage rate (0.001 = 0.1%)
|
||||
"""
|
||||
self.slippage_rate = slippage_rate
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def calculate_fill_price(
|
||||
self,
|
||||
order_price: Decimal,
|
||||
side: str, # "buy" or "sell"
|
||||
order_type: str, # "market" or "limit"
|
||||
market_price: Decimal,
|
||||
volume: Decimal = Decimal(0)
|
||||
) -> Decimal:
|
||||
"""Calculate fill price with slippage.
|
||||
|
||||
Args:
|
||||
order_price: Order price
|
||||
side: Buy or sell
|
||||
order_type: Market or limit
|
||||
market_price: Current market price
|
||||
volume: Order volume (for market impact)
|
||||
|
||||
Returns:
|
||||
Fill price with slippage
|
||||
"""
|
||||
if order_type == "limit":
|
||||
# Limit orders fill at order price (if market reaches it)
|
||||
return order_price
|
||||
|
||||
# Market orders have slippage
|
||||
if side == "buy":
|
||||
# Buy orders pay more (slippage up)
|
||||
slippage = market_price * Decimal(str(self.slippage_rate))
|
||||
# Add market impact based on volume
|
||||
impact = market_price * Decimal(str(volume)) * Decimal("0.0001") # Simplified
|
||||
return market_price + slippage + impact
|
||||
else:
|
||||
# Sell orders receive less (slippage down)
|
||||
slippage = market_price * Decimal(str(self.slippage_rate))
|
||||
impact = market_price * Decimal(str(volume)) * Decimal("0.0001")
|
||||
return market_price - slippage - impact
|
||||
|
||||
|
||||
class FeeModel:
|
||||
"""Models exchange fees with support for dynamic fee retrieval."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maker_fee: Optional[float] = None,
|
||||
taker_fee: Optional[float] = None,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None,
|
||||
minimum_fee: float = 0.0
|
||||
):
|
||||
"""Initialize fee model.
|
||||
|
||||
Args:
|
||||
maker_fee: Maker fee rate (if None, retrieved from exchange or default)
|
||||
taker_fee: Taker fee rate (if None, retrieved from exchange or default)
|
||||
exchange_adapter: Exchange adapter for dynamic fee retrieval
|
||||
minimum_fee: Minimum fee amount
|
||||
"""
|
||||
self.exchange_adapter = exchange_adapter
|
||||
self.minimum_fee = Decimal(str(minimum_fee))
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Set fees from parameters or retrieve from exchange
|
||||
if maker_fee is not None and taker_fee is not None:
|
||||
self.maker_fee = maker_fee
|
||||
self.taker_fee = taker_fee
|
||||
else:
|
||||
fee_structure = self._get_fee_structure()
|
||||
self.maker_fee = maker_fee if maker_fee is not None else fee_structure.get('maker', 0.001)
|
||||
self.taker_fee = taker_fee if taker_fee is not None else fee_structure.get('taker', 0.001)
|
||||
|
||||
def _get_fee_structure(self) -> Dict[str, Any]:
|
||||
"""Get fee structure from exchange adapter or defaults.
|
||||
|
||||
Returns:
|
||||
Fee structure dictionary
|
||||
"""
|
||||
if self.exchange_adapter:
|
||||
try:
|
||||
return self.exchange_adapter.get_fee_structure()
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to get fee structure from exchange: {e}")
|
||||
|
||||
return {
|
||||
'maker': 0.001, # 0.1%
|
||||
'taker': 0.001, # 0.1%
|
||||
'minimum': 0.0
|
||||
}
|
||||
|
||||
def calculate_fee(
|
||||
self,
|
||||
quantity: Decimal,
|
||||
price: Decimal,
|
||||
is_maker: bool = False
|
||||
) -> Decimal:
|
||||
"""Calculate trading fee.
|
||||
|
||||
Args:
|
||||
quantity: Trade quantity
|
||||
price: Trade price
|
||||
is_maker: True if maker order
|
||||
|
||||
Returns:
|
||||
Trading fee
|
||||
"""
|
||||
if quantity <= 0 or price <= 0:
|
||||
return Decimal(0)
|
||||
|
||||
trade_value = quantity * price
|
||||
fee_rate = self.maker_fee if is_maker else self.taker_fee
|
||||
fee = trade_value * Decimal(str(fee_rate))
|
||||
|
||||
# Apply minimum fee
|
||||
if self.minimum_fee > 0 and fee < self.minimum_fee:
|
||||
fee = self.minimum_fee
|
||||
|
||||
return fee
|
||||
|
||||
def estimate_round_trip_fee(
|
||||
self,
|
||||
quantity: Decimal,
|
||||
price: Decimal
|
||||
) -> Decimal:
|
||||
"""Estimate total fees for a round-trip trade (buy + sell).
|
||||
|
||||
Args:
|
||||
quantity: Trade quantity
|
||||
price: Trade price
|
||||
|
||||
Returns:
|
||||
Total estimated round-trip fee
|
||||
"""
|
||||
buy_fee = self.calculate_fee(quantity, price, is_maker=False)
|
||||
sell_fee = self.calculate_fee(quantity, price, is_maker=False)
|
||||
return buy_fee + sell_fee
|
||||
|
||||
def get_minimum_profit_threshold(
|
||||
self,
|
||||
quantity: Decimal,
|
||||
price: Decimal,
|
||||
multiplier: float = 2.0
|
||||
) -> Decimal:
|
||||
"""Calculate minimum profit threshold needed to break even after fees.
|
||||
|
||||
Args:
|
||||
quantity: Trade quantity
|
||||
price: Trade price
|
||||
multiplier: Multiplier for minimum profit (default 2.0 = 2x fees)
|
||||
|
||||
Returns:
|
||||
Minimum profit threshold
|
||||
"""
|
||||
round_trip_fee = self.estimate_round_trip_fee(quantity, price)
|
||||
return round_trip_fee * Decimal(str(multiplier))
|
||||
|
||||
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
256
src/core/config.py
Normal file
256
src/core/config.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Configuration management system with YAML and environment variables."""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration manager with XDG directory support."""
|
||||
|
||||
def __init__(self, config_file: Optional[str] = None):
|
||||
"""Initialize configuration manager.
|
||||
|
||||
Args:
|
||||
config_file: Optional path to config file. If None, uses XDG default.
|
||||
"""
|
||||
self._setup_xdg_directories()
|
||||
|
||||
# Determine config file priority:
|
||||
# 1. Explicit argument
|
||||
# 2. Local project config (dev mode)
|
||||
# 3. XDG config (user mode)
|
||||
|
||||
local_config = Path(__file__).parent.parent.parent / "config" / "config.yaml"
|
||||
|
||||
if config_file:
|
||||
self.config_file = Path(config_file)
|
||||
elif local_config.exists():
|
||||
self.config_file = local_config
|
||||
else:
|
||||
self.config_file = self.config_dir / "config.yaml"
|
||||
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._load_config()
|
||||
|
||||
def _setup_xdg_directories(self):
|
||||
"""Set up XDG Base Directory Specification directories."""
|
||||
home = Path.home()
|
||||
|
||||
# XDG_CONFIG_HOME or default
|
||||
xdg_config = os.getenv("XDG_CONFIG_HOME", home / ".config")
|
||||
self.config_dir = Path(xdg_config) / "crypto_trader"
|
||||
self.config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# XDG_DATA_HOME or default
|
||||
xdg_data = os.getenv("XDG_DATA_HOME", home / ".local" / "share")
|
||||
self.data_dir = Path(xdg_data) / "crypto_trader"
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create subdirectories
|
||||
(self.data_dir / "historical").mkdir(exist_ok=True)
|
||||
(self.data_dir / "backups").mkdir(exist_ok=True)
|
||||
(self.data_dir / "logs").mkdir(exist_ok=True)
|
||||
|
||||
# XDG_CACHE_HOME or default
|
||||
xdg_cache = os.getenv("XDG_CACHE_HOME", home / ".cache")
|
||||
self.cache_dir = Path(xdg_cache) / "crypto_trader"
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from YAML file and environment variables."""
|
||||
# Load defaults
|
||||
default_config = self._get_default_config()
|
||||
self._config = default_config.copy()
|
||||
|
||||
# Load from file if it exists
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, 'r') as f:
|
||||
file_config = yaml.safe_load(f) or {}
|
||||
self._config.update(file_config)
|
||||
|
||||
# Override with environment variables
|
||||
self._load_from_env()
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration."""
|
||||
return {
|
||||
"app": {
|
||||
"name": "Crypto Trader",
|
||||
"version": "0.1.0",
|
||||
},
|
||||
"database": {
|
||||
"type": "postgresql",
|
||||
"url": None, # For PostgreSQL
|
||||
},
|
||||
"logging": {
|
||||
"level": os.getenv("LOG_LEVEL", "INFO"),
|
||||
"dir": str(self.data_dir / "logs"),
|
||||
"retention_days": 30,
|
||||
"rotation": "daily",
|
||||
},
|
||||
"paper_trading": {
|
||||
"enabled": True,
|
||||
"default_capital": float(os.getenv("PAPER_TRADING_CAPITAL", "100.0")),
|
||||
},
|
||||
"updates": {
|
||||
"check_on_startup": os.getenv("UPDATE_CHECK_ON_STARTUP", "true").lower() == "true",
|
||||
"repository_url": os.getenv("UPDATE_REPOSITORY_URL", ""),
|
||||
},
|
||||
"exchanges": {},
|
||||
"strategies": {
|
||||
"default_timeframe": "1h",
|
||||
},
|
||||
"risk": {
|
||||
"max_drawdown_percent": 20.0,
|
||||
"daily_loss_limit_percent": 5.0,
|
||||
"position_size_percent": 2.0,
|
||||
},
|
||||
"trading": {
|
||||
"default_fees": {
|
||||
"maker": 0.001, # 0.1%
|
||||
"taker": 0.001, # 0.1%
|
||||
"minimum": 0.0,
|
||||
},
|
||||
"exchanges": {},
|
||||
},
|
||||
"data_providers": {
|
||||
"primary": [
|
||||
{"name": "kraken", "enabled": True, "priority": 1},
|
||||
{"name": "coinbase", "enabled": True, "priority": 2},
|
||||
{"name": "binance", "enabled": True, "priority": 3},
|
||||
],
|
||||
"fallback": {
|
||||
"name": "coingecko",
|
||||
"enabled": True,
|
||||
"api_key": "",
|
||||
},
|
||||
"caching": {
|
||||
"ticker_ttl": 2, # seconds
|
||||
"ohlcv_ttl": 60, # seconds
|
||||
"max_cache_size": 1000,
|
||||
},
|
||||
"websocket": {
|
||||
"enabled": True,
|
||||
"reconnect_interval": 5, # seconds
|
||||
"ping_interval": 30, # seconds
|
||||
},
|
||||
},
|
||||
"redis": {
|
||||
"host": os.getenv("REDIS_HOST", "127.0.0.1"),
|
||||
"port": int(os.getenv("REDIS_PORT", 6379)),
|
||||
"db": int(os.getenv("REDIS_DB", 0)),
|
||||
"password": os.getenv("REDIS_PASSWORD", None),
|
||||
"socket_connect_timeout": 5,
|
||||
},
|
||||
"celery": {
|
||||
"broker_url": os.getenv("CELERY_BROKER_URL", "redis://127.0.0.1:6379/0"),
|
||||
"result_backend": os.getenv("CELERY_RESULT_BACKEND", "redis://127.0.0.1:6379/0"),
|
||||
},
|
||||
}
|
||||
|
||||
def _load_from_env(self):
|
||||
"""Load configuration from environment variables."""
|
||||
# Database
|
||||
if db_url := os.getenv("DATABASE_URL"):
|
||||
self._config["database"]["url"] = db_url
|
||||
self._config["database"]["type"] = "postgresql"
|
||||
|
||||
# Logging
|
||||
if log_level := os.getenv("LOG_LEVEL"):
|
||||
self._config["logging"]["level"] = log_level
|
||||
if log_dir := os.getenv("LOG_DIR"):
|
||||
self._config["logging"]["dir"] = log_dir
|
||||
|
||||
# Paper trading
|
||||
if capital := os.getenv("PAPER_TRADING_CAPITAL"):
|
||||
self._config["paper_trading"]["default_capital"] = float(capital)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value using dot notation.
|
||||
|
||||
Args:
|
||||
key: Configuration key (e.g., "database.path")
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
Configuration value or default
|
||||
"""
|
||||
keys = key.split(".")
|
||||
value = self._config
|
||||
for k in keys:
|
||||
if isinstance(value, dict):
|
||||
value = value.get(k)
|
||||
if value is None:
|
||||
return default
|
||||
else:
|
||||
return default
|
||||
return value
|
||||
|
||||
def set(self, key: str, value: Any):
|
||||
"""Set configuration value using dot notation.
|
||||
|
||||
Args:
|
||||
key: Configuration key (e.g., "database.path")
|
||||
value: Value to set
|
||||
"""
|
||||
keys = key.split(".")
|
||||
config = self._config
|
||||
for k in keys[:-1]:
|
||||
if k not in config:
|
||||
config[k] = {}
|
||||
config = config[k]
|
||||
config[keys[-1]] = value
|
||||
|
||||
def save(self):
|
||||
"""Save configuration to file."""
|
||||
with open(self.config_file, 'w') as f:
|
||||
yaml.dump(self._config, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
@property
|
||||
def config_dir(self) -> Path:
|
||||
"""Get config directory path."""
|
||||
return self._config_dir
|
||||
|
||||
@config_dir.setter
|
||||
def config_dir(self, value: Path):
|
||||
"""Set config directory."""
|
||||
self._config_dir = value
|
||||
|
||||
@property
|
||||
def data_dir(self) -> Path:
|
||||
"""Get data directory path."""
|
||||
return self._data_dir
|
||||
|
||||
@data_dir.setter
|
||||
def data_dir(self, value: Path):
|
||||
"""Set data directory."""
|
||||
self._data_dir = value
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Get cache directory path."""
|
||||
return self._cache_dir
|
||||
|
||||
@cache_dir.setter
|
||||
def cache_dir(self, value: Path):
|
||||
"""Set cache directory."""
|
||||
self._cache_dir = value
|
||||
|
||||
|
||||
# Global config instance
|
||||
_config_instance: Optional[Config] = None
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""Get global configuration instance."""
|
||||
global _config_instance
|
||||
if _config_instance is None:
|
||||
_config_instance = Config()
|
||||
return _config_instance
|
||||
|
||||
416
src/core/database.py
Normal file
416
src/core/database.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""Database connection and models using SQLAlchemy."""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from sqlalchemy import (
|
||||
create_engine, Column, Integer, String, Float, Boolean, DateTime,
|
||||
Text, ForeignKey, JSON, Enum as SQLEnum, Numeric
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, relationship, Session
|
||||
from .config import get_config
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class OrderType(str, Enum):
|
||||
"""Order type enumeration."""
|
||||
MARKET = "market"
|
||||
LIMIT = "limit"
|
||||
STOP_LOSS = "stop_loss"
|
||||
TAKE_PROFIT = "take_profit"
|
||||
TRAILING_STOP = "trailing_stop"
|
||||
OCO = "oco"
|
||||
ICEBERG = "iceberg"
|
||||
|
||||
|
||||
class OrderSide(str, Enum):
|
||||
"""Order side enumeration."""
|
||||
BUY = "buy"
|
||||
SELL = "sell"
|
||||
|
||||
|
||||
class OrderStatus(str, Enum):
|
||||
"""Order status enumeration."""
|
||||
PENDING = "pending"
|
||||
OPEN = "open"
|
||||
PARTIALLY_FILLED = "partially_filled"
|
||||
FILLED = "filled"
|
||||
CANCELLED = "cancelled"
|
||||
REJECTED = "rejected"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class TradeType(str, Enum):
|
||||
"""Trade type enumeration."""
|
||||
SPOT = "spot"
|
||||
FUTURES = "futures"
|
||||
MARGIN = "margin"
|
||||
|
||||
|
||||
class Exchange(Base):
|
||||
"""Exchange configuration and credentials."""
|
||||
__tablename__ = "exchanges"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50), nullable=False, unique=True)
|
||||
api_key_encrypted = Column(Text) # Encrypted API key
|
||||
api_secret_encrypted = Column(Text) # Encrypted API secret
|
||||
sandbox = Column(Boolean, default=False)
|
||||
read_only = Column(Boolean, default=True) # Read-only mode
|
||||
enabled = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
trades = relationship("Trade", back_populates="exchange")
|
||||
orders = relationship("Order", back_populates="exchange")
|
||||
positions = relationship("Position", back_populates="exchange")
|
||||
|
||||
|
||||
class Strategy(Base):
|
||||
"""Strategy definitions and parameters."""
|
||||
__tablename__ = "strategies"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
description = Column(Text)
|
||||
strategy_type = Column(String(50)) # technical, momentum, grid, dca, etc.
|
||||
class_name = Column(String(100)) # Python class name
|
||||
parameters = Column(JSON) # Strategy parameters
|
||||
timeframes = Column(JSON) # Multi-timeframe configuration
|
||||
enabled = Column(Boolean, default=False) # Available to Autopilot
|
||||
running = Column(Boolean, default=False) # Currently running manually
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
version = Column(String(20), default="1.0.0")
|
||||
schedule = Column(JSON) # Scheduling configuration
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
trades = relationship("Trade", back_populates="strategy")
|
||||
backtest_results = relationship("BacktestResult", back_populates="strategy")
|
||||
|
||||
|
||||
class Order(Base):
|
||||
"""Order history with state tracking."""
|
||||
__tablename__ = "orders"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
|
||||
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=True)
|
||||
exchange_order_id = Column(String(100)) # Exchange's order ID
|
||||
symbol = Column(String(20), nullable=False)
|
||||
order_type = Column(SQLEnum(OrderType), nullable=False)
|
||||
side = Column(SQLEnum(OrderSide), nullable=False)
|
||||
status = Column(SQLEnum(OrderStatus), default=OrderStatus.PENDING)
|
||||
quantity = Column(Numeric(20, 8), nullable=False)
|
||||
price = Column(Numeric(20, 8)) # For limit orders
|
||||
filled_quantity = Column(Numeric(20, 8), default=0)
|
||||
average_fill_price = Column(Numeric(20, 8))
|
||||
fee = Column(Numeric(20, 8), default=0)
|
||||
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
|
||||
leverage = Column(Integer, default=1) # For futures/margin
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
filled_at = Column(DateTime)
|
||||
|
||||
# Relationships
|
||||
exchange = relationship("Exchange", back_populates="orders")
|
||||
strategy = relationship("Strategy")
|
||||
trades = relationship("Trade", back_populates="order")
|
||||
|
||||
|
||||
class Trade(Base):
|
||||
"""All executed trades (paper and live)."""
|
||||
__tablename__ = "trades"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
|
||||
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=True)
|
||||
order_id = Column(Integer, ForeignKey("orders.id"), nullable=True)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
side = Column(SQLEnum(OrderSide), nullable=False)
|
||||
quantity = Column(Numeric(20, 8), nullable=False)
|
||||
price = Column(Numeric(20, 8), nullable=False)
|
||||
fee = Column(Numeric(20, 8), default=0)
|
||||
total = Column(Numeric(20, 8), nullable=False) # quantity * price + fee
|
||||
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
exchange = relationship("Exchange", back_populates="trades")
|
||||
strategy = relationship("Strategy", back_populates="trades")
|
||||
order = relationship("Order", back_populates="trades")
|
||||
|
||||
|
||||
class Position(Base):
|
||||
"""Current open positions (spot and futures)."""
|
||||
__tablename__ = "positions"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
side = Column(String(10)) # long, short
|
||||
quantity = Column(Numeric(20, 8), nullable=False)
|
||||
entry_price = Column(Numeric(20, 8), nullable=False)
|
||||
current_price = Column(Numeric(20, 8))
|
||||
unrealized_pnl = Column(Numeric(20, 8), default=0)
|
||||
realized_pnl = Column(Numeric(20, 8), default=0)
|
||||
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
|
||||
leverage = Column(Integer, default=1)
|
||||
margin = Column(Numeric(20, 8))
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
opened_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
exchange = relationship("Exchange", back_populates="positions")
|
||||
|
||||
|
||||
class PortfolioSnapshot(Base):
|
||||
"""Historical portfolio values."""
|
||||
__tablename__ = "portfolio_snapshots"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
total_value = Column(Numeric(20, 8), nullable=False)
|
||||
cash = Column(Numeric(20, 8), nullable=False)
|
||||
positions_value = Column(Numeric(20, 8), nullable=False)
|
||||
unrealized_pnl = Column(Numeric(20, 8), default=0)
|
||||
realized_pnl = Column(Numeric(20, 8), default=0)
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
|
||||
class MarketData(Base):
|
||||
"""OHLCV historical data (multiple timeframes)."""
|
||||
__tablename__ = "market_data"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
exchange = Column(String(50), nullable=False)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
timeframe = Column(String(10), nullable=False) # 1m, 5m, 15m, 1h, 1d, etc.
|
||||
timestamp = Column(DateTime, nullable=False, index=True)
|
||||
open = Column(Numeric(20, 8), nullable=False)
|
||||
high = Column(Numeric(20, 8), nullable=False)
|
||||
low = Column(Numeric(20, 8), nullable=False)
|
||||
close = Column(Numeric(20, 8), nullable=False)
|
||||
volume = Column(Numeric(20, 8), nullable=False)
|
||||
|
||||
|
||||
|
||||
|
||||
class BacktestResult(Base):
|
||||
"""Backtesting results and metrics."""
|
||||
__tablename__ = "backtest_results"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=False)
|
||||
start_date = Column(DateTime, nullable=False)
|
||||
end_date = Column(DateTime, nullable=False)
|
||||
initial_capital = Column(Numeric(20, 8), nullable=False)
|
||||
final_capital = Column(Numeric(20, 8), nullable=False)
|
||||
total_return = Column(Numeric(10, 4)) # Percentage
|
||||
sharpe_ratio = Column(Numeric(10, 4))
|
||||
sortino_ratio = Column(Numeric(10, 4))
|
||||
max_drawdown = Column(Numeric(10, 4))
|
||||
win_rate = Column(Numeric(10, 4))
|
||||
total_trades = Column(Integer, default=0)
|
||||
parameters = Column(JSON) # Parameters used in backtest
|
||||
metrics = Column(JSON) # Additional metrics
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
strategy = relationship("Strategy", back_populates="backtest_results")
|
||||
|
||||
|
||||
class RiskLimit(Base):
|
||||
"""Risk management configuration."""
|
||||
__tablename__ = "risk_limits"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
limit_type = Column(String(50)) # max_drawdown, daily_loss, position_size, etc.
|
||||
value = Column(Numeric(10, 4), nullable=False)
|
||||
enabled = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class Alert(Base):
|
||||
"""Alert definitions and history."""
|
||||
__tablename__ = "alerts"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
alert_type = Column(String(50)) # price, indicator, risk, system
|
||||
condition = Column(JSON) # Alert condition configuration
|
||||
enabled = Column(Boolean, default=True)
|
||||
triggered = Column(Boolean, default=False)
|
||||
triggered_at = Column(DateTime)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class RebalancingEvent(Base):
|
||||
"""Portfolio rebalancing history."""
|
||||
__tablename__ = "rebalancing_events"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
trigger_type = Column(String(50)) # time, threshold, manual
|
||||
target_allocations = Column(JSON) # Target portfolio allocations
|
||||
before_allocations = Column(JSON) # Allocations before rebalancing
|
||||
after_allocations = Column(JSON) # Allocations after rebalancing
|
||||
orders_placed = Column(JSON) # Orders placed for rebalancing
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class AppState(Base):
|
||||
"""Application state for recovery."""
|
||||
__tablename__ = "app_state"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
key = Column(String(100), unique=True, nullable=False)
|
||||
value = Column(JSON)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Security and action audit trail."""
|
||||
__tablename__ = "audit_log"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
action = Column(String(100), nullable=False)
|
||||
entity_type = Column(String(50)) # exchange, strategy, order, etc.
|
||||
entity_id = Column(Integer)
|
||||
details = Column(JSON)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
|
||||
class MarketConditionsSnapshot(Base):
|
||||
"""Market conditions snapshot for ML training."""
|
||||
__tablename__ = "market_conditions_snapshot"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
timeframe = Column(String(10), nullable=False)
|
||||
regime = Column(String(50)) # Market regime classification
|
||||
features = Column(JSON) # Market condition features
|
||||
strategy_name = Column(String(100)) # Strategy used at this time
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
|
||||
|
||||
|
||||
class StrategyPerformance(Base):
|
||||
"""Strategy performance records for ML training."""
|
||||
__tablename__ = "strategy_performance"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
strategy_name = Column(String(100), nullable=False, index=True)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
timeframe = Column(String(10), nullable=False)
|
||||
market_regime = Column(String(50), index=True) # Market regime when trade executed
|
||||
return_pct = Column(Numeric(10, 4)) # Return percentage
|
||||
sharpe_ratio = Column(Numeric(10, 4)) # Sharpe ratio
|
||||
win_rate = Column(Numeric(5, 2)) # Win rate (0-100)
|
||||
max_drawdown = Column(Numeric(10, 4)) # Maximum drawdown
|
||||
trade_count = Column(Integer, default=1) # Number of trades in this period
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
|
||||
|
||||
|
||||
class MLModelMetadata(Base):
|
||||
"""ML model metadata and versions."""
|
||||
__tablename__ = "ml_model_metadata"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
model_type = Column(String(50)) # classifier, regressor
|
||||
version = Column(String(20))
|
||||
file_path = Column(String(500)) # Path to saved model file
|
||||
training_metrics = Column(JSON) # Training metrics (accuracy, MSE, etc.)
|
||||
feature_names = Column(JSON) # List of feature names
|
||||
strategy_names = Column(JSON) # List of strategy names
|
||||
training_samples = Column(Integer) # Number of training samples
|
||||
trained_at = Column(DateTime, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
|
||||
class Database:
|
||||
"""Database connection manager."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize database connection."""
|
||||
self.config = get_config()
|
||||
self.engine = self._create_engine()
|
||||
self.SessionLocal = async_sessionmaker(
|
||||
bind=self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
# self._create_tables() # Tables should be created via alembic or separate init script in async
|
||||
|
||||
def _create_engine(self):
|
||||
"""Create database engine."""
|
||||
db_type = self.config.get("database.type", "postgresql")
|
||||
|
||||
if db_type == "postgresql":
|
||||
db_url = self.config.get("database.url")
|
||||
if not db_url:
|
||||
raise ValueError("PostgreSQL URL not configured")
|
||||
# Ensure URL uses async driver (e.g. postgresql+asyncpg)
|
||||
if "postgresql://" in db_url and "postgresql+asyncpg://" not in db_url:
|
||||
# This is a naive replacement, in production we should handle this better
|
||||
db_url = db_url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
# Add connection timeout to prevent hanging
|
||||
# asyncpg connect timeout is set via connect_timeout in connect_args
|
||||
return create_async_engine(
|
||||
db_url,
|
||||
echo=False,
|
||||
connect_args={
|
||||
"server_settings": {"application_name": "crypto_trader"},
|
||||
"timeout": 5, # 5 second connection timeout
|
||||
},
|
||||
pool_pre_ping=True, # Verify connections before using
|
||||
pool_recycle=3600, # Recycle connections after 1 hour
|
||||
pool_timeout=5, # Timeout when getting connection from pool
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {db_type}. Only 'postgresql' is supported.")
|
||||
|
||||
async def create_tables(self):
|
||||
"""Create all database tables."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
def get_session(self) -> AsyncSession:
|
||||
"""Get a database session."""
|
||||
return self.SessionLocal()
|
||||
|
||||
async def close(self):
|
||||
"""Close database connection."""
|
||||
await self.engine.dispose()
|
||||
|
||||
|
||||
# Global database instance
|
||||
_db_instance: Optional[Database] = None
|
||||
|
||||
|
||||
def get_database() -> Database:
|
||||
"""Get global database instance."""
|
||||
global _db_instance
|
||||
if _db_instance is None:
|
||||
_db_instance = Database()
|
||||
return _db_instance
|
||||
|
||||
128
src/core/logger.py
Normal file
128
src/core/logger.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Configurable logging system with XDG directory support."""
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class LoggingConfig:
|
||||
"""Logging configuration manager."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize logging configuration."""
|
||||
self.config = get_config()
|
||||
self.log_dir = Path(self.config.get("logging.dir", "~/.local/share/crypto_trader/logs")).expanduser()
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.retention_days = self.config.get("logging.retention_days", 30)
|
||||
self._setup_logging()
|
||||
|
||||
def _setup_logging(self):
|
||||
"""Set up logging configuration."""
|
||||
log_level = self.config.get("logging.level", "INFO")
|
||||
level = getattr(logging, log_level.upper(), logging.INFO)
|
||||
|
||||
# Root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(level)
|
||||
console_formatter = logging.Formatter(
|
||||
'%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler with rotation
|
||||
log_file = self.log_dir / "crypto_trader.log"
|
||||
file_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
log_file,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=self.retention_days,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s [%(levelname)s] %(name)s:%(lineno)d: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# Compress old logs
|
||||
self._setup_log_compression()
|
||||
|
||||
def _setup_log_compression(self):
|
||||
"""Set up log compression for old log files."""
|
||||
import gzip
|
||||
import glob
|
||||
|
||||
# Compress logs older than retention period
|
||||
log_files = list(self.log_dir.glob("crypto_trader.log.*"))
|
||||
for log_file in log_files:
|
||||
if not log_file.name.endswith('.gz'):
|
||||
try:
|
||||
with open(log_file, 'rb') as f_in:
|
||||
with gzip.open(f"{log_file}.gz", 'wb') as f_out:
|
||||
f_out.writelines(f_in)
|
||||
log_file.unlink()
|
||||
except Exception:
|
||||
pass # Skip if compression fails
|
||||
|
||||
def get_logger(self, name: str) -> logging.Logger:
|
||||
"""Get a logger with the specified name.
|
||||
|
||||
Args:
|
||||
name: Logger name (typically module name)
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
return logger
|
||||
|
||||
def set_level(self, level: str):
|
||||
"""Set logging level.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
"""
|
||||
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||
logging.getLogger().setLevel(log_level)
|
||||
for handler in logging.getLogger().handlers:
|
||||
handler.setLevel(log_level)
|
||||
|
||||
|
||||
# Global logging config instance
|
||||
_logging_config: Optional[LoggingConfig] = None
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name (typically __name__)
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
global _logging_config
|
||||
if _logging_config is None:
|
||||
_logging_config = LoggingConfig()
|
||||
return _logging_config.get_logger(name)
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Set up logging system."""
|
||||
global _logging_config
|
||||
_logging_config = LoggingConfig()
|
||||
|
||||
261
src/core/pubsub.py
Normal file
261
src/core/pubsub.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""Redis Pub/Sub for real-time event broadcasting across workers."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Callable, Dict, Any, Optional, List
|
||||
from src.core.redis import get_redis_client
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Channel names
|
||||
CHANNEL_MARKET_EVENTS = "crypto_trader:market_events"
|
||||
CHANNEL_TRADE_EVENTS = "crypto_trader:trade_events"
|
||||
CHANNEL_SYSTEM_EVENTS = "crypto_trader:system_events"
|
||||
CHANNEL_AUTOPILOT_EVENTS = "crypto_trader:autopilot_events"
|
||||
|
||||
|
||||
class RedisPubSub:
|
||||
"""Redis Pub/Sub handler for real-time event broadcasting."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Redis Pub/Sub."""
|
||||
self.redis = get_redis_client()
|
||||
self._subscribers: Dict[str, List[Callable]] = {}
|
||||
self._pubsub = None
|
||||
self._running = False
|
||||
self._listen_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def publish(self, channel: str, event_type: str, data: Dict[str, Any]) -> int:
|
||||
"""Publish an event to a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name
|
||||
event_type: Type of event (e.g., 'price_update', 'trade_executed')
|
||||
data: Event data
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
"""
|
||||
message = {
|
||||
"type": event_type,
|
||||
"data": data,
|
||||
"timestamp": asyncio.get_event_loop().time()
|
||||
}
|
||||
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
count = await client.publish(channel, json.dumps(message))
|
||||
logger.debug(f"Published {event_type} to {channel} ({count} subscribers)")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to publish to {channel}: {e}")
|
||||
return 0
|
||||
|
||||
async def subscribe(self, channel: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""Subscribe to a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name
|
||||
callback: Async function to call when message received
|
||||
"""
|
||||
if channel not in self._subscribers:
|
||||
self._subscribers[channel] = []
|
||||
self._subscribers[channel].append(callback)
|
||||
|
||||
logger.info(f"Subscribed to channel: {channel}")
|
||||
|
||||
# Start listener if not running
|
||||
if not self._running:
|
||||
await self._start_listener()
|
||||
|
||||
async def unsubscribe(self, channel: str, callback: Callable = None) -> None:
|
||||
"""Unsubscribe from a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name
|
||||
callback: Specific callback to remove, or None to remove all
|
||||
"""
|
||||
if channel in self._subscribers:
|
||||
if callback:
|
||||
self._subscribers[channel] = [c for c in self._subscribers[channel] if c != callback]
|
||||
else:
|
||||
del self._subscribers[channel]
|
||||
|
||||
logger.info(f"Unsubscribed from channel: {channel}")
|
||||
|
||||
async def _start_listener(self) -> None:
|
||||
"""Start the Pub/Sub listener."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._listen_task = asyncio.create_task(self._listen())
|
||||
logger.info("Started Redis Pub/Sub listener")
|
||||
|
||||
async def _listen(self) -> None:
|
||||
"""Listen for messages on subscribed channels."""
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
self._pubsub = client.pubsub()
|
||||
|
||||
# Subscribe to all registered channels
|
||||
channels = list(self._subscribers.keys())
|
||||
if channels:
|
||||
await self._pubsub.subscribe(*channels)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
message = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
|
||||
if message and message['type'] == 'message':
|
||||
channel = message['channel']
|
||||
if isinstance(channel, bytes):
|
||||
channel = channel.decode('utf-8')
|
||||
|
||||
data = message['data']
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode('utf-8')
|
||||
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
parsed = {"raw": data}
|
||||
|
||||
# Call all subscribers for this channel
|
||||
callbacks = self._subscribers.get(channel, [])
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(parsed)
|
||||
else:
|
||||
callback(parsed)
|
||||
except Exception as e:
|
||||
logger.error(f"Subscriber callback error: {e}")
|
||||
|
||||
await asyncio.sleep(0.01) # Prevent busy loop
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Pub/Sub listener error: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
finally:
|
||||
if self._pubsub:
|
||||
await self._pubsub.close()
|
||||
self._running = False
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Pub/Sub listener."""
|
||||
self._running = False
|
||||
if self._listen_task:
|
||||
self._listen_task.cancel()
|
||||
try:
|
||||
await self._listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Stopped Redis Pub/Sub listener")
|
||||
|
||||
# Convenience methods for common event types
|
||||
|
||||
async def publish_price_update(self, symbol: str, price: float, bid: float = None, ask: float = None) -> int:
|
||||
"""Publish a price update event.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Current price
|
||||
bid: Bid price
|
||||
ask: Ask price
|
||||
|
||||
Returns:
|
||||
Number of subscribers notified
|
||||
"""
|
||||
return await self.publish(CHANNEL_MARKET_EVENTS, "price_update", {
|
||||
"symbol": symbol,
|
||||
"price": price,
|
||||
"bid": bid,
|
||||
"ask": ask
|
||||
})
|
||||
|
||||
async def publish_trade_executed(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str,
|
||||
quantity: float,
|
||||
price: float,
|
||||
order_id: str = None
|
||||
) -> int:
|
||||
"""Publish a trade execution event.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
side: 'buy' or 'sell'
|
||||
quantity: Trade quantity
|
||||
price: Execution price
|
||||
order_id: Order ID
|
||||
|
||||
Returns:
|
||||
Number of subscribers notified
|
||||
"""
|
||||
return await self.publish(CHANNEL_TRADE_EVENTS, "trade_executed", {
|
||||
"symbol": symbol,
|
||||
"side": side,
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"order_id": order_id
|
||||
})
|
||||
|
||||
async def publish_autopilot_status(
|
||||
self,
|
||||
symbol: str,
|
||||
status: str,
|
||||
action: str = None,
|
||||
details: Dict[str, Any] = None
|
||||
) -> int:
|
||||
"""Publish an autopilot status event.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
status: 'started', 'stopped', 'error', 'signal'
|
||||
action: Optional action taken
|
||||
details: Additional details
|
||||
|
||||
Returns:
|
||||
Number of subscribers notified
|
||||
"""
|
||||
return await self.publish(CHANNEL_AUTOPILOT_EVENTS, "autopilot_status", {
|
||||
"symbol": symbol,
|
||||
"status": status,
|
||||
"action": action,
|
||||
"details": details or {}
|
||||
})
|
||||
|
||||
async def publish_system_event(self, event_type: str, message: str, severity: str = "info") -> int:
|
||||
"""Publish a system event.
|
||||
|
||||
Args:
|
||||
event_type: Event type (e.g., 'startup', 'shutdown', 'error')
|
||||
message: Event message
|
||||
severity: 'info', 'warning', 'error'
|
||||
|
||||
Returns:
|
||||
Number of subscribers notified
|
||||
"""
|
||||
return await self.publish(CHANNEL_SYSTEM_EVENTS, event_type, {
|
||||
"message": message,
|
||||
"severity": severity
|
||||
})
|
||||
|
||||
|
||||
# Global Pub/Sub instance
|
||||
_redis_pubsub: Optional[RedisPubSub] = None
|
||||
|
||||
|
||||
def get_redis_pubsub() -> RedisPubSub:
|
||||
"""Get global Redis Pub/Sub instance."""
|
||||
global _redis_pubsub
|
||||
if _redis_pubsub is None:
|
||||
_redis_pubsub = RedisPubSub()
|
||||
return _redis_pubsub
|
||||
70
src/core/redis.py
Normal file
70
src/core/redis.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Redis client wrapper."""
|
||||
|
||||
import redis.asyncio as redis
|
||||
from typing import Optional
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisClient:
|
||||
"""Redis client wrapper with automatic connection handling."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Redis client."""
|
||||
self.config = get_config()
|
||||
self._client: Optional[redis.Redis] = None
|
||||
self._pool: Optional[redis.ConnectionPool] = None
|
||||
|
||||
def get_client(self) -> redis.Redis:
|
||||
"""Get or create Redis client.
|
||||
|
||||
Returns:
|
||||
Async Redis client
|
||||
"""
|
||||
if self._client is None:
|
||||
self._connect()
|
||||
return self._client
|
||||
|
||||
def _connect(self):
|
||||
"""Connect to Redis."""
|
||||
redis_config = self.config.get("redis", {})
|
||||
host = redis_config.get("host", "localhost")
|
||||
port = redis_config.get("port", 6379)
|
||||
db = redis_config.get("db", 0)
|
||||
password = redis_config.get("password")
|
||||
|
||||
logger.info(f"Connecting to Redis at {host}:{port}/{db}")
|
||||
|
||||
try:
|
||||
self._pool = redis.ConnectionPool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=redis_config.get("socket_connect_timeout", 5)
|
||||
)
|
||||
self._client = redis.Redis(connection_pool=self._pool)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Redis client: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
logger.info("Redis connection closed")
|
||||
|
||||
|
||||
# Global instance
|
||||
_redis_client: Optional[RedisClient] = None
|
||||
|
||||
|
||||
def get_redis_client() -> RedisClient:
|
||||
"""Get global Redis client instance."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = RedisClient()
|
||||
return _redis_client
|
||||
99
src/core/repositories.py
Normal file
99
src/core/repositories.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Database repositories for data access."""
|
||||
|
||||
from typing import Optional, List, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from .database import Order, Position, OrderStatus, OrderSide, OrderType, MarketData
|
||||
|
||||
class BaseRepository:
|
||||
"""Base repository."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""Initialize repository."""
|
||||
self.session = session
|
||||
|
||||
class OrderRepository(BaseRepository):
|
||||
"""Order repository."""
|
||||
|
||||
async def create(self, order: Order) -> Order:
|
||||
"""Create new order."""
|
||||
self.session.add(order)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(order)
|
||||
return order
|
||||
|
||||
async def get_by_id(self, order_id: int) -> Optional[Order]:
|
||||
"""Get order by ID."""
|
||||
result = await self.session.execute(
|
||||
select(Order).where(Order.id == order_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_all(self, limit: int = 100, offset: int = 0) -> Sequence[Order]:
|
||||
"""Get all orders."""
|
||||
result = await self.session.execute(
|
||||
select(Order).limit(limit).offset(offset).order_by(Order.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
order_id: int,
|
||||
status: OrderStatus,
|
||||
exchange_order_id: Optional[str] = None,
|
||||
fee: Optional[Decimal] = None
|
||||
) -> Optional[Order]:
|
||||
"""Update order status."""
|
||||
values = {"status": status, "updated_at": datetime.utcnow()}
|
||||
if exchange_order_id:
|
||||
values["exchange_order_id"] = exchange_order_id
|
||||
if fee is not None:
|
||||
values["fee"] = fee
|
||||
|
||||
await self.session.execute(
|
||||
update(Order)
|
||||
.where(Order.id == order_id)
|
||||
.values(**values)
|
||||
)
|
||||
await self.session.commit()
|
||||
return await self.get_by_id(order_id)
|
||||
|
||||
async def get_open_orders(self, paper_trading: bool = True) -> Sequence[Order]:
|
||||
"""Get open orders."""
|
||||
result = await self.session.execute(
|
||||
select(Order).where(
|
||||
Order.paper_trading == paper_trading,
|
||||
Order.status.in_([OrderStatus.PENDING, OrderStatus.OPEN, OrderStatus.PARTIALLY_FILLED])
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def delete(self, order_id: int) -> bool:
|
||||
"""Delete order."""
|
||||
result = await self.session.execute(
|
||||
delete(Order).where(Order.id == order_id)
|
||||
)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
class PositionRepository(BaseRepository):
|
||||
"""Position repository."""
|
||||
|
||||
async def get_all(self, paper_trading: bool = True) -> Sequence[Position]:
|
||||
"""Get all positions."""
|
||||
result = await self.session.execute(
|
||||
select(Position).where(Position.paper_trading == paper_trading)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_by_symbol(self, symbol: str, paper_trading: bool = True) -> Optional[Position]:
|
||||
"""Get position by symbol."""
|
||||
result = await self.session.execute(
|
||||
select(Position).where(
|
||||
Position.symbol == symbol,
|
||||
Position.paper_trading == paper_trading
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
27
src/data/__init__.py
Normal file
27
src/data/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Data collection and storage module.
|
||||
|
||||
Provides:
|
||||
- DataCollector: Real-time market data collection
|
||||
- NewsCollector: Crypto news headline aggregation for sentiment analysis
|
||||
- TechnicalIndicators: Technical analysis indicators
|
||||
- DataStorage: Data persistence utilities
|
||||
- DataQualityManager: Data quality checks
|
||||
"""
|
||||
|
||||
from .collector import DataCollector
|
||||
from .news_collector import NewsCollector, NewsItem, NewsSource, get_news_collector
|
||||
from .indicators import TechnicalIndicators, get_indicators
|
||||
from .storage import DataStorage
|
||||
from .quality import DataQualityManager
|
||||
|
||||
__all__ = [
|
||||
"DataCollector",
|
||||
"NewsCollector",
|
||||
"NewsItem",
|
||||
"NewsSource",
|
||||
"get_news_collector",
|
||||
"TechnicalIndicators",
|
||||
"get_indicators",
|
||||
"DataStorage",
|
||||
"DataQualityManager",
|
||||
]
|
||||
221
src/data/cache_manager.py
Normal file
221
src/data/cache_manager.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Intelligent caching system for pricing data."""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from collections import OrderedDict
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CacheEntry:
|
||||
"""Cache entry with TTL support."""
|
||||
|
||||
def __init__(self, data: Any, ttl: float):
|
||||
"""Initialize cache entry.
|
||||
|
||||
Args:
|
||||
data: Cached data
|
||||
ttl: Time to live in seconds
|
||||
"""
|
||||
self.data = data
|
||||
self.expires_at = time.time() + ttl
|
||||
self.created_at = time.time()
|
||||
self.access_count = 0
|
||||
self.last_accessed = time.time()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if entry is expired.
|
||||
|
||||
Returns:
|
||||
True if expired
|
||||
"""
|
||||
return time.time() > self.expires_at
|
||||
|
||||
def touch(self):
|
||||
"""Update access statistics."""
|
||||
self.access_count += 1
|
||||
self.last_accessed = time.time()
|
||||
|
||||
def age(self) -> float:
|
||||
"""Get age of entry in seconds.
|
||||
|
||||
Returns:
|
||||
Age in seconds
|
||||
"""
|
||||
return time.time() - self.created_at
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Intelligent cache manager with TTL and size limits.
|
||||
|
||||
Implements LRU (Least Recently Used) eviction when size limit is reached.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_ttl: float = 60.0,
|
||||
max_size: int = 1000,
|
||||
ticker_ttl: float = 2.0,
|
||||
ohlcv_ttl: float = 60.0
|
||||
):
|
||||
"""Initialize cache manager.
|
||||
|
||||
Args:
|
||||
default_ttl: Default TTL in seconds
|
||||
max_size: Maximum number of cache entries
|
||||
ticker_ttl: TTL for ticker data in seconds
|
||||
ohlcv_ttl: TTL for OHLCV data in seconds
|
||||
"""
|
||||
self.default_ttl = default_ttl
|
||||
self.max_size = max_size
|
||||
self.ticker_ttl = ticker_ttl
|
||||
self.ohlcv_ttl = ohlcv_ttl
|
||||
|
||||
# Use OrderedDict for LRU eviction
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
self._evictions = 0
|
||||
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
# Clean expired entries
|
||||
self._cleanup_expired()
|
||||
|
||||
if key not in self._cache:
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[key]
|
||||
|
||||
if entry.is_expired():
|
||||
# Remove expired entry
|
||||
del self._cache[key]
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
# Update access (move to end for LRU)
|
||||
entry.touch()
|
||||
self._cache.move_to_end(key)
|
||||
self._hits += 1
|
||||
return entry.data
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[float] = None,
|
||||
cache_type: Optional[str] = None
|
||||
):
|
||||
"""Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds (uses type-specific or default if None)
|
||||
cache_type: Type of cache ('ticker' or 'ohlcv') for type-specific TTL
|
||||
"""
|
||||
# Determine TTL
|
||||
if ttl is None:
|
||||
if cache_type == 'ticker':
|
||||
ttl = self.ticker_ttl
|
||||
elif cache_type == 'ohlcv':
|
||||
ttl = self.ohlcv_ttl
|
||||
else:
|
||||
ttl = self.default_ttl
|
||||
|
||||
# Check if we need to evict
|
||||
if len(self._cache) >= self.max_size and key not in self._cache:
|
||||
self._evict_lru()
|
||||
|
||||
# Create entry
|
||||
entry = CacheEntry(value, ttl)
|
||||
|
||||
# Add or update
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
self._cache[key] = entry
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete entry from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cache entries."""
|
||||
self._cache.clear()
|
||||
self.logger.info("Cache cleared")
|
||||
|
||||
def _cleanup_expired(self):
|
||||
"""Remove expired entries from cache."""
|
||||
expired_keys = [
|
||||
key for key, entry in self._cache.items()
|
||||
if entry.is_expired()
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
|
||||
def _evict_lru(self):
|
||||
"""Evict least recently used entry."""
|
||||
if self._cache:
|
||||
# Remove oldest (first) entry
|
||||
self._cache.popitem(last=False)
|
||||
self._evictions += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
total_requests = self._hits + self._misses
|
||||
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0.0
|
||||
|
||||
# Calculate average age
|
||||
if self._cache:
|
||||
avg_age = sum(entry.age() for entry in self._cache.values()) / len(self._cache)
|
||||
else:
|
||||
avg_age = 0.0
|
||||
|
||||
return {
|
||||
'size': len(self._cache),
|
||||
'max_size': self.max_size,
|
||||
'hits': self._hits,
|
||||
'misses': self._misses,
|
||||
'hit_rate': round(hit_rate, 2),
|
||||
'evictions': self._evictions,
|
||||
'avg_age_seconds': round(avg_age, 2),
|
||||
}
|
||||
|
||||
def invalidate_pattern(self, pattern: str):
|
||||
"""Invalidate entries matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: String pattern to match (simple substring match)
|
||||
"""
|
||||
keys_to_delete = [key for key in self._cache.keys() if pattern in key]
|
||||
for key in keys_to_delete:
|
||||
del self._cache[key]
|
||||
|
||||
if keys_to_delete:
|
||||
self.logger.info(f"Invalidated {len(keys_to_delete)} cache entries matching '{pattern}'")
|
||||
139
src/data/collector.py
Normal file
139
src/data/collector.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Real-time data collection system with WebSocket support."""
|
||||
|
||||
import asyncio
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Optional, Callable, List
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select
|
||||
from src.core.database import get_database, MarketData
|
||||
from src.core.logger import get_logger
|
||||
from .pricing_service import get_pricing_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataCollector:
|
||||
"""Collects real-time market data using the unified pricing service."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data collector."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
self._callbacks: Dict[str, List[Callable]] = {}
|
||||
self._running = False
|
||||
self._pricing_service = get_pricing_service()
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
exchange_id: Optional[int] = None,
|
||||
symbol: str = "",
|
||||
callback: Optional[Callable] = None
|
||||
):
|
||||
"""Subscribe to real-time data.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange ID (deprecated, kept for backward compatibility)
|
||||
symbol: Trading symbol
|
||||
callback: Callback function(data)
|
||||
"""
|
||||
if not symbol or not callback:
|
||||
logger.warning("subscribe called without symbol or callback")
|
||||
return
|
||||
|
||||
key = f"pricing:{symbol}"
|
||||
if key not in self._callbacks:
|
||||
self._callbacks[key] = []
|
||||
self._callbacks[key].append(callback)
|
||||
|
||||
# Subscribe via pricing service
|
||||
def wrapped_callback(data):
|
||||
"""Wrap callback to maintain backward compatibility."""
|
||||
for cb in self._callbacks.get(key, []):
|
||||
try:
|
||||
cb(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
self._pricing_service.subscribe_ticker(symbol, wrapped_callback)
|
||||
|
||||
async def store_ohlcv(
|
||||
self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
ohlcv_data: List[List]
|
||||
):
|
||||
"""Store OHLCV data in database.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name (can be provider name like 'CCXT-Kraken' or 'CoinGecko')
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
ohlcv_data: List of [timestamp, open, high, low, close, volume]
|
||||
"""
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
for candle in ohlcv_data:
|
||||
timestamp = datetime.fromtimestamp(candle[0] / 1000)
|
||||
|
||||
# Check if already exists
|
||||
stmt = select(MarketData).filter_by(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if not existing:
|
||||
market_data = MarketData(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
open=Decimal(str(candle[1])),
|
||||
high=Decimal(str(candle[2])),
|
||||
low=Decimal(str(candle[3])),
|
||||
close=Decimal(str(candle[4])),
|
||||
volume=Decimal(str(candle[5]))
|
||||
)
|
||||
session.add(market_data)
|
||||
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to store OHLCV data: {e}")
|
||||
|
||||
def get_ohlcv_from_pricing_service(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV data from pricing service.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
since: Start datetime
|
||||
limit: Number of candles
|
||||
|
||||
Returns:
|
||||
List of OHLCV candles
|
||||
"""
|
||||
return self._pricing_service.get_ohlcv(symbol, timeframe, since, limit)
|
||||
|
||||
|
||||
# Global data collector
|
||||
_data_collector: Optional[DataCollector] = None
|
||||
|
||||
|
||||
def get_data_collector() -> DataCollector:
|
||||
"""Get global data collector instance."""
|
||||
global _data_collector
|
||||
if _data_collector is None:
|
||||
_data_collector = DataCollector()
|
||||
return _data_collector
|
||||
|
||||
317
src/data/health_monitor.py
Normal file
317
src/data/health_monitor.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Health monitoring and failover management for pricing providers."""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HealthStatus(Enum):
|
||||
"""Provider health status."""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthMetrics:
|
||||
"""Health metrics for a provider."""
|
||||
status: HealthStatus = HealthStatus.UNKNOWN
|
||||
last_check: Optional[datetime] = None
|
||||
last_success: Optional[datetime] = None
|
||||
last_failure: Optional[datetime] = None
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
response_times: deque = field(default_factory=lambda: deque(maxlen=100))
|
||||
consecutive_failures: int = 0
|
||||
circuit_breaker_open: bool = False
|
||||
circuit_breaker_opened_at: Optional[datetime] = None
|
||||
|
||||
def record_success(self, response_time: float):
|
||||
"""Record a successful operation.
|
||||
|
||||
Args:
|
||||
response_time: Response time in seconds
|
||||
"""
|
||||
self.status = HealthStatus.HEALTHY
|
||||
self.last_check = datetime.utcnow()
|
||||
self.last_success = datetime.utcnow()
|
||||
self.success_count += 1
|
||||
self.response_times.append(response_time)
|
||||
self.consecutive_failures = 0
|
||||
self.circuit_breaker_open = False
|
||||
self.circuit_breaker_opened_at = None
|
||||
|
||||
def record_failure(self):
|
||||
"""Record a failed operation."""
|
||||
self.last_check = datetime.utcnow()
|
||||
self.last_failure = datetime.utcnow()
|
||||
self.failure_count += 1
|
||||
self.consecutive_failures += 1
|
||||
|
||||
# Open circuit breaker after 5 consecutive failures
|
||||
if self.consecutive_failures >= 5:
|
||||
if not self.circuit_breaker_open:
|
||||
self.circuit_breaker_open = True
|
||||
self.circuit_breaker_opened_at = datetime.utcnow()
|
||||
logger.warning(f"Circuit breaker opened after {self.consecutive_failures} failures")
|
||||
|
||||
# Update status based on failure rate
|
||||
total = self.success_count + self.failure_count
|
||||
if total > 0:
|
||||
failure_rate = self.failure_count / total
|
||||
if failure_rate > 0.5:
|
||||
self.status = HealthStatus.UNHEALTHY
|
||||
elif failure_rate > 0.2:
|
||||
self.status = HealthStatus.DEGRADED
|
||||
|
||||
def get_avg_response_time(self) -> float:
|
||||
"""Get average response time.
|
||||
|
||||
Returns:
|
||||
Average response time in seconds, or 0.0 if no data
|
||||
"""
|
||||
if not self.response_times:
|
||||
return 0.0
|
||||
return sum(self.response_times) / len(self.response_times)
|
||||
|
||||
def should_attempt(self, circuit_breaker_timeout: int = 60) -> bool:
|
||||
"""Check if we should attempt to use this provider.
|
||||
|
||||
Args:
|
||||
circuit_breaker_timeout: Seconds to wait before retrying after circuit breaker opens
|
||||
|
||||
Returns:
|
||||
True if we should attempt, False otherwise
|
||||
"""
|
||||
if not self.circuit_breaker_open:
|
||||
return True
|
||||
|
||||
# Check if timeout has passed
|
||||
if self.circuit_breaker_opened_at:
|
||||
elapsed = (datetime.utcnow() - self.circuit_breaker_opened_at).total_seconds()
|
||||
if elapsed >= circuit_breaker_timeout:
|
||||
# Half-open state: allow one attempt
|
||||
logger.info("Circuit breaker half-open, allowing attempt")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API responses."""
|
||||
return {
|
||||
'status': self.status.value,
|
||||
'last_check': self.last_check.isoformat() if self.last_check else None,
|
||||
'last_success': self.last_success.isoformat() if self.last_success else None,
|
||||
'last_failure': self.last_failure.isoformat() if self.last_failure else None,
|
||||
'success_count': self.success_count,
|
||||
'failure_count': self.failure_count,
|
||||
'avg_response_time': round(self.get_avg_response_time(), 3),
|
||||
'consecutive_failures': self.consecutive_failures,
|
||||
'circuit_breaker_open': self.circuit_breaker_open,
|
||||
'circuit_breaker_opened_at': (
|
||||
self.circuit_breaker_opened_at.isoformat()
|
||||
if self.circuit_breaker_opened_at else None
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""Monitors health of pricing providers and manages failover."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
circuit_breaker_timeout: int = 60,
|
||||
min_success_rate: float = 0.8,
|
||||
max_avg_response_time: float = 5.0
|
||||
):
|
||||
"""Initialize health monitor.
|
||||
|
||||
Args:
|
||||
circuit_breaker_timeout: Seconds to wait before retrying after circuit breaker opens
|
||||
min_success_rate: Minimum success rate to be considered healthy (0.0-1.0)
|
||||
max_avg_response_time: Maximum average response time in seconds to be considered healthy
|
||||
"""
|
||||
self.circuit_breaker_timeout = circuit_breaker_timeout
|
||||
self.min_success_rate = min_success_rate
|
||||
self.max_avg_response_time = max_avg_response_time
|
||||
|
||||
self._metrics: Dict[str, HealthMetrics] = {}
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def record_success(self, provider_name: str, response_time: float):
|
||||
"""Record a successful operation for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
response_time: Response time in seconds
|
||||
"""
|
||||
if provider_name not in self._metrics:
|
||||
self._metrics[provider_name] = HealthMetrics()
|
||||
|
||||
self._metrics[provider_name].record_success(response_time)
|
||||
self.logger.debug(f"Recorded success for {provider_name} ({response_time:.3f}s)")
|
||||
|
||||
def record_failure(self, provider_name: str):
|
||||
"""Record a failed operation for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
"""
|
||||
if provider_name not in self._metrics:
|
||||
self._metrics[provider_name] = HealthMetrics()
|
||||
|
||||
self._metrics[provider_name].record_failure()
|
||||
self.logger.warning(f"Recorded failure for {provider_name} "
|
||||
f"(consecutive: {self._metrics[provider_name].consecutive_failures})")
|
||||
|
||||
def is_healthy(self, provider_name: str) -> bool:
|
||||
"""Check if a provider is healthy.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
|
||||
Returns:
|
||||
True if provider is healthy
|
||||
"""
|
||||
if provider_name not in self._metrics:
|
||||
return True # Assume healthy if no metrics yet
|
||||
|
||||
metrics = self._metrics[provider_name]
|
||||
|
||||
# Check circuit breaker
|
||||
if not metrics.should_attempt(self.circuit_breaker_timeout):
|
||||
return False
|
||||
|
||||
# Check status
|
||||
if metrics.status == HealthStatus.UNHEALTHY:
|
||||
return False
|
||||
|
||||
# Check success rate
|
||||
total = metrics.success_count + metrics.failure_count
|
||||
if total > 10: # Need minimum data points
|
||||
success_rate = metrics.success_count / total
|
||||
if success_rate < self.min_success_rate:
|
||||
return False
|
||||
|
||||
# Check response time
|
||||
if metrics.response_times:
|
||||
avg_response_time = metrics.get_avg_response_time()
|
||||
if avg_response_time > self.max_avg_response_time:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_health_status(self, provider_name: str) -> HealthStatus:
|
||||
"""Get health status for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
|
||||
Returns:
|
||||
Health status
|
||||
"""
|
||||
if provider_name not in self._metrics:
|
||||
return HealthStatus.UNKNOWN
|
||||
|
||||
return self._metrics[provider_name].status
|
||||
|
||||
def get_metrics(self, provider_name: str) -> Optional[HealthMetrics]:
|
||||
"""Get health metrics for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
|
||||
Returns:
|
||||
Health metrics or None if not found
|
||||
"""
|
||||
return self._metrics.get(provider_name)
|
||||
|
||||
def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all provider metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping provider names to their metrics
|
||||
"""
|
||||
return {
|
||||
name: metrics.to_dict()
|
||||
for name, metrics in self._metrics.items()
|
||||
}
|
||||
|
||||
def select_healthiest(self, provider_names: List[str]) -> Optional[str]:
|
||||
"""Select the healthiest provider from a list.
|
||||
|
||||
Args:
|
||||
provider_names: List of provider names to choose from
|
||||
|
||||
Returns:
|
||||
Name of healthiest provider, or None if none are healthy
|
||||
"""
|
||||
healthy_providers = [
|
||||
name for name in provider_names
|
||||
if self.is_healthy(name)
|
||||
]
|
||||
|
||||
if not healthy_providers:
|
||||
return None
|
||||
|
||||
# Sort by health metrics (better providers first)
|
||||
def health_score(name: str) -> tuple:
|
||||
metrics = self._metrics.get(name)
|
||||
if not metrics:
|
||||
return (1, 0, 0) # Unknown providers get lowest priority
|
||||
|
||||
# Score: (status_weight, -avg_response_time, success_rate)
|
||||
status_weights = {
|
||||
HealthStatus.HEALTHY: 3,
|
||||
HealthStatus.DEGRADED: 2,
|
||||
HealthStatus.UNHEALTHY: 1,
|
||||
HealthStatus.UNKNOWN: 0,
|
||||
}
|
||||
|
||||
success_rate = (
|
||||
metrics.success_count / (metrics.success_count + metrics.failure_count)
|
||||
if (metrics.success_count + metrics.failure_count) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return (
|
||||
status_weights.get(metrics.status, 0),
|
||||
-metrics.get_avg_response_time(),
|
||||
success_rate
|
||||
)
|
||||
|
||||
sorted_providers = sorted(healthy_providers, key=health_score, reverse=True)
|
||||
return sorted_providers[0] if sorted_providers else None
|
||||
|
||||
def reset_circuit_breaker(self, provider_name: str):
|
||||
"""Manually reset circuit breaker for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
"""
|
||||
if provider_name in self._metrics:
|
||||
self._metrics[provider_name].circuit_breaker_open = False
|
||||
self._metrics[provider_name].circuit_breaker_opened_at = None
|
||||
self._metrics[provider_name].consecutive_failures = 0
|
||||
self.logger.info(f"Circuit breaker reset for {provider_name}")
|
||||
|
||||
def reset_metrics(self, provider_name: Optional[str] = None):
|
||||
"""Reset metrics for a provider or all providers.
|
||||
|
||||
Args:
|
||||
provider_name: Name of provider to reset, or None to reset all
|
||||
"""
|
||||
if provider_name:
|
||||
if provider_name in self._metrics:
|
||||
del self._metrics[provider_name]
|
||||
self.logger.info(f"Reset metrics for {provider_name}")
|
||||
else:
|
||||
self._metrics.clear()
|
||||
self.logger.info("Reset all provider metrics")
|
||||
569
src/data/indicators.py
Normal file
569
src/data/indicators.py
Normal file
@@ -0,0 +1,569 @@
|
||||
"""Comprehensive technical indicator library using pandas-ta and TA-Lib."""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
# Try to import pandas_ta, but handle if numba is missing
|
||||
try:
|
||||
import pandas_ta as ta
|
||||
PANDAS_TA_AVAILABLE = True
|
||||
except ImportError:
|
||||
PANDAS_TA_AVAILABLE = False
|
||||
ta = None
|
||||
import warnings
|
||||
warnings.warn("pandas-ta not available (numba issue), using basic implementations")
|
||||
|
||||
try:
|
||||
import talib
|
||||
TALIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
TALIB_AVAILABLE = False
|
||||
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TechnicalIndicators:
|
||||
"""Technical indicators library."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize indicators library."""
|
||||
self.talib_available = TALIB_AVAILABLE
|
||||
|
||||
# Trend Indicators
|
||||
|
||||
def sma(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Simple Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.sma(data, length=period)
|
||||
return data.rolling(window=period).mean()
|
||||
|
||||
def ema(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Exponential Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.ema(data, length=period)
|
||||
return data.ewm(span=period, adjust=False).mean()
|
||||
|
||||
def wma(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Weighted Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.wma(data, length=period)
|
||||
# Basic WMA implementation
|
||||
weights = np.arange(1, period + 1)
|
||||
return data.rolling(window=period).apply(lambda x: np.dot(x, weights) / weights.sum(), raw=True)
|
||||
|
||||
def dema(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Double Exponential Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.dema(data, length=period)
|
||||
ema1 = self.ema(data, period)
|
||||
return 2 * ema1 - self.ema(ema1, period)
|
||||
|
||||
def tema(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Triple Exponential Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.tema(data, length=period)
|
||||
ema1 = self.ema(data, period)
|
||||
ema2 = self.ema(ema1, period)
|
||||
ema3 = self.ema(ema2, period)
|
||||
return 3 * ema1 - 3 * ema2 + ema3
|
||||
|
||||
# Momentum Indicators
|
||||
|
||||
def rsi(self, data: pd.Series, period: int = 14) -> pd.Series:
|
||||
"""Relative Strength Index."""
|
||||
if self.talib_available:
|
||||
return pd.Series(talib.RSI(data.values, timeperiod=period), index=data.index)
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.rsi(data, length=period)
|
||||
# Basic RSI implementation
|
||||
delta = data.diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
return 100 - (100 / (1 + rs))
|
||||
|
||||
def macd(
|
||||
self,
|
||||
data: pd.Series,
|
||||
fast: int = 12,
|
||||
slow: int = 26,
|
||||
signal: int = 9
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""MACD (Moving Average Convergence Divergence)."""
|
||||
if self.talib_available:
|
||||
macd, signal_line, histogram = talib.MACD(
|
||||
data.values, fastperiod=fast, slowperiod=slow, signalperiod=signal
|
||||
)
|
||||
return {
|
||||
'macd': pd.Series(macd, index=data.index),
|
||||
'signal': pd.Series(signal_line, index=data.index),
|
||||
'histogram': pd.Series(histogram, index=data.index),
|
||||
}
|
||||
if not PANDAS_TA_AVAILABLE or ta is None:
|
||||
# Basic MACD implementation fallback
|
||||
ema_fast = self.ema(data, fast)
|
||||
ema_slow = self.ema(data, slow)
|
||||
macd_line = ema_fast - ema_slow
|
||||
signal_line = self.ema(macd_line.dropna(), signal)
|
||||
histogram = macd_line - signal_line
|
||||
return {
|
||||
'macd': macd_line,
|
||||
'signal': signal_line,
|
||||
'histogram': histogram,
|
||||
}
|
||||
result = ta.macd(data, fast=fast, slow=slow, signal=signal)
|
||||
return {
|
||||
'macd': result[f'MACD_{fast}_{slow}_{signal}'],
|
||||
'signal': result[f'MACDs_{fast}_{slow}_{signal}'],
|
||||
'histogram': result[f'MACDh_{fast}_{slow}_{signal}'],
|
||||
}
|
||||
|
||||
def stochastic(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
k_period: int = 14,
|
||||
d_period: int = 3
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""Stochastic Oscillator."""
|
||||
if self.talib_available:
|
||||
slowk, slowd = talib.STOCH(
|
||||
high.values, low.values, close.values,
|
||||
fastk_period=k_period, slowk_period=d_period, slowd_period=d_period
|
||||
)
|
||||
return {
|
||||
'k': pd.Series(slowk, index=close.index),
|
||||
'd': pd.Series(slowd, index=close.index),
|
||||
}
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
result = ta.stoch(high, low, close, k=k_period, d=d_period)
|
||||
return {
|
||||
'k': result[f'STOCHk_{k_period}_{d_period}_{d_period}'],
|
||||
'd': result[f'STOCHd_{k_period}_{d_period}_{d_period}'],
|
||||
}
|
||||
# Basic Stochastic implementation
|
||||
lowest_low = low.rolling(window=k_period).min()
|
||||
highest_high = high.rolling(window=k_period).max()
|
||||
k = 100 * ((close - lowest_low) / (highest_high - lowest_low))
|
||||
d = k.rolling(window=d_period).mean()
|
||||
return {'k': k, 'd': d}
|
||||
|
||||
def cci(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 20
|
||||
) -> pd.Series:
|
||||
"""Commodity Channel Index."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.CCI(high.values, low.values, close.values, timeperiod=period),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.cci(high, low, close, length=period)
|
||||
# Basic CCI implementation
|
||||
tp = (high + low + close) / 3
|
||||
sma_tp = tp.rolling(window=period).mean()
|
||||
mad = tp.rolling(window=period).apply(lambda x: np.abs(x - x.mean()).mean(), raw=True)
|
||||
return (tp - sma_tp) / (0.015 * mad)
|
||||
|
||||
def williams_r(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 14
|
||||
) -> pd.Series:
|
||||
"""Williams %R."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.WILLR(high.values, low.values, close.values, timeperiod=period),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.willr(high, low, close, length=period)
|
||||
# Basic Williams %R implementation
|
||||
highest_high = high.rolling(window=period).max()
|
||||
lowest_low = low.rolling(window=period).min()
|
||||
return -100 * ((highest_high - close) / (highest_high - lowest_low))
|
||||
|
||||
# Volatility Indicators
|
||||
|
||||
def bollinger_bands(
|
||||
self,
|
||||
data: pd.Series,
|
||||
period: int = 20,
|
||||
std_dev: float = 2.0
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""Bollinger Bands."""
|
||||
if self.talib_available:
|
||||
upper, middle, lower = talib.BBANDS(
|
||||
data.values, timeperiod=period, nbdevup=std_dev, nbdevdn=std_dev
|
||||
)
|
||||
return {
|
||||
'upper': pd.Series(upper, index=data.index),
|
||||
'middle': pd.Series(middle, index=data.index),
|
||||
'lower': pd.Series(lower, index=data.index),
|
||||
}
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
result = ta.bbands(data, length=period, std=std_dev)
|
||||
return {
|
||||
'upper': result[f'BBU_{period}_{std_dev}'],
|
||||
'middle': result[f'BBM_{period}_{std_dev}'],
|
||||
'lower': result[f'BBL_{period}_{std_dev}'],
|
||||
}
|
||||
# Basic Bollinger Bands implementation
|
||||
middle = self.sma(data, period)
|
||||
std = data.rolling(window=period).std()
|
||||
return {
|
||||
'upper': middle + (std * std_dev),
|
||||
'middle': middle,
|
||||
'lower': middle - (std * std_dev),
|
||||
}
|
||||
|
||||
def atr(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 14
|
||||
) -> pd.Series:
|
||||
"""Average True Range."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.ATR(high.values, low.values, close.values, timeperiod=period),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.atr(high, low, close, length=period)
|
||||
# Basic ATR implementation
|
||||
prev_close = close.shift(1)
|
||||
tr1 = high - low
|
||||
tr2 = abs(high - prev_close)
|
||||
tr3 = abs(low - prev_close)
|
||||
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
||||
return tr.rolling(window=period).mean()
|
||||
|
||||
def keltner_channels(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 20,
|
||||
multiplier: float = 2.0
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""Keltner Channels."""
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.kc(high, low, close, length=period, scalar=multiplier)
|
||||
# Basic Keltner Channels implementation
|
||||
middle = self.ema(close, period)
|
||||
atr_val = self.atr(high, low, close, period)
|
||||
return {
|
||||
'lower': middle - (multiplier * atr_val),
|
||||
'middle': middle,
|
||||
'upper': middle + (multiplier * atr_val),
|
||||
}
|
||||
|
||||
# Volume Indicators
|
||||
|
||||
def obv(self, close: pd.Series, volume: pd.Series) -> pd.Series:
|
||||
"""On-Balance Volume."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.OBV(close.values, volume.values),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.obv(close, volume)
|
||||
# Basic OBV implementation
|
||||
price_change = close.diff()
|
||||
obv = pd.Series(index=close.index, dtype=float)
|
||||
obv.iloc[0] = 0
|
||||
for i in range(1, len(close)):
|
||||
if price_change.iloc[i] > 0:
|
||||
obv.iloc[i] = obv.iloc[i-1] + volume.iloc[i]
|
||||
elif price_change.iloc[i] < 0:
|
||||
obv.iloc[i] = obv.iloc[i-1] - volume.iloc[i]
|
||||
else:
|
||||
obv.iloc[i] = obv.iloc[i-1]
|
||||
return obv
|
||||
|
||||
def vwap(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
volume: pd.Series
|
||||
) -> pd.Series:
|
||||
"""Volume Weighted Average Price."""
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.vwap(high, low, close, volume)
|
||||
# Basic VWAP implementation
|
||||
typical_price = (high + low + close) / 3
|
||||
cumulative_tp_vol = (typical_price * volume).cumsum()
|
||||
cumulative_vol = volume.cumsum()
|
||||
return cumulative_tp_vol / cumulative_vol
|
||||
|
||||
# Advanced Indicators
|
||||
|
||||
def ichimoku(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
tenkan: int = 9,
|
||||
kijun: int = 26,
|
||||
senkou: int = 52
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""Ichimoku Cloud."""
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
result = ta.ichimoku(high, low, close, tenkan=tenkan, kijun=kijun, senkou=senkou)
|
||||
return {
|
||||
'tenkan': result['ITS_9'],
|
||||
'kijun': result['IKS_26'],
|
||||
'senkou_a': result['ISA_9'],
|
||||
'senkou_b': result['ISB_26'],
|
||||
'chikou': result['ICS_26'],
|
||||
}
|
||||
# Basic Ichimoku implementation
|
||||
tenkan_sen = (high.rolling(window=tenkan).max() + low.rolling(window=tenkan).min()) / 2
|
||||
kijun_sen = (high.rolling(window=kijun).max() + low.rolling(window=kijun).min()) / 2
|
||||
senkou_a = ((tenkan_sen + kijun_sen) / 2).shift(kijun)
|
||||
senkou_b = ((high.rolling(window=senkou).max() + low.rolling(window=senkou).min()) / 2).shift(kijun)
|
||||
chikou = close.shift(-kijun)
|
||||
return {
|
||||
'tenkan': tenkan_sen,
|
||||
'kijun': kijun_sen,
|
||||
'senkou_a': senkou_a,
|
||||
'senkou_b': senkou_b,
|
||||
'chikou': chikou,
|
||||
}
|
||||
|
||||
def adx(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 14
|
||||
) -> pd.Series:
|
||||
"""Average Directional Index."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.ADX(high.values, low.values, close.values, timeperiod=period),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.adx(high, low, close, length=period)
|
||||
# Basic ADX implementation
|
||||
plus_dm = high.diff()
|
||||
minus_dm = low.diff().abs()
|
||||
plus_dm[plus_dm < 0] = 0
|
||||
minus_dm[minus_dm < 0] = 0
|
||||
|
||||
atr_val = self.atr(high, low, close, period)
|
||||
plus_di = 100 * (plus_dm.rolling(window=period).mean() / atr_val)
|
||||
minus_di = 100 * (minus_dm.rolling(window=period).mean() / atr_val)
|
||||
|
||||
dx = 100 * abs(plus_di - minus_di) / (plus_di + minus_di)
|
||||
adx = dx.rolling(window=period).mean()
|
||||
return adx
|
||||
|
||||
def detect_divergence(
|
||||
self,
|
||||
prices: pd.Series,
|
||||
indicator: pd.Series,
|
||||
lookback: int = 20,
|
||||
min_swings: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
"""Detect divergence between price and indicator.
|
||||
|
||||
Divergence occurs when price makes new highs/lows but indicator doesn't,
|
||||
or vice versa. This is a powerful reversal signal.
|
||||
|
||||
Args:
|
||||
prices: Price series
|
||||
indicator: Indicator series (e.g., RSI, MACD)
|
||||
lookback: Lookback period for finding swings
|
||||
min_swings: Minimum number of swings to detect divergence
|
||||
|
||||
Returns:
|
||||
Dictionary with divergence information:
|
||||
{
|
||||
'type': 'bullish', 'bearish', or None
|
||||
'confidence': 0.0 to 1.0
|
||||
'price_swing_high': price at high swing
|
||||
'price_swing_low': price at low swing
|
||||
'indicator_swing_high': indicator at high swing
|
||||
'indicator_swing_low': indicator at low swing
|
||||
}
|
||||
"""
|
||||
if len(prices) < lookback * 2 or len(indicator) < lookback * 2:
|
||||
return {
|
||||
'type': None,
|
||||
'confidence': 0.0,
|
||||
'price_swing_high': None,
|
||||
'price_swing_low': None,
|
||||
'indicator_swing_high': None,
|
||||
'indicator_swing_low': None
|
||||
}
|
||||
|
||||
# Find local extrema (swings)
|
||||
def find_swings(series: pd.Series, lookback: int):
|
||||
"""Find local maxima and minima."""
|
||||
highs = []
|
||||
lows = []
|
||||
|
||||
for i in range(lookback, len(series) - lookback):
|
||||
window = series.iloc[i-lookback:i+lookback+1]
|
||||
center = series.iloc[i]
|
||||
|
||||
# Local maximum
|
||||
if center == window.max():
|
||||
highs.append((i, center))
|
||||
# Local minimum
|
||||
elif center == window.min():
|
||||
lows.append((i, center))
|
||||
|
||||
return highs, lows
|
||||
|
||||
price_highs, price_lows = find_swings(prices, lookback)
|
||||
indicator_highs, indicator_lows = find_swings(indicator, lookback)
|
||||
|
||||
# Need at least min_swings swings to detect divergence
|
||||
if len(price_highs) < min_swings or len(price_lows) < min_swings:
|
||||
return {
|
||||
'type': None,
|
||||
'confidence': 0.0,
|
||||
'price_swing_high': None,
|
||||
'price_swing_low': None,
|
||||
'indicator_swing_high': None,
|
||||
'indicator_swing_low': None
|
||||
}
|
||||
|
||||
# Check for bearish divergence (price makes higher high, indicator makes lower high)
|
||||
if len(price_highs) >= 2 and len(indicator_highs) >= 2:
|
||||
recent_price_high = price_highs[-1][1]
|
||||
prev_price_high = price_highs[-2][1]
|
||||
recent_indicator_high = indicator_highs[-1][1]
|
||||
prev_indicator_high = indicator_highs[-2][1]
|
||||
|
||||
# Price higher high but indicator lower high = bearish divergence
|
||||
if recent_price_high > prev_price_high and recent_indicator_high < prev_indicator_high:
|
||||
confidence = min(1.0, abs(recent_price_high - prev_price_high) / prev_price_high * 10)
|
||||
return {
|
||||
'type': 'bearish',
|
||||
'confidence': confidence,
|
||||
'price_swing_high': (price_highs[-2][0], price_highs[-1][0]),
|
||||
'price_swing_low': None,
|
||||
'indicator_swing_high': (indicator_highs[-2][0], indicator_highs[-1][0]),
|
||||
'indicator_swing_low': None
|
||||
}
|
||||
|
||||
# Check for bullish divergence (price makes lower low, indicator makes higher low)
|
||||
if len(price_lows) >= 2 and len(indicator_lows) >= 2:
|
||||
recent_price_low = price_lows[-1][1]
|
||||
prev_price_low = price_lows[-2][1]
|
||||
recent_indicator_low = indicator_lows[-1][1]
|
||||
prev_indicator_low = indicator_lows[-2][1]
|
||||
|
||||
# Price lower low but indicator higher low = bullish divergence
|
||||
if recent_price_low < prev_price_low and recent_indicator_low > prev_indicator_low:
|
||||
confidence = min(1.0, abs(prev_price_low - recent_price_low) / prev_price_low * 10)
|
||||
return {
|
||||
'type': 'bullish',
|
||||
'confidence': confidence,
|
||||
'price_swing_high': None,
|
||||
'price_swing_low': (price_lows[-2][0], price_lows[-1][0]),
|
||||
'indicator_swing_high': None,
|
||||
'indicator_swing_low': (indicator_lows[-2][0], indicator_lows[-1][0])
|
||||
}
|
||||
|
||||
return {
|
||||
'type': None,
|
||||
'confidence': 0.0,
|
||||
'price_swing_high': None,
|
||||
'price_swing_low': None,
|
||||
'indicator_swing_high': None,
|
||||
'indicator_swing_low': None
|
||||
}
|
||||
|
||||
def calculate_all(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
indicators: Optional[List[str]] = None
|
||||
) -> pd.DataFrame:
|
||||
"""Calculate multiple indicators at once.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV data (columns: open, high, low, close, volume)
|
||||
indicators: List of indicator names to calculate (None = all)
|
||||
|
||||
Returns:
|
||||
DataFrame with added indicator columns
|
||||
"""
|
||||
result = df.copy()
|
||||
|
||||
if 'close' not in result.columns:
|
||||
raise ValueError("DataFrame must have 'close' column")
|
||||
|
||||
close = result['close']
|
||||
high = result.get('high', close)
|
||||
low = result.get('low', close)
|
||||
volume = result.get('volume', pd.Series(1, index=close.index))
|
||||
|
||||
# Default indicators if none specified
|
||||
if indicators is None:
|
||||
indicators = [
|
||||
'sma_20', 'ema_20', 'rsi', 'macd', 'bollinger_bands',
|
||||
'atr', 'obv', 'adx'
|
||||
]
|
||||
|
||||
for indicator in indicators:
|
||||
try:
|
||||
if indicator.startswith('sma_'):
|
||||
period = int(indicator.split('_')[1])
|
||||
result[f'SMA_{period}'] = self.sma(close, period)
|
||||
elif indicator.startswith('ema_'):
|
||||
period = int(indicator.split('_')[1])
|
||||
result[f'EMA_{period}'] = self.ema(close, period)
|
||||
elif indicator == 'rsi':
|
||||
result['RSI'] = self.rsi(close)
|
||||
elif indicator == 'macd':
|
||||
macd_data = self.macd(close)
|
||||
result['MACD'] = macd_data['macd']
|
||||
result['MACD_Signal'] = macd_data['signal']
|
||||
result['MACD_Histogram'] = macd_data['histogram']
|
||||
elif indicator == 'bollinger_bands':
|
||||
bb_data = self.bollinger_bands(close)
|
||||
result['BB_Upper'] = bb_data['upper']
|
||||
result['BB_Middle'] = bb_data['middle']
|
||||
result['BB_Lower'] = bb_data['lower']
|
||||
elif indicator == 'atr':
|
||||
result['ATR'] = self.atr(high, low, close)
|
||||
elif indicator == 'obv':
|
||||
result['OBV'] = self.obv(close, volume)
|
||||
elif indicator == 'adx':
|
||||
result['ADX'] = self.adx(high, low, close)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate indicator {indicator}: {e}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Global indicators instance
|
||||
_indicators: Optional[TechnicalIndicators] = None
|
||||
|
||||
|
||||
def get_indicators() -> TechnicalIndicators:
|
||||
"""Get global technical indicators instance."""
|
||||
global _indicators
|
||||
if _indicators is None:
|
||||
_indicators = TechnicalIndicators()
|
||||
return _indicators
|
||||
|
||||
447
src/data/news_collector.py
Normal file
447
src/data/news_collector.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""News collector for crypto sentiment analysis.
|
||||
|
||||
Collects headlines from multiple sources:
|
||||
- RSS feeds (CoinDesk, CoinTelegraph, Decrypt, etc.)
|
||||
- CryptoPanic API (optional, requires API key)
|
||||
|
||||
Headlines are cached and refreshed periodically to avoid rate limits.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class NewsSource(str, Enum):
|
||||
"""Supported news sources."""
|
||||
COINDESK = "coindesk"
|
||||
COINTELEGRAPH = "cointelegraph"
|
||||
DECRYPT = "decrypt"
|
||||
BITCOIN_MAGAZINE = "bitcoin_magazine"
|
||||
THE_BLOCK = "the_block"
|
||||
MESSARI = "messari"
|
||||
CRYPTOPANIC = "cryptopanic"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsItem:
|
||||
"""A single news item."""
|
||||
title: str
|
||||
source: NewsSource
|
||||
published: datetime
|
||||
url: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
symbols: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# RSS feed URLs for major crypto news sources
|
||||
RSS_FEEDS: Dict[NewsSource, str] = {
|
||||
NewsSource.COINDESK: "https://www.coindesk.com/arc/outboundfeeds/rss/",
|
||||
NewsSource.COINTELEGRAPH: "https://cointelegraph.com/rss",
|
||||
NewsSource.DECRYPT: "https://decrypt.co/feed",
|
||||
NewsSource.BITCOIN_MAGAZINE: "https://bitcoinmagazine.com/.rss/full/",
|
||||
NewsSource.THE_BLOCK: "https://www.theblock.co/rss.xml",
|
||||
NewsSource.MESSARI: "https://messari.io/rss",
|
||||
}
|
||||
|
||||
# Common crypto symbols to detect in headlines
|
||||
CRYPTO_SYMBOLS = {
|
||||
"BTC": ["bitcoin", "btc"],
|
||||
"ETH": ["ethereum", "eth", "ether"],
|
||||
"SOL": ["solana", "sol"],
|
||||
"XRP": ["ripple", "xrp"],
|
||||
"ADA": ["cardano", "ada"],
|
||||
"DOGE": ["dogecoin", "doge"],
|
||||
"DOT": ["polkadot", "dot"],
|
||||
"AVAX": ["avalanche", "avax"],
|
||||
"MATIC": ["polygon", "matic"],
|
||||
"LINK": ["chainlink", "link"],
|
||||
"UNI": ["uniswap", "uni"],
|
||||
"ATOM": ["cosmos", "atom"],
|
||||
"LTC": ["litecoin", "ltc"],
|
||||
}
|
||||
|
||||
|
||||
class NewsCollector:
|
||||
"""Collects crypto news headlines for sentiment analysis.
|
||||
|
||||
Features:
|
||||
- Aggregates news from multiple RSS feeds
|
||||
- Caches headlines to reduce network requests
|
||||
- Filters by crypto symbol
|
||||
- Optional CryptoPanic integration for more sources
|
||||
|
||||
Usage:
|
||||
collector = NewsCollector()
|
||||
headlines = await collector.fetch_headlines()
|
||||
# Or filter by symbol
|
||||
btc_headlines = await collector.fetch_headlines(symbols=["BTC"])
|
||||
"""
|
||||
|
||||
# Minimum time between fetches (in seconds)
|
||||
MIN_FETCH_INTERVAL = 300 # 5 minutes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sources: Optional[List[NewsSource]] = None,
|
||||
cryptopanic_api_key: Optional[str] = None,
|
||||
cache_duration: int = 600, # 10 minutes
|
||||
max_headlines: int = 50
|
||||
):
|
||||
"""Initialize NewsCollector.
|
||||
|
||||
Args:
|
||||
sources: List of news sources to use. Defaults to all RSS feeds.
|
||||
cryptopanic_api_key: Optional API key for CryptoPanic.
|
||||
cache_duration: How long to cache headlines (seconds).
|
||||
max_headlines: Maximum headlines to keep in cache.
|
||||
"""
|
||||
self.sources = sources or list(RSS_FEEDS.keys())
|
||||
self.cryptopanic_api_key = cryptopanic_api_key
|
||||
self.cache_duration = cache_duration
|
||||
self.max_headlines = max_headlines
|
||||
|
||||
self._cache: List[NewsItem] = []
|
||||
self._last_fetch: Optional[datetime] = None
|
||||
self._fetching = False
|
||||
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Check if feedparser is available
|
||||
try:
|
||||
import feedparser
|
||||
self._feedparser = feedparser
|
||||
self._feedparser_available = True
|
||||
except ImportError:
|
||||
self._feedparser_available = False
|
||||
self.logger.warning(
|
||||
"feedparser not installed. Install with: pip install feedparser"
|
||||
)
|
||||
|
||||
def _extract_symbols(self, text: str) -> List[str]:
|
||||
"""Extract crypto symbols mentioned in text.
|
||||
|
||||
Args:
|
||||
text: Text to search (headline, summary)
|
||||
|
||||
Returns:
|
||||
List of detected symbol codes (e.g., ["BTC", "ETH"])
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
detected = []
|
||||
|
||||
for symbol, keywords in CRYPTO_SYMBOLS.items():
|
||||
for keyword in keywords:
|
||||
if keyword in text_lower:
|
||||
detected.append(symbol)
|
||||
break
|
||||
|
||||
return detected
|
||||
|
||||
async def _fetch_rss_feed(self, source: NewsSource) -> List[NewsItem]:
|
||||
"""Fetch and parse a single RSS feed.
|
||||
|
||||
Args:
|
||||
source: News source to fetch
|
||||
|
||||
Returns:
|
||||
List of NewsItems from the feed
|
||||
"""
|
||||
if not self._feedparser_available:
|
||||
return []
|
||||
|
||||
url = RSS_FEEDS.get(source)
|
||||
if not url:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Run feedparser in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
feed = await loop.run_in_executor(
|
||||
None,
|
||||
self._feedparser.parse,
|
||||
url
|
||||
)
|
||||
|
||||
items = []
|
||||
for entry in feed.entries[:20]: # Limit entries per feed
|
||||
# Parse publication date
|
||||
published = datetime.now()
|
||||
if hasattr(entry, 'published_parsed') and entry.published_parsed:
|
||||
try:
|
||||
published = datetime(*entry.published_parsed[:6])
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
title = entry.get('title', '')
|
||||
summary = entry.get('summary', '')
|
||||
|
||||
# Clean HTML from summary
|
||||
summary = re.sub(r'<[^>]+>', '', summary)[:200]
|
||||
|
||||
item = NewsItem(
|
||||
title=title,
|
||||
source=source,
|
||||
published=published,
|
||||
url=entry.get('link'),
|
||||
summary=summary,
|
||||
symbols=self._extract_symbols(f"{title} {summary}")
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
self.logger.debug(f"Fetched {len(items)} items from {source.value}")
|
||||
return items
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to fetch {source.value} RSS: {e}")
|
||||
return []
|
||||
|
||||
async def _fetch_cryptopanic(self, symbols: Optional[List[str]] = None) -> List[NewsItem]:
|
||||
"""Fetch news from CryptoPanic API.
|
||||
|
||||
Args:
|
||||
symbols: Optional list of symbols to filter
|
||||
|
||||
Returns:
|
||||
List of NewsItems from CryptoPanic
|
||||
"""
|
||||
if not self.cryptopanic_api_key:
|
||||
return []
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
url = "https://cryptopanic.com/api/v1/posts/"
|
||||
params = {
|
||||
"auth_token": self.cryptopanic_api_key,
|
||||
"public": "true",
|
||||
}
|
||||
|
||||
if symbols:
|
||||
params["currencies"] = ",".join(symbols)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, timeout=10) as response:
|
||||
if response.status != 200:
|
||||
self.logger.warning(f"CryptoPanic API error: {response.status}")
|
||||
return []
|
||||
|
||||
data = await response.json()
|
||||
|
||||
items = []
|
||||
for post in data.get("results", [])[:20]:
|
||||
published = datetime.now()
|
||||
if post.get("published_at"):
|
||||
try:
|
||||
published = datetime.fromisoformat(
|
||||
post["published_at"].replace("Z", "+00:00")
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
item = NewsItem(
|
||||
title=post.get("title", ""),
|
||||
source=NewsSource.CRYPTOPANIC,
|
||||
published=published,
|
||||
url=post.get("url"),
|
||||
symbols=[c["code"] for c in post.get("currencies", [])]
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
self.logger.debug(f"Fetched {len(items)} items from CryptoPanic")
|
||||
return items
|
||||
|
||||
except ImportError:
|
||||
self.logger.warning("aiohttp not installed for CryptoPanic API")
|
||||
return []
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to fetch CryptoPanic: {e}")
|
||||
return []
|
||||
|
||||
async def fetch_news(
|
||||
self,
|
||||
symbols: Optional[List[str]] = None,
|
||||
force_refresh: bool = False
|
||||
) -> List[NewsItem]:
|
||||
"""Fetch news items from all sources.
|
||||
|
||||
Args:
|
||||
symbols: Optional list of symbols to filter (e.g., ["BTC", "ETH"])
|
||||
force_refresh: Force a refresh even if cache is valid
|
||||
|
||||
Returns:
|
||||
List of NewsItems sorted by publication date (newest first)
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
# Check cache validity
|
||||
cache_valid = (
|
||||
self._last_fetch is not None and
|
||||
(now - self._last_fetch).total_seconds() < self.cache_duration and
|
||||
len(self._cache) > 0
|
||||
)
|
||||
|
||||
if cache_valid and not force_refresh:
|
||||
self.logger.debug("Using cached news items")
|
||||
items = self._cache
|
||||
else:
|
||||
# Prevent concurrent fetches
|
||||
if self._fetching:
|
||||
self.logger.debug("Fetch already in progress, using cache")
|
||||
items = self._cache
|
||||
else:
|
||||
self._fetching = True
|
||||
try:
|
||||
items = await self._fetch_all_sources()
|
||||
self._cache = items
|
||||
self._last_fetch = now
|
||||
finally:
|
||||
self._fetching = False
|
||||
|
||||
# Filter by symbols if specified
|
||||
if symbols:
|
||||
symbols_upper = [s.upper() for s in symbols]
|
||||
items = [
|
||||
item for item in items
|
||||
if any(s in symbols_upper for s in item.symbols) or not item.symbols
|
||||
]
|
||||
|
||||
return items
|
||||
|
||||
async def _fetch_all_sources(self) -> List[NewsItem]:
|
||||
"""Fetch from all configured sources concurrently."""
|
||||
tasks = []
|
||||
|
||||
# RSS feeds
|
||||
for source in self.sources:
|
||||
if source in RSS_FEEDS:
|
||||
tasks.append(self._fetch_rss_feed(source))
|
||||
|
||||
# CryptoPanic
|
||||
if self.cryptopanic_api_key:
|
||||
tasks.append(self._fetch_cryptopanic())
|
||||
|
||||
if not tasks:
|
||||
self.logger.warning("No news sources configured")
|
||||
return []
|
||||
|
||||
# Fetch all concurrently
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Combine results
|
||||
all_items = []
|
||||
for result in results:
|
||||
if isinstance(result, list):
|
||||
all_items.extend(result)
|
||||
elif isinstance(result, Exception):
|
||||
self.logger.warning(f"Source fetch failed: {result}")
|
||||
|
||||
# Sort by publication date (newest first)
|
||||
all_items.sort(key=lambda x: x.published, reverse=True)
|
||||
|
||||
# Limit total items
|
||||
all_items = all_items[:self.max_headlines]
|
||||
|
||||
self.logger.info(f"Fetched {len(all_items)} total news items")
|
||||
return all_items
|
||||
|
||||
async def fetch_headlines(
|
||||
self,
|
||||
symbols: Optional[List[str]] = None,
|
||||
max_age_hours: int = 24,
|
||||
force_refresh: bool = False
|
||||
) -> List[str]:
|
||||
"""Fetch headlines as strings for sentiment analysis.
|
||||
|
||||
This is the main method to use with SentimentScanner.
|
||||
|
||||
Args:
|
||||
symbols: Optional list of symbols to filter
|
||||
max_age_hours: Only include headlines from the last N hours
|
||||
force_refresh: Force a refresh even if cache is valid
|
||||
|
||||
Returns:
|
||||
List of headline strings
|
||||
"""
|
||||
items = await self.fetch_news(symbols=symbols, force_refresh=force_refresh)
|
||||
|
||||
# Filter by age
|
||||
cutoff = datetime.now() - timedelta(hours=max_age_hours)
|
||||
recent_items = [item for item in items if item.published > cutoff]
|
||||
|
||||
# Extract just the titles
|
||||
headlines = [item.title for item in recent_items if item.title]
|
||||
|
||||
self.logger.debug(f"Returning {len(headlines)} headlines for analysis")
|
||||
return headlines
|
||||
|
||||
def get_cached_headlines(self, symbols: Optional[List[str]] = None) -> List[str]:
|
||||
"""Get cached headlines synchronously (no fetch).
|
||||
|
||||
Args:
|
||||
symbols: Optional list of symbols to filter
|
||||
|
||||
Returns:
|
||||
List of cached headline strings
|
||||
"""
|
||||
items = self._cache
|
||||
|
||||
if symbols:
|
||||
symbols_upper = [s.upper() for s in symbols]
|
||||
items = [
|
||||
item for item in items
|
||||
if any(s in symbols_upper for s in item.symbols) or not item.symbols
|
||||
]
|
||||
|
||||
return [item.title for item in items if item.title]
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get collector status information.
|
||||
|
||||
Returns:
|
||||
Dictionary with status info
|
||||
"""
|
||||
return {
|
||||
"sources": [s.value for s in self.sources],
|
||||
"cryptopanic_enabled": self.cryptopanic_api_key is not None,
|
||||
"feedparser_available": self._feedparser_available,
|
||||
"cached_items": len(self._cache),
|
||||
"last_fetch": self._last_fetch.isoformat() if self._last_fetch else None,
|
||||
"cache_age_seconds": (
|
||||
(datetime.now() - self._last_fetch).total_seconds()
|
||||
if self._last_fetch else None
|
||||
),
|
||||
}
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the headline cache."""
|
||||
self._cache = []
|
||||
self._last_fetch = None
|
||||
self.logger.info("News cache cleared")
|
||||
|
||||
|
||||
# Global instance
|
||||
_news_collector: Optional[NewsCollector] = None
|
||||
|
||||
|
||||
def get_news_collector(**kwargs) -> NewsCollector:
|
||||
"""Get or create the global NewsCollector instance.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments passed to NewsCollector constructor
|
||||
|
||||
Returns:
|
||||
NewsCollector instance
|
||||
"""
|
||||
global _news_collector
|
||||
if _news_collector is None:
|
||||
_news_collector = NewsCollector(**kwargs)
|
||||
logger.info("Created global NewsCollector instance")
|
||||
return _news_collector
|
||||
406
src/data/pricing_service.py
Normal file
406
src/data/pricing_service.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""Unified pricing data service with multi-provider support and automatic failover."""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from .providers.base_provider import BasePricingProvider
|
||||
from .providers.ccxt_provider import CCXTProvider
|
||||
from .providers.coingecko_provider import CoinGeckoProvider
|
||||
from .cache_manager import CacheManager
|
||||
from .health_monitor import HealthMonitor, HealthStatus
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PricingService:
|
||||
"""Unified pricing data service with multi-provider support.
|
||||
|
||||
Manages multiple pricing providers with automatic failover, caching,
|
||||
and health monitoring. Provides a single consistent API for accessing
|
||||
market data regardless of the underlying provider.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize pricing service."""
|
||||
self.config = get_config()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize components
|
||||
self.cache = CacheManager(
|
||||
default_ttl=self.config.get("data_providers.caching.ohlcv_ttl", 60),
|
||||
max_size=self.config.get("data_providers.caching.max_cache_size", 1000),
|
||||
ticker_ttl=self.config.get("data_providers.caching.ticker_ttl", 2),
|
||||
ohlcv_ttl=self.config.get("data_providers.caching.ohlcv_ttl", 60),
|
||||
)
|
||||
|
||||
self.health_monitor = HealthMonitor()
|
||||
|
||||
# Provider instances
|
||||
self._providers: Dict[str, BasePricingProvider] = {}
|
||||
self._active_provider: Optional[str] = None
|
||||
self._provider_priority: List[str] = []
|
||||
|
||||
# Subscriptions
|
||||
self._subscriptions: Dict[str, List[Callable]] = {}
|
||||
|
||||
# Initialize providers
|
||||
self._initialize_providers()
|
||||
|
||||
def _initialize_providers(self):
|
||||
"""Initialize providers from configuration."""
|
||||
# Get primary providers from config
|
||||
primary_config = self.config.get("data_providers.primary", [])
|
||||
if not primary_config:
|
||||
# Default configuration
|
||||
primary_config = [
|
||||
{'name': 'kraken', 'enabled': True, 'priority': 1},
|
||||
{'name': 'coinbase', 'enabled': True, 'priority': 2},
|
||||
{'name': 'binance', 'enabled': True, 'priority': 3},
|
||||
]
|
||||
|
||||
# Sort by priority
|
||||
primary_config = sorted(
|
||||
[p for p in primary_config if p.get('enabled', True)],
|
||||
key=lambda x: x.get('priority', 999)
|
||||
)
|
||||
|
||||
# Create CCXT providers for each exchange
|
||||
for provider_config in primary_config:
|
||||
exchange_name = provider_config.get('name')
|
||||
try:
|
||||
provider = CCXTProvider(exchange_name=exchange_name)
|
||||
provider_name = provider.name
|
||||
|
||||
if provider.connect():
|
||||
self._providers[provider_name] = provider
|
||||
self._provider_priority.append(provider_name)
|
||||
self.logger.info(f"Initialized provider: {provider_name}")
|
||||
else:
|
||||
self.logger.warning(f"Failed to connect provider: {provider_name}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error initializing provider {exchange_name}: {e}")
|
||||
|
||||
# Add fallback provider (CoinGecko)
|
||||
fallback_config = self.config.get("data_providers.fallback", {})
|
||||
if fallback_config.get('enabled', True):
|
||||
try:
|
||||
coingecko = CoinGeckoProvider(api_key=fallback_config.get('api_key'))
|
||||
if coingecko.connect():
|
||||
self._providers[coingecko.name] = coingecko
|
||||
self._provider_priority.append(coingecko.name)
|
||||
self.logger.info(f"Initialized fallback provider: {coingecko.name}")
|
||||
else:
|
||||
self.logger.warning("Failed to connect CoinGecko fallback provider")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error initializing CoinGecko provider: {e}")
|
||||
|
||||
# Select initial active provider
|
||||
self._select_active_provider()
|
||||
|
||||
def _select_active_provider(self) -> Optional[str]:
|
||||
"""Select the best available provider.
|
||||
|
||||
Returns:
|
||||
Name of selected provider or None
|
||||
"""
|
||||
# Filter to healthy providers
|
||||
healthy_providers = [
|
||||
name for name in self._provider_priority
|
||||
if name in self._providers
|
||||
and self.health_monitor.is_healthy(name)
|
||||
]
|
||||
|
||||
if not healthy_providers:
|
||||
# Fall back to any available provider if none are healthy
|
||||
healthy_providers = list(self._providers.keys())
|
||||
|
||||
if not healthy_providers:
|
||||
self.logger.error("No providers available")
|
||||
self._active_provider = None
|
||||
return None
|
||||
|
||||
# Select first healthy provider (already sorted by priority)
|
||||
self._active_provider = healthy_providers[0]
|
||||
self.logger.info(f"Selected active provider: {self._active_provider}")
|
||||
return self._active_provider
|
||||
|
||||
def _get_provider(self, provider_name: Optional[str] = None) -> Optional[BasePricingProvider]:
|
||||
"""Get a provider instance.
|
||||
|
||||
Args:
|
||||
provider_name: Name of provider, or None to use active provider
|
||||
|
||||
Returns:
|
||||
Provider instance or None
|
||||
"""
|
||||
if provider_name:
|
||||
return self._providers.get(provider_name)
|
||||
|
||||
# Use active provider, or select one if none active
|
||||
if not self._active_provider:
|
||||
self._select_active_provider()
|
||||
|
||||
return self._providers.get(self._active_provider) if self._active_provider else None
|
||||
|
||||
def _execute_with_failover(
|
||||
self,
|
||||
operation: Callable[[BasePricingProvider], Any],
|
||||
operation_name: str
|
||||
) -> Any:
|
||||
"""Execute an operation with automatic failover.
|
||||
|
||||
Args:
|
||||
operation: Function that takes a provider and returns a result
|
||||
operation_name: Name of operation for logging
|
||||
|
||||
Returns:
|
||||
Operation result or None if all providers fail
|
||||
"""
|
||||
# Try active provider first
|
||||
providers_to_try = [self._active_provider] if self._active_provider else []
|
||||
|
||||
# Add other providers in priority order
|
||||
for provider_name in self._provider_priority:
|
||||
if provider_name != self._active_provider and provider_name in self._providers:
|
||||
providers_to_try.append(provider_name)
|
||||
|
||||
last_error = None
|
||||
|
||||
for provider_name in providers_to_try:
|
||||
provider = self._providers.get(provider_name)
|
||||
if not provider:
|
||||
continue
|
||||
|
||||
# Check health
|
||||
if not self.health_monitor.is_healthy(provider_name):
|
||||
self.logger.debug(f"Skipping unhealthy provider: {provider_name}")
|
||||
continue
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
result = operation(provider)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# Record success
|
||||
self.health_monitor.record_success(provider_name, response_time)
|
||||
|
||||
# Update active provider if we used a different one
|
||||
if provider_name != self._active_provider:
|
||||
self.logger.info(f"Switched to provider: {provider_name}")
|
||||
self._active_provider = provider_name
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.warning(f"{operation_name} failed on {provider_name}: {e}")
|
||||
self.health_monitor.record_failure(provider_name)
|
||||
|
||||
# Try next provider
|
||||
continue
|
||||
|
||||
# All providers failed
|
||||
self.logger.error(f"{operation_name} failed on all providers")
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return None
|
||||
|
||||
def get_ticker(self, symbol: str, use_cache: bool = True) -> Dict[str, Any]:
|
||||
"""Get current ticker data for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'BTC/USD')
|
||||
use_cache: Whether to use cache
|
||||
|
||||
Returns:
|
||||
Ticker data dictionary
|
||||
"""
|
||||
cache_key = f"ticker:{symbol}"
|
||||
|
||||
# Check cache
|
||||
if use_cache:
|
||||
cached = self.cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Fetch from provider
|
||||
def fetch_ticker(provider: BasePricingProvider):
|
||||
return provider.get_ticker(symbol)
|
||||
|
||||
ticker_data = self._execute_with_failover(fetch_ticker, f"get_ticker({symbol})")
|
||||
|
||||
if ticker_data:
|
||||
# Cache the result
|
||||
if use_cache:
|
||||
self.cache.set(cache_key, ticker_data, cache_type='ticker')
|
||||
|
||||
return ticker_data
|
||||
|
||||
return {}
|
||||
|
||||
def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100,
|
||||
use_cache: bool = True
|
||||
) -> List[List]:
|
||||
"""Get OHLCV candlestick data.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
|
||||
since: Start datetime
|
||||
limit: Number of candles
|
||||
use_cache: Whether to use cache
|
||||
|
||||
Returns:
|
||||
List of [timestamp_ms, open, high, low, close, volume]
|
||||
"""
|
||||
cache_key = f"ohlcv:{symbol}:{timeframe}:{limit}"
|
||||
|
||||
# Check cache (only if no 'since' parameter, as it changes the result)
|
||||
if use_cache and not since:
|
||||
cached = self.cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Fetch from provider
|
||||
def fetch_ohlcv(provider: BasePricingProvider):
|
||||
return provider.get_ohlcv(symbol, timeframe, since, limit)
|
||||
|
||||
ohlcv_data = self._execute_with_failover(
|
||||
fetch_ohlcv,
|
||||
f"get_ohlcv({symbol}, {timeframe})"
|
||||
)
|
||||
|
||||
if ohlcv_data:
|
||||
# Cache the result (only if no 'since' parameter)
|
||||
if use_cache and not since:
|
||||
self.cache.set(cache_key, ohlcv_data, cache_type='ohlcv')
|
||||
|
||||
return ohlcv_data
|
||||
|
||||
return []
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
|
||||
"""Subscribe to ticker updates.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Callback function(data) called on price updates
|
||||
|
||||
Returns:
|
||||
True if subscription successful
|
||||
"""
|
||||
key = f"ticker:{symbol}"
|
||||
|
||||
# Add callback
|
||||
if key not in self._subscriptions:
|
||||
self._subscriptions[key] = []
|
||||
if callback not in self._subscriptions[key]:
|
||||
self._subscriptions[key].append(callback)
|
||||
|
||||
# Wrap callback to handle failover
|
||||
def wrapped_callback(data):
|
||||
for cb in self._subscriptions.get(key, []):
|
||||
try:
|
||||
cb(data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
# Subscribe via active provider
|
||||
provider = self._get_provider()
|
||||
if provider:
|
||||
try:
|
||||
success = provider.subscribe_ticker(symbol, wrapped_callback)
|
||||
if success:
|
||||
self.logger.info(f"Subscribed to ticker updates for {symbol}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to subscribe to ticker for {symbol}: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
|
||||
"""Unsubscribe from ticker updates.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Specific callback to remove, or None to remove all
|
||||
"""
|
||||
key = f"ticker:{symbol}"
|
||||
|
||||
# Remove callback
|
||||
if key in self._subscriptions:
|
||||
if callback:
|
||||
if callback in self._subscriptions[key]:
|
||||
self._subscriptions[key].remove(callback)
|
||||
if not self._subscriptions[key]:
|
||||
del self._subscriptions[key]
|
||||
else:
|
||||
del self._subscriptions[key]
|
||||
|
||||
# Unsubscribe from all providers
|
||||
for provider in self._providers.values():
|
||||
try:
|
||||
provider.unsubscribe_ticker(symbol, callback)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.logger.info(f"Unsubscribed from ticker updates for {symbol}")
|
||||
|
||||
def get_active_provider(self) -> Optional[str]:
|
||||
"""Get name of active provider.
|
||||
|
||||
Returns:
|
||||
Provider name or None
|
||||
"""
|
||||
return self._active_provider
|
||||
|
||||
def get_provider_health(self, provider_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get health status for a provider or all providers.
|
||||
|
||||
Args:
|
||||
provider_name: Provider name, or None for all providers
|
||||
|
||||
Returns:
|
||||
Health status dictionary
|
||||
"""
|
||||
if provider_name:
|
||||
metrics = self.health_monitor.get_metrics(provider_name)
|
||||
if metrics:
|
||||
return metrics.to_dict()
|
||||
return {}
|
||||
|
||||
return self.health_monitor.get_all_metrics()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Cache statistics dictionary
|
||||
"""
|
||||
return self.cache.get_stats()
|
||||
|
||||
|
||||
# Global pricing service instance
|
||||
_pricing_service: Optional[PricingService] = None
|
||||
|
||||
|
||||
def get_pricing_service() -> PricingService:
|
||||
"""Get global pricing service instance.
|
||||
|
||||
Returns:
|
||||
PricingService instance
|
||||
"""
|
||||
global _pricing_service
|
||||
if _pricing_service is None:
|
||||
_pricing_service = PricingService()
|
||||
return _pricing_service
|
||||
7
src/data/providers/__init__.py
Normal file
7
src/data/providers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Pricing data providers package."""
|
||||
|
||||
from .base_provider import BasePricingProvider
|
||||
from .ccxt_provider import CCXTProvider
|
||||
from .coingecko_provider import CoinGeckoProvider
|
||||
|
||||
__all__ = ['BasePricingProvider', 'CCXTProvider', 'CoinGeckoProvider']
|
||||
150
src/data/providers/base_provider.py
Normal file
150
src/data/providers/base_provider.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Base pricing provider interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from datetime import datetime
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BasePricingProvider(ABC):
|
||||
"""Base class for pricing data providers.
|
||||
|
||||
Pricing providers are responsible for fetching market data (prices, OHLCV)
|
||||
without requiring API keys. They differ from exchange adapters which handle
|
||||
trading operations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize pricing provider."""
|
||||
self.logger = get_logger(f"provider.{self.__class__.__name__}")
|
||||
self._connected = False
|
||||
self._subscribers: Dict[str, List[Callable]] = {}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_websocket(self) -> bool:
|
||||
"""Whether this provider supports WebSocket connections."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def connect(self) -> bool:
|
||||
"""Connect to the provider.
|
||||
|
||||
Returns:
|
||||
True if connection successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def disconnect(self):
|
||||
"""Disconnect from the provider."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'BTC/USD')
|
||||
|
||||
Returns:
|
||||
Dictionary with ticker data:
|
||||
- 'symbol': str
|
||||
- 'bid': Decimal
|
||||
- 'ask': Decimal
|
||||
- 'last': Decimal (last price)
|
||||
- 'high': Decimal (24h high)
|
||||
- 'low': Decimal (24h low)
|
||||
- 'volume': Decimal (24h volume)
|
||||
- 'timestamp': int (Unix timestamp in milliseconds)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV (candlestick) data.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
|
||||
since: Start datetime
|
||||
limit: Number of candles
|
||||
|
||||
Returns:
|
||||
List of [timestamp_ms, open, high, low, close, volume]
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
|
||||
"""Subscribe to ticker updates.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Callback function(data) called on price updates
|
||||
|
||||
Returns:
|
||||
True if subscription successful
|
||||
"""
|
||||
pass
|
||||
|
||||
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
|
||||
"""Unsubscribe from ticker updates.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Specific callback to remove, or None to remove all
|
||||
"""
|
||||
key = f"ticker_{symbol}"
|
||||
if key in self._subscribers:
|
||||
if callback:
|
||||
if callback in self._subscribers[key]:
|
||||
self._subscribers[key].remove(callback)
|
||||
if not self._subscribers[key]:
|
||||
del self._subscribers[key]
|
||||
else:
|
||||
del self._subscribers[key]
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol format for this provider.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to normalize
|
||||
|
||||
Returns:
|
||||
Normalized symbol
|
||||
"""
|
||||
# Default: uppercase and replace dashes with slashes
|
||||
return symbol.upper().replace('-', '/')
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if provider is connected.
|
||||
|
||||
Returns:
|
||||
True if connected
|
||||
"""
|
||||
return self._connected
|
||||
|
||||
def get_supported_symbols(self) -> List[str]:
|
||||
"""Get list of supported trading symbols.
|
||||
|
||||
Returns:
|
||||
List of supported symbols, or empty list if not available
|
||||
"""
|
||||
# Default implementation - override in subclasses
|
||||
return []
|
||||
333
src/data/providers/ccxt_provider.py
Normal file
333
src/data/providers/ccxt_provider.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""CCXT-based pricing provider with multi-exchange support and WebSocket capabilities."""
|
||||
|
||||
import ccxt
|
||||
import threading
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from datetime import datetime
|
||||
from .base_provider import BasePricingProvider
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CCXTProvider(BasePricingProvider):
|
||||
"""CCXT-based pricing provider with multi-exchange fallback support.
|
||||
|
||||
This provider uses CCXT to connect to multiple exchanges (Kraken, Coinbase, Binance)
|
||||
as primary data sources. It supports WebSocket where available and falls back to
|
||||
polling if WebSocket is not available.
|
||||
"""
|
||||
|
||||
def __init__(self, exchange_name: Optional[str] = None):
|
||||
"""Initialize CCXT provider.
|
||||
|
||||
Args:
|
||||
exchange_name: Specific exchange to use ('kraken', 'coinbase', 'binance'),
|
||||
or None to try multiple exchanges
|
||||
"""
|
||||
super().__init__()
|
||||
self.exchange_name = exchange_name
|
||||
self.exchange = None
|
||||
self._selected_exchange_id = None
|
||||
self._polling_threads: Dict[str, threading.Thread] = {}
|
||||
self._stop_polling: Dict[str, bool] = {}
|
||||
|
||||
# Exchange priority order (try first to last)
|
||||
if exchange_name:
|
||||
self._exchange_options = [(exchange_name.lower(), exchange_name.capitalize())]
|
||||
else:
|
||||
self._exchange_options = [
|
||||
('kraken', 'Kraken'),
|
||||
('coinbase', 'Coinbase'),
|
||||
('binance', 'Binance'),
|
||||
]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Provider name."""
|
||||
if self._selected_exchange_id:
|
||||
return f"CCXT-{self._selected_exchange_id.capitalize()}"
|
||||
return "CCXT Provider"
|
||||
|
||||
@property
|
||||
def supports_websocket(self) -> bool:
|
||||
"""Check if current exchange supports WebSocket."""
|
||||
if not self.exchange:
|
||||
return False
|
||||
|
||||
# Check if exchange has WebSocket support
|
||||
# Most modern CCXT exchanges support WebSocket, but we check explicitly
|
||||
exchange_id = getattr(self.exchange, 'id', '').lower()
|
||||
|
||||
# Known WebSocket-capable exchanges
|
||||
ws_capable = ['kraken', 'coinbase', 'binance', 'binanceus', 'okx']
|
||||
|
||||
return exchange_id in ws_capable
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect to an exchange via CCXT.
|
||||
|
||||
Tries multiple exchanges in order until one succeeds.
|
||||
|
||||
Returns:
|
||||
True if connection successful
|
||||
"""
|
||||
for exchange_id, exchange_display_name in self._exchange_options:
|
||||
try:
|
||||
# Get exchange class from CCXT
|
||||
exchange_class = getattr(ccxt, exchange_id, None)
|
||||
if not exchange_class:
|
||||
logger.warning(f"Exchange {exchange_id} not found in CCXT")
|
||||
continue
|
||||
|
||||
# Create exchange instance
|
||||
self.exchange = exchange_class({
|
||||
'enableRateLimit': True,
|
||||
'options': {
|
||||
'defaultType': 'spot',
|
||||
}
|
||||
})
|
||||
|
||||
# Load markets to test connection
|
||||
if not hasattr(self.exchange, 'markets') or not self.exchange.markets:
|
||||
try:
|
||||
self.exchange.load_markets()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load markets for {exchange_id}: {e}")
|
||||
continue
|
||||
|
||||
# Test connection with a common symbol
|
||||
test_symbols = ['BTC/USDT', 'BTC/USD', 'BTC/EUR']
|
||||
ticker_result = None
|
||||
|
||||
for test_symbol in test_symbols:
|
||||
try:
|
||||
ticker_result = self.exchange.fetch_ticker(test_symbol)
|
||||
if ticker_result:
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not ticker_result:
|
||||
logger.warning(f"Could not fetch ticker from {exchange_id}")
|
||||
continue
|
||||
|
||||
# Success!
|
||||
self._selected_exchange_id = exchange_id
|
||||
self._connected = True
|
||||
logger.info(f"Connected to {exchange_display_name} via CCXT")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to connect to {exchange_id}: {e}")
|
||||
continue
|
||||
|
||||
# All exchanges failed
|
||||
logger.error("Failed to connect to any CCXT exchange")
|
||||
self._connected = False
|
||||
self.exchange = None
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from exchange."""
|
||||
# Stop all polling threads
|
||||
for symbol in list(self._stop_polling.keys()):
|
||||
self._stop_polling[symbol] = True
|
||||
|
||||
# Wait a bit for threads to stop
|
||||
for symbol, thread in self._polling_threads.items():
|
||||
if thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
self._polling_threads.clear()
|
||||
self._stop_polling.clear()
|
||||
self._subscribers.clear()
|
||||
self._connected = False
|
||||
self.exchange = None
|
||||
self._selected_exchange_id = None
|
||||
logger.info(f"Disconnected from {self.name}")
|
||||
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data."""
|
||||
if not self._connected or not self.exchange:
|
||||
logger.error("Provider not connected")
|
||||
return {}
|
||||
|
||||
try:
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
ticker = self.exchange.fetch_ticker(normalized_symbol)
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': Decimal(str(ticker.get('bid', 0) or 0)),
|
||||
'ask': Decimal(str(ticker.get('ask', 0) or 0)),
|
||||
'last': Decimal(str(ticker.get('last', 0) or ticker.get('close', 0) or 0)),
|
||||
'high': Decimal(str(ticker.get('high', 0) or 0)),
|
||||
'low': Decimal(str(ticker.get('low', 0) or 0)),
|
||||
'volume': Decimal(str(ticker.get('quoteVolume', ticker.get('volume', 0)) or 0)),
|
||||
'timestamp': ticker.get('timestamp', int(time.time() * 1000)),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ticker for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV candlestick data."""
|
||||
if not self._connected or not self.exchange:
|
||||
logger.error("Provider not connected")
|
||||
return []
|
||||
|
||||
try:
|
||||
since_timestamp = int(since.timestamp() * 1000) if since else None
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
|
||||
# Most exchanges support up to 1000 candles per request
|
||||
max_limit = min(limit, 1000)
|
||||
|
||||
ohlcv = self.exchange.fetch_ohlcv(
|
||||
normalized_symbol,
|
||||
timeframe,
|
||||
since_timestamp,
|
||||
max_limit
|
||||
)
|
||||
|
||||
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} ({timeframe})")
|
||||
return ohlcv
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
|
||||
"""Subscribe to ticker updates.
|
||||
|
||||
Uses WebSocket if available, otherwise falls back to polling.
|
||||
"""
|
||||
if not self._connected or not self.exchange:
|
||||
logger.error("Provider not connected")
|
||||
return False
|
||||
|
||||
key = f"ticker_{symbol}"
|
||||
|
||||
# Add callback to subscribers
|
||||
if key not in self._subscribers:
|
||||
self._subscribers[key] = []
|
||||
if callback not in self._subscribers[key]:
|
||||
self._subscribers[key].append(callback)
|
||||
|
||||
# Try WebSocket first if supported
|
||||
if self.supports_websocket:
|
||||
try:
|
||||
# CCXT WebSocket support varies by exchange
|
||||
# For now, use polling as fallback since WebSocket implementation
|
||||
# in CCXT requires exchange-specific handling
|
||||
# TODO: Implement native WebSocket when CCXT adds better support
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket subscription failed, falling back to polling: {e}")
|
||||
|
||||
# Use polling as primary method (more reliable across exchanges)
|
||||
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
|
||||
self._stop_polling[key] = False
|
||||
|
||||
def poll_ticker():
|
||||
"""Poll ticker every 2 seconds."""
|
||||
while not self._stop_polling.get(key, False):
|
||||
try:
|
||||
ticker_data = self.get_ticker(symbol)
|
||||
if ticker_data and 'last' in ticker_data:
|
||||
# Call all callbacks for this symbol
|
||||
for cb in self._subscribers.get(key, []):
|
||||
try:
|
||||
cb({
|
||||
'symbol': symbol,
|
||||
'price': ticker_data['last'],
|
||||
'bid': ticker_data.get('bid', 0),
|
||||
'ask': ticker_data.get('ask', 0),
|
||||
'volume': ticker_data.get('volume', 0),
|
||||
'timestamp': ticker_data.get('timestamp'),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
time.sleep(2) # Poll every 2 seconds
|
||||
except Exception as e:
|
||||
logger.error(f"Ticker polling error for {symbol}: {e}")
|
||||
time.sleep(5) # Wait longer on error
|
||||
|
||||
thread = threading.Thread(target=poll_ticker, daemon=True)
|
||||
thread.start()
|
||||
self._polling_threads[key] = thread
|
||||
logger.info(f"Subscribed to ticker updates for {symbol} (polling mode)")
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
|
||||
"""Unsubscribe from ticker updates."""
|
||||
key = f"ticker_{symbol}"
|
||||
|
||||
# Stop polling thread
|
||||
if key in self._stop_polling:
|
||||
self._stop_polling[key] = True
|
||||
|
||||
# Remove from subscribers
|
||||
super().unsubscribe_ticker(symbol, callback)
|
||||
|
||||
# Clean up thread reference
|
||||
if key in self._polling_threads:
|
||||
del self._polling_threads[key]
|
||||
|
||||
logger.info(f"Unsubscribed from ticker updates for {symbol}")
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol for the selected exchange."""
|
||||
if not self.exchange:
|
||||
return symbol.replace('-', '/').upper()
|
||||
|
||||
# Basic normalization
|
||||
normalized = symbol.replace('-', '/').upper()
|
||||
|
||||
# Try to use exchange's markets to find correct symbol
|
||||
try:
|
||||
if hasattr(self.exchange, 'markets') and self.exchange.markets:
|
||||
# Check if symbol exists
|
||||
if normalized in self.exchange.markets:
|
||||
return normalized
|
||||
|
||||
# Try alternative formats
|
||||
if '/USD' in normalized:
|
||||
alt_symbol = normalized.replace('/USD', '/USDT')
|
||||
if alt_symbol in self.exchange.markets:
|
||||
return alt_symbol
|
||||
|
||||
# For Kraken: try XBT instead of BTC
|
||||
if normalized.startswith('BTC/'):
|
||||
alt_symbol = normalized.replace('BTC/', 'XBT/')
|
||||
if alt_symbol in self.exchange.markets:
|
||||
return alt_symbol
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: return normalized (let CCXT handle errors)
|
||||
return normalized
|
||||
|
||||
def get_supported_symbols(self) -> List[str]:
|
||||
"""Get list of supported trading symbols."""
|
||||
if not self._connected or not self.exchange:
|
||||
return []
|
||||
|
||||
try:
|
||||
markets = self.exchange.load_markets()
|
||||
return list(markets.keys())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get supported symbols: {e}")
|
||||
return []
|
||||
376
src/data/providers/coingecko_provider.py
Normal file
376
src/data/providers/coingecko_provider.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""CoinGecko pricing provider for fallback market data."""
|
||||
|
||||
import httpx
|
||||
import time
|
||||
import threading
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Callable, Tuple
|
||||
from datetime import datetime
|
||||
from .base_provider import BasePricingProvider
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CoinGeckoProvider(BasePricingProvider):
|
||||
"""CoinGecko API pricing provider.
|
||||
|
||||
This provider uses CoinGecko's free API tier as a fallback when CCXT
|
||||
providers are unavailable. It uses simple REST endpoints that don't
|
||||
require authentication.
|
||||
"""
|
||||
|
||||
BASE_URL = "https://api.coingecko.com/api/v3"
|
||||
|
||||
# CoinGecko coin ID mapping for common symbols
|
||||
COIN_ID_MAP = {
|
||||
'BTC': 'bitcoin',
|
||||
'ETH': 'ethereum',
|
||||
'BNB': 'binancecoin',
|
||||
'SOL': 'solana',
|
||||
'ADA': 'cardano',
|
||||
'XRP': 'ripple',
|
||||
'DOGE': 'dogecoin',
|
||||
'DOT': 'polkadot',
|
||||
'MATIC': 'matic-network',
|
||||
'AVAX': 'avalanche-2',
|
||||
'LINK': 'chainlink',
|
||||
'USDT': 'tether',
|
||||
'USDC': 'usd-coin',
|
||||
'DAI': 'dai',
|
||||
}
|
||||
|
||||
# Currency mapping
|
||||
CURRENCY_MAP = {
|
||||
'USD': 'usd',
|
||||
'EUR': 'eur',
|
||||
'GBP': 'gbp',
|
||||
'JPY': 'jpy',
|
||||
'USDT': 'usd', # CoinGecko uses USD for stablecoins
|
||||
}
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""Initialize CoinGecko provider.
|
||||
|
||||
Args:
|
||||
api_key: Optional API key for higher rate limits (free tier doesn't require it)
|
||||
"""
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self._client = None
|
||||
self._polling_threads: Dict[str, threading.Thread] = {}
|
||||
self._stop_polling: Dict[str, bool] = {}
|
||||
self._rate_limit_delay = 1.0 # Free tier: ~30-50 calls/minute
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Provider name."""
|
||||
return "CoinGecko"
|
||||
|
||||
@property
|
||||
def supports_websocket(self) -> bool:
|
||||
"""CoinGecko free tier doesn't support WebSocket."""
|
||||
return False
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect to CoinGecko API.
|
||||
|
||||
Returns:
|
||||
True if connection successful (just validates API access)
|
||||
"""
|
||||
try:
|
||||
self._client = httpx.Client(timeout=10.0)
|
||||
|
||||
# Test connection by fetching Bitcoin price
|
||||
response = self._client.get(
|
||||
f"{self.BASE_URL}/simple/price",
|
||||
params={
|
||||
'ids': 'bitcoin',
|
||||
'vs_currencies': 'usd',
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if 'bitcoin' in data:
|
||||
self._connected = True
|
||||
logger.info("Connected to CoinGecko API")
|
||||
return True
|
||||
|
||||
logger.warning("CoinGecko API test failed")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to CoinGecko: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from CoinGecko API."""
|
||||
# Stop all polling threads
|
||||
for symbol in list(self._stop_polling.keys()):
|
||||
self._stop_polling[symbol] = True
|
||||
|
||||
# Wait for threads to stop
|
||||
for symbol, thread in self._polling_threads.items():
|
||||
if thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
self._polling_threads.clear()
|
||||
self._stop_polling.clear()
|
||||
self._subscribers.clear()
|
||||
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info("Disconnected from CoinGecko API")
|
||||
|
||||
def _parse_symbol(self, symbol: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Parse symbol into coin_id and currency.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair like 'BTC/USD' or 'BTC/USDT'
|
||||
|
||||
Returns:
|
||||
Tuple of (coin_id, currency) or (None, None) if not found
|
||||
"""
|
||||
parts = symbol.upper().replace('-', '/').split('/')
|
||||
if len(parts) != 2:
|
||||
return None, None
|
||||
|
||||
base, quote = parts
|
||||
|
||||
# Get coin ID
|
||||
coin_id = self.COIN_ID_MAP.get(base)
|
||||
if not coin_id:
|
||||
# Try lowercase base as fallback
|
||||
coin_id = base.lower()
|
||||
|
||||
# Get currency
|
||||
currency = self.CURRENCY_MAP.get(quote, quote.lower())
|
||||
|
||||
return coin_id, currency
|
||||
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data from CoinGecko."""
|
||||
if not self._connected or not self._client:
|
||||
logger.error("Provider not connected")
|
||||
return {}
|
||||
|
||||
try:
|
||||
coin_id, currency = self._parse_symbol(symbol)
|
||||
if not coin_id:
|
||||
logger.error(f"Unknown symbol format: {symbol}")
|
||||
return {}
|
||||
|
||||
# Fetch current price
|
||||
response = self._client.get(
|
||||
f"{self.BASE_URL}/simple/price",
|
||||
params={
|
||||
'ids': coin_id,
|
||||
'vs_currencies': currency,
|
||||
'include_24hr_change': 'true',
|
||||
'include_24hr_vol': 'true',
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"CoinGecko API error: {response.status_code}")
|
||||
return {}
|
||||
|
||||
data = response.json()
|
||||
if coin_id not in data:
|
||||
logger.error(f"Coin {coin_id} not found in CoinGecko response")
|
||||
return {}
|
||||
|
||||
coin_data = data[coin_id]
|
||||
price_key = currency.lower()
|
||||
|
||||
if price_key not in coin_data:
|
||||
logger.error(f"Currency {currency} not found for {coin_id}")
|
||||
return {}
|
||||
|
||||
price = Decimal(str(coin_data[price_key]))
|
||||
|
||||
# Calculate high/low from 24h change if available
|
||||
change_key = f"{price_key}_24h_change"
|
||||
vol_key = f"{price_key}_24h_vol"
|
||||
|
||||
change_24h = coin_data.get(change_key, 0) or 0
|
||||
volume_24h = coin_data.get(vol_key, 0) or 0
|
||||
|
||||
# Estimate high/low from current price and 24h change
|
||||
# This is approximate since CoinGecko free tier doesn't provide exact high/low
|
||||
current_price = float(price)
|
||||
if change_24h:
|
||||
estimated_high = current_price * (1 + abs(change_24h / 100) / 2)
|
||||
estimated_low = current_price * (1 - abs(change_24h / 100) / 2)
|
||||
else:
|
||||
estimated_high = current_price
|
||||
estimated_low = current_price
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': price, # CoinGecko doesn't provide bid/ask, use last price
|
||||
'ask': price,
|
||||
'last': price,
|
||||
'high': Decimal(str(estimated_high)),
|
||||
'low': Decimal(str(estimated_low)),
|
||||
'volume': Decimal(str(volume_24h)),
|
||||
'timestamp': int(time.time() * 1000),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ticker for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV data from CoinGecko.
|
||||
|
||||
Note: CoinGecko free tier has limited historical data access.
|
||||
This method may return empty data for some timeframes.
|
||||
"""
|
||||
if not self._connected or not self._client:
|
||||
logger.error("Provider not connected")
|
||||
return []
|
||||
|
||||
try:
|
||||
coin_id, currency = self._parse_symbol(symbol)
|
||||
if not coin_id:
|
||||
logger.error(f"Unknown symbol format: {symbol}")
|
||||
return []
|
||||
|
||||
# CoinGecko uses days parameter instead of timeframe
|
||||
# Map timeframe to days
|
||||
timeframe_days_map = {
|
||||
'1m': 1, # Last 24 hours, minute data (not available in free tier)
|
||||
'5m': 1,
|
||||
'15m': 1,
|
||||
'30m': 1,
|
||||
'1h': 1, # Last 24 hours, hourly data
|
||||
'4h': 7, # Last 7 days
|
||||
'1d': 30, # Last 30 days
|
||||
'1w': 90, # Last 90 days
|
||||
}
|
||||
|
||||
days = timeframe_days_map.get(timeframe, 7)
|
||||
|
||||
# Fetch OHLC data
|
||||
response = self._client.get(
|
||||
f"{self.BASE_URL}/coins/{coin_id}/ohlc",
|
||||
params={
|
||||
'vs_currency': currency,
|
||||
'days': days,
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"CoinGecko OHLC API returned {response.status_code}")
|
||||
return []
|
||||
|
||||
data = response.json()
|
||||
|
||||
# CoinGecko returns: [timestamp_ms, open, high, low, close]
|
||||
# We need to convert to: [timestamp_ms, open, high, low, close, volume]
|
||||
# Note: CoinGecko OHLC endpoint doesn't include volume
|
||||
|
||||
# Filter by since if provided
|
||||
if since:
|
||||
since_timestamp = int(since.timestamp() * 1000)
|
||||
data = [candle for candle in data if candle[0] >= since_timestamp]
|
||||
|
||||
# Limit results
|
||||
data = data[-limit:] if limit else data
|
||||
|
||||
# Add volume as 0 (CoinGecko doesn't provide it in OHLC endpoint)
|
||||
ohlcv = [candle + [0] for candle in data]
|
||||
|
||||
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} from CoinGecko")
|
||||
return ohlcv
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
|
||||
"""Subscribe to ticker updates via polling."""
|
||||
if not self._connected or not self._client:
|
||||
logger.error("Provider not connected")
|
||||
return False
|
||||
|
||||
key = f"ticker_{symbol}"
|
||||
|
||||
# Add callback to subscribers
|
||||
if key not in self._subscribers:
|
||||
self._subscribers[key] = []
|
||||
if callback not in self._subscribers[key]:
|
||||
self._subscribers[key].append(callback)
|
||||
|
||||
# Start polling thread if not already running
|
||||
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
|
||||
self._stop_polling[key] = False
|
||||
|
||||
def poll_ticker():
|
||||
"""Poll ticker respecting rate limits."""
|
||||
while not self._stop_polling.get(key, False):
|
||||
try:
|
||||
ticker_data = self.get_ticker(symbol)
|
||||
if ticker_data and 'last' in ticker_data:
|
||||
# Call all callbacks
|
||||
for cb in self._subscribers.get(key, []):
|
||||
try:
|
||||
cb({
|
||||
'symbol': symbol,
|
||||
'price': ticker_data['last'],
|
||||
'bid': ticker_data.get('bid', 0),
|
||||
'ask': ticker_data.get('ask', 0),
|
||||
'volume': ticker_data.get('volume', 0),
|
||||
'timestamp': ticker_data.get('timestamp'),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
# Rate limit: wait between requests (free tier: ~30-50 calls/min)
|
||||
time.sleep(self._rate_limit_delay * 2) # Poll every 2 seconds
|
||||
except Exception as e:
|
||||
logger.error(f"Ticker polling error for {symbol}: {e}")
|
||||
time.sleep(10) # Wait longer on error
|
||||
|
||||
thread = threading.Thread(target=poll_ticker, daemon=True)
|
||||
thread.start()
|
||||
self._polling_threads[key] = thread
|
||||
logger.info(f"Subscribed to ticker updates for {symbol} (CoinGecko polling)")
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
|
||||
"""Unsubscribe from ticker updates."""
|
||||
key = f"ticker_{symbol}"
|
||||
|
||||
# Stop polling
|
||||
if key in self._stop_polling:
|
||||
self._stop_polling[key] = True
|
||||
|
||||
# Remove from subscribers
|
||||
super().unsubscribe_ticker(symbol, callback)
|
||||
|
||||
# Clean up thread
|
||||
if key in self._polling_threads:
|
||||
del self._polling_threads[key]
|
||||
|
||||
logger.info(f"Unsubscribed from ticker updates for {symbol}")
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol format."""
|
||||
# CoinGecko uses coin IDs, so we just normalize the format
|
||||
return symbol.replace('-', '/').upper()
|
||||
116
src/data/quality.py
Normal file
116
src/data/quality.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Data quality validation, gap filling, and retention policies."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, MarketData
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataQualityManager:
|
||||
"""Manages data quality and retention."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data quality manager."""
|
||||
self.db = get_database()
|
||||
self.config = get_config()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def validate_data_quality(
|
||||
self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate data quality.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Returns:
|
||||
Quality report
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
data = session.query(MarketData).filter(
|
||||
MarketData.exchange == exchange,
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe,
|
||||
MarketData.timestamp >= start_date,
|
||||
MarketData.timestamp <= end_date
|
||||
).order_by(MarketData.timestamp).all()
|
||||
|
||||
if len(data) == 0:
|
||||
return {"valid": False, "reason": "No data"}
|
||||
|
||||
# Check for gaps
|
||||
gaps = self._detect_gaps(data, timeframe)
|
||||
|
||||
# Check for anomalies
|
||||
anomalies = self._detect_anomalies(data)
|
||||
|
||||
return {
|
||||
"valid": len(gaps) == 0 and len(anomalies) == 0,
|
||||
"total_records": len(data),
|
||||
"gaps": len(gaps),
|
||||
"anomalies": len(anomalies),
|
||||
}
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def _detect_gaps(self, data: List[MarketData], timeframe: str) -> List[datetime]:
|
||||
"""Detect gaps in data.
|
||||
|
||||
Args:
|
||||
data: List of market data
|
||||
timeframe: Timeframe
|
||||
|
||||
Returns:
|
||||
List of gap timestamps
|
||||
"""
|
||||
gaps = []
|
||||
# Simplified gap detection
|
||||
return gaps
|
||||
|
||||
def _detect_anomalies(self, data: List[MarketData]) -> List[int]:
|
||||
"""Detect data anomalies.
|
||||
|
||||
Args:
|
||||
data: List of market data
|
||||
|
||||
Returns:
|
||||
List of anomaly indices
|
||||
"""
|
||||
anomalies = []
|
||||
# Simplified anomaly detection
|
||||
return anomalies
|
||||
|
||||
def cleanup_old_data(self, days_to_keep: int = 365):
|
||||
"""Clean up old data based on retention policy.
|
||||
|
||||
Args:
|
||||
days_to_keep: Days of data to keep
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_to_keep)
|
||||
deleted = session.query(MarketData).filter(
|
||||
MarketData.timestamp < cutoff_date
|
||||
).delete()
|
||||
session.commit()
|
||||
logger.info(f"Cleaned up {deleted} old data records")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to cleanup old data: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
225
src/data/redis_cache.py
Normal file
225
src/data/redis_cache.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Redis-based caching for market data and API responses."""
|
||||
|
||||
from typing import Any, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
from src.core.redis import get_redis_client
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisCache:
|
||||
"""Redis-based cache for market data and API responses."""
|
||||
|
||||
# Default TTL values (seconds)
|
||||
TTL_TICKER = 5 # Ticker prices are very volatile
|
||||
TTL_OHLCV = 60 # OHLCV can be cached longer
|
||||
TTL_ORDERBOOK = 2 # Order books change rapidly
|
||||
TTL_API_RESPONSE = 30 # General API response cache
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Redis cache."""
|
||||
self.redis = get_redis_client()
|
||||
|
||||
async def get_ticker(self, symbol: str) -> Optional[dict]:
|
||||
"""Get cached ticker data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USD')
|
||||
|
||||
Returns:
|
||||
Cached ticker data or None
|
||||
"""
|
||||
key = f"cache:ticker:{symbol.replace('/', '_')}"
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
data = await client.get(key)
|
||||
if data:
|
||||
logger.debug(f"Cache hit for ticker:{symbol}")
|
||||
return json.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache get failed: {e}")
|
||||
return None
|
||||
|
||||
async def set_ticker(self, symbol: str, data: dict, ttl: int = None) -> bool:
|
||||
"""Cache ticker data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
data: Ticker data
|
||||
ttl: Time-to-live in seconds (default: TTL_TICKER)
|
||||
|
||||
Returns:
|
||||
True if cached successfully
|
||||
"""
|
||||
key = f"cache:ticker:{symbol.replace('/', '_')}"
|
||||
ttl = ttl or self.TTL_TICKER
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
await client.setex(key, ttl, json.dumps(data))
|
||||
logger.debug(f"Cached ticker:{symbol} for {ttl}s")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_ohlcv(self, symbol: str, timeframe: str, limit: int = 100) -> Optional[list]:
|
||||
"""Get cached OHLCV data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
limit: Number of candles
|
||||
|
||||
Returns:
|
||||
Cached OHLCV data or None
|
||||
"""
|
||||
key = f"cache:ohlcv:{symbol.replace('/', '_')}:{timeframe}:{limit}"
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
data = await client.get(key)
|
||||
if data:
|
||||
logger.debug(f"Cache hit for ohlcv:{symbol}:{timeframe}")
|
||||
return json.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache get failed: {e}")
|
||||
return None
|
||||
|
||||
async def set_ohlcv(self, symbol: str, timeframe: str, data: list, limit: int = 100, ttl: int = None) -> bool:
|
||||
"""Cache OHLCV data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
data: OHLCV data
|
||||
limit: Number of candles
|
||||
ttl: Time-to-live in seconds
|
||||
|
||||
Returns:
|
||||
True if cached successfully
|
||||
"""
|
||||
key = f"cache:ohlcv:{symbol.replace('/', '_')}:{timeframe}:{limit}"
|
||||
ttl = ttl or self.TTL_OHLCV
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
await client.setex(key, ttl, json.dumps(data))
|
||||
logger.debug(f"Cached ohlcv:{symbol}:{timeframe} for {ttl}s")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_api_response(self, cache_key: str) -> Optional[dict]:
|
||||
"""Get cached API response.
|
||||
|
||||
Args:
|
||||
cache_key: Unique cache key
|
||||
|
||||
Returns:
|
||||
Cached response or None
|
||||
"""
|
||||
key = f"cache:api:{cache_key}"
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
data = await client.get(key)
|
||||
if data:
|
||||
logger.debug(f"Cache hit for api:{cache_key}")
|
||||
return json.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache get failed: {e}")
|
||||
return None
|
||||
|
||||
async def set_api_response(self, cache_key: str, data: dict, ttl: int = None) -> bool:
|
||||
"""Cache API response.
|
||||
|
||||
Args:
|
||||
cache_key: Unique cache key
|
||||
data: Response data
|
||||
ttl: Time-to-live in seconds
|
||||
|
||||
Returns:
|
||||
True if cached successfully
|
||||
"""
|
||||
key = f"cache:api:{cache_key}"
|
||||
ttl = ttl or self.TTL_API_RESPONSE
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
await client.setex(key, ttl, json.dumps(data))
|
||||
logger.debug(f"Cached api:{cache_key} for {ttl}s")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
return False
|
||||
|
||||
async def invalidate(self, pattern: str) -> int:
|
||||
"""Invalidate cache entries matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Redis key pattern (e.g., 'cache:ticker:*')
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
keys = []
|
||||
async for key in client.scan_iter(match=pattern):
|
||||
keys.append(key)
|
||||
|
||||
if keys:
|
||||
deleted = await client.delete(*keys)
|
||||
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
|
||||
return deleted
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache invalidation failed: {e}")
|
||||
return 0
|
||||
|
||||
async def get_stats(self) -> dict:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Cache statistics
|
||||
"""
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
info = await client.info('memory')
|
||||
|
||||
# Count cached items by type
|
||||
ticker_count = 0
|
||||
ohlcv_count = 0
|
||||
api_count = 0
|
||||
|
||||
async for key in client.scan_iter(match='cache:ticker:*'):
|
||||
ticker_count += 1
|
||||
async for key in client.scan_iter(match='cache:ohlcv:*'):
|
||||
ohlcv_count += 1
|
||||
async for key in client.scan_iter(match='cache:api:*'):
|
||||
api_count += 1
|
||||
|
||||
return {
|
||||
"memory_used": info.get('used_memory_human', 'N/A'),
|
||||
"ticker_entries": ticker_count,
|
||||
"ohlcv_entries": ohlcv_count,
|
||||
"api_entries": api_count,
|
||||
"total_entries": ticker_count + ohlcv_count + api_count
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cache stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_redis_cache: Optional[RedisCache] = None
|
||||
|
||||
|
||||
def get_redis_cache() -> RedisCache:
|
||||
"""Get global Redis cache instance."""
|
||||
global _redis_cache
|
||||
if _redis_cache is None:
|
||||
_redis_cache = RedisCache()
|
||||
return _redis_cache
|
||||
75
src/data/storage.py
Normal file
75
src/data/storage.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Data persistence."""
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, MarketData
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataStorage:
|
||||
"""Manages data storage and persistence."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data storage."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def store_ohlcv(
|
||||
self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
timestamp: datetime,
|
||||
open: Decimal,
|
||||
high: Decimal,
|
||||
low: Decimal,
|
||||
close: Decimal,
|
||||
volume: Decimal
|
||||
):
|
||||
"""Store OHLCV data.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
timestamp: Timestamp
|
||||
open: Open price
|
||||
high: High price
|
||||
low: Low price
|
||||
close: Close price
|
||||
volume: Volume
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
# Check if exists
|
||||
existing = session.query(MarketData).filter_by(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp
|
||||
).first()
|
||||
|
||||
if not existing:
|
||||
market_data = MarketData(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
open=open,
|
||||
high=high,
|
||||
low=low,
|
||||
close=close,
|
||||
volume=volume
|
||||
)
|
||||
session.add(market_data)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to store OHLCV data: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
14
src/exchanges/__init__.py
Normal file
14
src/exchanges/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Exchange adapters package."""
|
||||
|
||||
from .base import BaseExchangeAdapter
|
||||
from .factory import ExchangeFactory, get_exchange
|
||||
from .coinbase import CoinbaseAdapter
|
||||
from .public_data import PublicDataAdapter
|
||||
|
||||
# Register exchange adapters
|
||||
ExchangeFactory.register("coinbase", CoinbaseAdapter)
|
||||
ExchangeFactory.register("binance public", PublicDataAdapter)
|
||||
ExchangeFactory.register("public data", PublicDataAdapter) # Alias
|
||||
|
||||
__all__ = ['ExchangeFactory', 'get_exchange', 'CoinbaseAdapter', 'PublicDataAdapter', 'BaseExchangeAdapter']
|
||||
|
||||
309
src/exchanges/base.py
Normal file
309
src/exchanges/base.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""Base exchange adapter interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Order:
|
||||
"""Order representation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str, # 'buy' or 'sell'
|
||||
order_type: str, # 'market', 'limit', etc.
|
||||
quantity: Decimal,
|
||||
price: Optional[Decimal] = None,
|
||||
**kwargs
|
||||
):
|
||||
self.symbol = symbol
|
||||
self.side = side
|
||||
self.order_type = order_type
|
||||
self.quantity = quantity
|
||||
self.price = price
|
||||
self.extra = kwargs
|
||||
|
||||
|
||||
class BaseExchangeAdapter(ABC):
|
||||
"""Base class for exchange adapters."""
|
||||
|
||||
def __init__(self, api_key: str, api_secret: str, sandbox: bool = False, read_only: bool = True):
|
||||
"""Initialize exchange adapter.
|
||||
|
||||
Args:
|
||||
api_key: API key
|
||||
api_secret: API secret
|
||||
sandbox: Use sandbox/testnet
|
||||
read_only: Read-only mode (no trading)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.sandbox = sandbox
|
||||
self.read_only = read_only
|
||||
self.logger = get_logger(f"exchange.{self.__class__.__name__}")
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Exchange name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to exchange.
|
||||
|
||||
Returns:
|
||||
True if connection successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self):
|
||||
"""Disconnect from exchange."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_balance(self, currency: Optional[str] = None) -> Dict[str, Decimal]:
|
||||
"""Get account balance.
|
||||
|
||||
Args:
|
||||
currency: Specific currency to get balance for, or None for all
|
||||
|
||||
Returns:
|
||||
Dictionary of currency -> balance
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker for symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'BTC/USD')
|
||||
|
||||
Returns:
|
||||
Ticker data with 'bid', 'ask', 'last', 'volume', etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_orderbook(self, symbol: str, limit: int = 20) -> Dict[str, List]:
|
||||
"""Get order book for symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
limit: Number of orders per side
|
||||
|
||||
Returns:
|
||||
Dictionary with 'bids' and 'asks' lists
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def place_order(self, order: Order) -> Dict[str, Any]:
|
||||
"""Place an order.
|
||||
|
||||
Args:
|
||||
order: Order object
|
||||
|
||||
Returns:
|
||||
Order response with order ID and status
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_order(self, order_id: str, symbol: str) -> bool:
|
||||
"""Cancel an order.
|
||||
|
||||
Args:
|
||||
order_id: Order ID to cancel
|
||||
symbol: Trading pair symbol
|
||||
|
||||
Returns:
|
||||
True if cancellation successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_status(self, order_id: str, symbol: str) -> Dict[str, Any]:
|
||||
"""Get order status.
|
||||
|
||||
Args:
|
||||
order_id: Order ID
|
||||
symbol: Trading pair symbol
|
||||
|
||||
Returns:
|
||||
Order status information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get open orders.
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol to filter by
|
||||
|
||||
Returns:
|
||||
List of open orders
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_positions(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get open positions.
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol to filter by
|
||||
|
||||
Returns:
|
||||
List of positions
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV (candlestick) data.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
|
||||
since: Start datetime
|
||||
limit: Number of candles
|
||||
|
||||
Returns:
|
||||
List of [timestamp, open, high, low, close, volume]
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe_ticker(self, symbol: str, callback):
|
||||
"""Subscribe to ticker updates via WebSocket.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Callback function(ticker_data)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe_orderbook(self, symbol: str, callback):
|
||||
"""Subscribe to order book updates via WebSocket.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Callback function(orderbook_data)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe_trades(self, symbol: str, callback):
|
||||
"""Subscribe to trade updates via WebSocket.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Callback function(trade_data)
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_order(self, order: Order) -> bool:
|
||||
"""Validate order before placing.
|
||||
|
||||
Args:
|
||||
order: Order to validate
|
||||
|
||||
Returns:
|
||||
True if order is valid
|
||||
"""
|
||||
if self.read_only:
|
||||
self.logger.warning("Exchange is in read-only mode")
|
||||
return False
|
||||
|
||||
if not order.symbol:
|
||||
self.logger.error("Order symbol is required")
|
||||
return False
|
||||
|
||||
if order.quantity <= 0:
|
||||
self.logger.error("Order quantity must be positive")
|
||||
return False
|
||||
|
||||
if order.order_type == 'limit' and not order.price:
|
||||
self.logger.error("Limit orders require a price")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol format.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to normalize
|
||||
|
||||
Returns:
|
||||
Normalized symbol
|
||||
"""
|
||||
# Default implementation - override in subclasses if needed
|
||||
return symbol.upper().replace('-', '/')
|
||||
|
||||
def get_fee_structure(self) -> Dict[str, float]:
|
||||
"""Get fee structure (maker/taker fees).
|
||||
|
||||
Returns:
|
||||
Dictionary with 'maker' and 'taker' fee percentages, and optional 'minimum'
|
||||
"""
|
||||
# Default fees - override in subclasses
|
||||
return {
|
||||
'maker': 0.001, # 0.1%
|
||||
'taker': 0.001, # 0.1%
|
||||
'minimum': 0.0, # Minimum fee amount
|
||||
}
|
||||
|
||||
def extract_fee_from_order_response(self, order_response: Dict[str, Any]) -> Optional[Decimal]:
|
||||
"""Extract actual fee from order response.
|
||||
|
||||
Args:
|
||||
order_response: Order response from exchange API
|
||||
|
||||
Returns:
|
||||
Fee amount as Decimal, or None if not available
|
||||
"""
|
||||
# Default implementation - override in subclasses
|
||||
# Try common fee fields
|
||||
if 'fee' in order_response:
|
||||
try:
|
||||
return Decimal(str(order_response['fee']))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if 'fees' in order_response:
|
||||
# Some exchanges return fees as a list
|
||||
fees = order_response['fees']
|
||||
if isinstance(fees, list) and len(fees) > 0:
|
||||
try:
|
||||
# Sum all fees
|
||||
total_fee = sum(Decimal(str(f.get('cost', 0) if isinstance(f, dict) else f)) for f in fees)
|
||||
return total_fee
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Try 'cost' field (some exchanges use this)
|
||||
if 'cost' in order_response:
|
||||
try:
|
||||
return Decimal(str(order_response['cost']))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
392
src/exchanges/coinbase.py
Normal file
392
src/exchanges/coinbase.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""Coinbase Advanced Trade API adapter with WebSocket support."""
|
||||
|
||||
import asyncio
|
||||
import ccxt
|
||||
import websockets
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from datetime import datetime
|
||||
from .base import BaseExchangeAdapter, Order
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CoinbaseAdapter(BaseExchangeAdapter):
|
||||
"""Coinbase Advanced Trade API adapter."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Exchange name."""
|
||||
return "Coinbase"
|
||||
|
||||
def __init__(self, api_key: str, api_secret: str, sandbox: bool = False, read_only: bool = True):
|
||||
"""Initialize Coinbase adapter."""
|
||||
super().__init__(api_key, api_secret, sandbox, read_only)
|
||||
|
||||
# Initialize ccxt exchange with async support
|
||||
exchange_class = ccxt.async_support.coinbase
|
||||
self.exchange = exchange_class({
|
||||
'apiKey': api_key,
|
||||
'secret': api_secret,
|
||||
'sandbox': sandbox,
|
||||
'enableRateLimit': True,
|
||||
'options': {
|
||||
'defaultType': 'spot', # or 'future' for futures
|
||||
}
|
||||
})
|
||||
|
||||
# WebSocket connections
|
||||
self._ws_connections = {}
|
||||
self._ws_callbacks = {}
|
||||
self._ws_loop = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to exchange."""
|
||||
try:
|
||||
# Test connection by fetching account info
|
||||
if self.read_only:
|
||||
# Just verify credentials work
|
||||
await self.exchange.fetch_balance()
|
||||
else:
|
||||
# Verify trading permissions
|
||||
await self.exchange.fetch_balance()
|
||||
|
||||
self._connected = True
|
||||
logger.info(f"Connected to {self.name} (sandbox={self.sandbox}, read_only={self.read_only})")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {self.name}: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from exchange."""
|
||||
# Close WebSocket connections
|
||||
for ws in self._ws_connections.values():
|
||||
try:
|
||||
# await ws.close() # If using real websockets
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await self.exchange.close()
|
||||
|
||||
self._ws_connections.clear()
|
||||
self._ws_callbacks.clear()
|
||||
self._connected = False
|
||||
logger.info(f"Disconnected from {self.name}")
|
||||
|
||||
async def get_balance(self, currency: Optional[str] = None) -> Dict[str, Decimal]:
|
||||
"""Get account balance."""
|
||||
try:
|
||||
balance = await self.exchange.fetch_balance()
|
||||
result = {}
|
||||
|
||||
for curr, amounts in balance.items():
|
||||
if isinstance(amounts, dict) and 'free' in amounts:
|
||||
free = Decimal(str(amounts['free']))
|
||||
if free > 0 or currency == curr:
|
||||
result[curr] = free
|
||||
|
||||
if currency:
|
||||
return {currency: result.get(currency, Decimal(0))}
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get balance: {e}")
|
||||
return {}
|
||||
|
||||
async def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker."""
|
||||
try:
|
||||
ticker = await self.exchange.fetch_ticker(self.normalize_symbol(symbol))
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': Decimal(str(ticker.get('bid', 0))),
|
||||
'ask': Decimal(str(ticker.get('ask', 0))),
|
||||
'last': Decimal(str(ticker.get('last', 0))),
|
||||
'high': Decimal(str(ticker.get('high', 0))),
|
||||
'low': Decimal(str(ticker.get('low', 0))),
|
||||
'volume': Decimal(str(ticker.get('quoteVolume', 0))),
|
||||
'timestamp': ticker.get('timestamp'),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ticker for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
async def get_orderbook(self, symbol: str, limit: int = 20) -> Dict[str, List]:
|
||||
"""Get order book."""
|
||||
try:
|
||||
orderbook = await self.exchange.fetch_order_book(self.normalize_symbol(symbol), limit)
|
||||
return {
|
||||
'bids': [[Decimal(str(b[0])), Decimal(str(b[1]))] for b in orderbook['bids']],
|
||||
'asks': [[Decimal(str(a[0])), Decimal(str(a[1]))] for a in orderbook['asks']],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get orderbook for {symbol}: {e}")
|
||||
return {'bids': [], 'asks': []}
|
||||
|
||||
async def place_order(self, order: Order) -> Dict[str, Any]:
|
||||
"""Place an order."""
|
||||
if not self.validate_order(order):
|
||||
return {'error': 'Invalid order'}
|
||||
|
||||
try:
|
||||
symbol = self.normalize_symbol(order.symbol)
|
||||
|
||||
if order.order_type == 'market':
|
||||
result = await self.exchange.create_market_order(
|
||||
symbol,
|
||||
order.side,
|
||||
float(order.quantity)
|
||||
)
|
||||
elif order.order_type == 'limit':
|
||||
result = await self.exchange.create_limit_order(
|
||||
symbol,
|
||||
order.side,
|
||||
float(order.quantity),
|
||||
float(order.price)
|
||||
)
|
||||
else:
|
||||
return {'error': f'Unsupported order type: {order.order_type}'}
|
||||
|
||||
# Extract fee from result
|
||||
fee = self.extract_fee_from_order_response(result)
|
||||
|
||||
return {
|
||||
'id': result.get('id'),
|
||||
'symbol': symbol,
|
||||
'status': result.get('status', 'open'),
|
||||
'side': result.get('side'),
|
||||
'type': result.get('type'),
|
||||
'amount': Decimal(str(result.get('amount', 0))),
|
||||
'price': Decimal(str(result.get('price', 0))) if result.get('price') else None,
|
||||
'fee': fee, # Include fee in response
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to place order: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
async def cancel_order(self, order_id: str, symbol: str) -> bool:
|
||||
"""Cancel an order."""
|
||||
try:
|
||||
await self.exchange.cancel_order(order_id, self.normalize_symbol(symbol))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel order {order_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_order_status(self, order_id: str, symbol: str) -> Dict[str, Any]:
|
||||
"""Get order status."""
|
||||
try:
|
||||
order = await self.exchange.fetch_order(order_id, self.normalize_symbol(symbol))
|
||||
|
||||
# Extract fee from order
|
||||
fee = self.extract_fee_from_order_response(order)
|
||||
|
||||
return {
|
||||
'id': order.get('id'),
|
||||
'symbol': symbol,
|
||||
'status': order.get('status'),
|
||||
'side': order.get('side'),
|
||||
'type': order.get('type'),
|
||||
'amount': Decimal(str(order.get('amount', 0))),
|
||||
'filled': Decimal(str(order.get('filled', 0))),
|
||||
'remaining': Decimal(str(order.get('remaining', 0))),
|
||||
'price': Decimal(str(order.get('price', 0))) if order.get('price') else None,
|
||||
'average': Decimal(str(order.get('average', 0))) if order.get('average') else None,
|
||||
'fee': fee, # Include fee in response
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get order status {order_id}: {e}")
|
||||
return {}
|
||||
|
||||
async def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get open orders."""
|
||||
try:
|
||||
if symbol:
|
||||
orders = await self.exchange.fetch_open_orders(self.normalize_symbol(symbol))
|
||||
else:
|
||||
orders = await self.exchange.fetch_open_orders()
|
||||
|
||||
return [
|
||||
{
|
||||
'id': o.get('id'),
|
||||
'symbol': o.get('symbol'),
|
||||
'status': o.get('status'),
|
||||
'side': o.get('side'),
|
||||
'type': o.get('type'),
|
||||
'amount': Decimal(str(o.get('amount', 0))),
|
||||
'filled': Decimal(str(o.get('filled', 0))),
|
||||
'price': Decimal(str(o.get('price', 0))) if o.get('price') else None,
|
||||
}
|
||||
for o in orders
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get open orders: {e}")
|
||||
return []
|
||||
|
||||
async def get_positions(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get open positions."""
|
||||
try:
|
||||
# Coinbase Advanced Trade uses fetch_positions for futures
|
||||
# For spot, positions are derived from balances
|
||||
positions = await self.exchange.fetch_positions(symbols=[symbol] if symbol else None)
|
||||
|
||||
return [
|
||||
{
|
||||
'symbol': p.get('symbol'),
|
||||
'side': p.get('side'),
|
||||
'size': Decimal(str(p.get('size', 0))),
|
||||
'entry_price': Decimal(str(p.get('entryPrice', 0))),
|
||||
'mark_price': Decimal(str(p.get('markPrice', 0))),
|
||||
'unrealized_pnl': Decimal(str(p.get('unrealizedPnl', 0))),
|
||||
}
|
||||
for p in positions if p.get('size', 0) != 0
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get positions: {e}")
|
||||
return []
|
||||
|
||||
async def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV data."""
|
||||
try:
|
||||
since_timestamp = int(since.timestamp() * 1000) if since else None
|
||||
ohlcv = await self.exchange.fetch_ohlcv(
|
||||
self.normalize_symbol(symbol),
|
||||
timeframe,
|
||||
since_timestamp,
|
||||
limit
|
||||
)
|
||||
|
||||
return ohlcv
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable):
|
||||
"""Subscribe to ticker updates via WebSocket."""
|
||||
try:
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
|
||||
# Normalize symbol for Coinbase
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
|
||||
# Store callback
|
||||
self._ws_callbacks[f'ticker_{symbol}'] = callback
|
||||
|
||||
# Start WebSocket connection if not already started
|
||||
if not hasattr(self, '_ws_running') or not self._ws_running:
|
||||
self._start_websocket_loop()
|
||||
|
||||
logger.info(f"Subscribed to ticker updates for {symbol}")
|
||||
except ImportError:
|
||||
logger.warning("websockets library not available, using polling fallback")
|
||||
self._ws_callbacks[f'ticker_{symbol}'] = callback
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe to ticker: {e}")
|
||||
self._ws_callbacks[f'ticker_{symbol}'] = callback
|
||||
|
||||
def subscribe_orderbook(self, symbol: str, callback: Callable):
|
||||
"""Subscribe to order book updates via WebSocket."""
|
||||
try:
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
self._ws_callbacks[f'orderbook_{symbol}'] = callback
|
||||
|
||||
if not hasattr(self, '_ws_running') or not self._ws_running:
|
||||
self._start_websocket_loop()
|
||||
|
||||
logger.info(f"Subscribed to orderbook updates for {symbol}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe to orderbook: {e}")
|
||||
self._ws_callbacks[f'orderbook_{symbol}'] = callback
|
||||
|
||||
def subscribe_trades(self, symbol: str, callback: Callable):
|
||||
"""Subscribe to trade updates via WebSocket."""
|
||||
try:
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
self._ws_callbacks[f'trades_{symbol}'] = callback
|
||||
|
||||
if not hasattr(self, '_ws_running') or not self._ws_running:
|
||||
self._start_websocket_loop()
|
||||
|
||||
logger.info(f"Subscribed to trades updates for {symbol}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe to trades: {e}")
|
||||
self._ws_callbacks[f'trades_{symbol}'] = callback
|
||||
|
||||
def _start_websocket_loop(self):
|
||||
"""Start WebSocket connection loop."""
|
||||
try:
|
||||
import threading
|
||||
self._ws_running = True
|
||||
# Start WebSocket in background thread
|
||||
# Note: Full implementation would use asyncio event loop
|
||||
logger.info("WebSocket connection started (basic implementation)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start WebSocket: {e}")
|
||||
self._ws_running = False
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol for Coinbase."""
|
||||
# Coinbase uses format like BTC-USD
|
||||
return symbol.replace('/', '-').upper()
|
||||
|
||||
def get_fee_structure(self) -> Dict[str, float]:
|
||||
"""Get Coinbase fee structure."""
|
||||
# Coinbase Advanced Trade fees (approximate)
|
||||
# Actual fees may vary based on trading volume and account type
|
||||
return {
|
||||
'maker': 0.004, # 0.4%
|
||||
'taker': 0.006, # 0.6%
|
||||
'minimum': 0.0, # No minimum fee
|
||||
}
|
||||
|
||||
def extract_fee_from_order_response(self, order_response: Dict[str, Any]) -> Optional[Decimal]:
|
||||
"""Extract actual fee from Coinbase order response.
|
||||
|
||||
Args:
|
||||
order_response: Order response from Coinbase API
|
||||
|
||||
Returns:
|
||||
Fee amount as Decimal, or None if not available
|
||||
"""
|
||||
# Coinbase/ccxt typically returns fees in 'fee' field or 'fees' list
|
||||
if 'fee' in order_response and order_response['fee']:
|
||||
try:
|
||||
fee_data = order_response['fee']
|
||||
if isinstance(fee_data, dict):
|
||||
# Fee is often {'cost': amount, 'currency': 'USD'}
|
||||
return Decimal(str(fee_data.get('cost', 0)))
|
||||
else:
|
||||
return Decimal(str(fee_data))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Try 'fees' list
|
||||
if 'fees' in order_response:
|
||||
fees = order_response['fees']
|
||||
if isinstance(fees, list) and len(fees) > 0:
|
||||
try:
|
||||
total_fee = Decimal(0)
|
||||
for fee_item in fees:
|
||||
if isinstance(fee_item, dict):
|
||||
total_fee += Decimal(str(fee_item.get('cost', 0)))
|
||||
else:
|
||||
total_fee += Decimal(str(fee_item))
|
||||
return total_fee if total_fee > 0 else None
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
165
src/exchanges/factory.py
Normal file
165
src/exchanges/factory.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Exchange factory for creating exchange adapters."""
|
||||
|
||||
from typing import Optional, Dict, Type, List
|
||||
from src.core.database import get_database, Exchange
|
||||
from src.core.logger import get_logger
|
||||
from src.security.key_manager import get_key_manager
|
||||
from .base import BaseExchangeAdapter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ExchangeFactory:
|
||||
"""Factory for creating exchange adapter instances."""
|
||||
|
||||
_adapters: Dict[str, Type[BaseExchangeAdapter]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, adapter_class: Type[BaseExchangeAdapter]):
|
||||
"""Register an exchange adapter.
|
||||
|
||||
Args:
|
||||
name: Exchange name
|
||||
adapter_class: Adapter class
|
||||
"""
|
||||
cls._adapters[name.lower()] = adapter_class
|
||||
logger.info(f"Registered exchange adapter: {name}")
|
||||
|
||||
@classmethod
|
||||
async def create(cls, exchange_id: int) -> Optional[BaseExchangeAdapter]:
|
||||
"""Create exchange adapter from database.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange ID from database
|
||||
|
||||
Returns:
|
||||
Exchange adapter instance or None
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
db = get_database()
|
||||
key_manager = get_key_manager()
|
||||
|
||||
async with db.get_session() as session:
|
||||
try:
|
||||
stmt = select(Exchange).where(Exchange.id == exchange_id)
|
||||
result = await session.execute(stmt)
|
||||
exchange = result.scalar_one_or_none()
|
||||
|
||||
if not exchange:
|
||||
logger.error(f"Exchange {exchange_id} not found")
|
||||
return None
|
||||
|
||||
if not exchange.enabled:
|
||||
logger.warning(f"Exchange {exchange.name} is disabled")
|
||||
return None
|
||||
|
||||
# Get adapter class
|
||||
adapter_class = cls._adapters.get(exchange.name.lower())
|
||||
if not adapter_class:
|
||||
logger.error(f"No adapter registered for exchange: {exchange.name}")
|
||||
return None
|
||||
|
||||
# Check if this is a public data adapter (doesn't need credentials)
|
||||
from .public_data import PublicDataAdapter
|
||||
is_public_data = adapter_class == PublicDataAdapter
|
||||
|
||||
if is_public_data:
|
||||
# Public data adapter doesn't need credentials
|
||||
adapter = adapter_class(
|
||||
api_key="",
|
||||
api_secret="",
|
||||
sandbox=False,
|
||||
read_only=True
|
||||
)
|
||||
else:
|
||||
# Get credentials for regular exchanges
|
||||
credentials = await key_manager.get_exchange_credentials(exchange_id)
|
||||
if not credentials:
|
||||
logger.error(f"No credentials found for exchange: {exchange.name}")
|
||||
return None
|
||||
|
||||
# Create adapter instance
|
||||
adapter = adapter_class(
|
||||
api_key=credentials['api_key'],
|
||||
api_secret=credentials['api_secret'],
|
||||
sandbox=credentials['sandbox'],
|
||||
read_only=credentials['read_only']
|
||||
)
|
||||
|
||||
# Connect
|
||||
if await adapter.connect():
|
||||
logger.info(f"Connected to {exchange.name}")
|
||||
return adapter
|
||||
else:
|
||||
logger.error(f"Failed to connect to {exchange.name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create exchange adapter: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def create_by_name(
|
||||
cls,
|
||||
name: str,
|
||||
api_key: str,
|
||||
api_secret: str,
|
||||
sandbox: bool = False,
|
||||
read_only: bool = True
|
||||
) -> Optional[BaseExchangeAdapter]:
|
||||
"""Create exchange adapter directly.
|
||||
|
||||
Args:
|
||||
name: Exchange name
|
||||
api_key: API key
|
||||
api_secret: API secret
|
||||
sandbox: Use sandbox/testnet
|
||||
read_only: Read-only mode
|
||||
|
||||
Returns:
|
||||
Exchange adapter instance or None
|
||||
"""
|
||||
adapter_class = cls._adapters.get(name.lower())
|
||||
if not adapter_class:
|
||||
logger.error(f"No adapter registered for exchange: {name}")
|
||||
return None
|
||||
|
||||
try:
|
||||
adapter = adapter_class(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
sandbox=sandbox,
|
||||
read_only=read_only
|
||||
)
|
||||
|
||||
if await adapter.connect():
|
||||
return adapter
|
||||
else:
|
||||
logger.error(f"Failed to connect to {name}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create exchange adapter: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def list_available(cls) -> List[str]:
|
||||
"""List available exchange adapters.
|
||||
|
||||
Returns:
|
||||
List of exchange names
|
||||
"""
|
||||
return list(cls._adapters.keys())
|
||||
|
||||
|
||||
# Convenience function
|
||||
async def get_exchange(exchange_id: int) -> Optional[BaseExchangeAdapter]:
|
||||
"""Get exchange adapter by ID.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange ID
|
||||
|
||||
Returns:
|
||||
Exchange adapter instance or None
|
||||
"""
|
||||
return await ExchangeFactory.create(exchange_id)
|
||||
|
||||
433
src/exchanges/public_data.py
Normal file
433
src/exchanges/public_data.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Public Data Exchange Adapter - Free market data without API keys.
|
||||
|
||||
Uses CCXT in public mode to fetch market data from Binance (or other exchanges)
|
||||
without requiring API keys. Perfect for testing, backtesting, and paper trading.
|
||||
"""
|
||||
|
||||
import ccxt.async_support as ccxt
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from datetime import datetime
|
||||
from .base import BaseExchangeAdapter, Order
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PublicDataAdapter(BaseExchangeAdapter):
|
||||
"""Public market data adapter using CCXT in public mode (no API keys needed).
|
||||
|
||||
This adapter uses Binance's public API to fetch:
|
||||
- Historical OHLCV data
|
||||
- Real-time ticker data
|
||||
- Order book data
|
||||
- Trade data
|
||||
|
||||
No API keys required - perfect for testing and paper trading.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Exchange name."""
|
||||
return self._selected_exchange_name or "Public Data"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "",
|
||||
api_secret: str = "",
|
||||
sandbox: bool = False,
|
||||
read_only: bool = True
|
||||
):
|
||||
"""Initialize Public Data adapter.
|
||||
|
||||
Args:
|
||||
api_key: Ignored (public data doesn't need API keys)
|
||||
api_secret: Ignored (public data doesn't need API keys)
|
||||
sandbox: Ignored (Binance public API is always live data)
|
||||
read_only: Always True (this adapter is read-only by design)
|
||||
"""
|
||||
super().__init__("", "", False, True) # Always read-only, no keys needed
|
||||
|
||||
# List of exchanges to try as fallbacks (in order of preference)
|
||||
# These exchanges have good public API access without geographic restrictions
|
||||
self._exchange_options = [
|
||||
('kraken', 'Kraken Public'),
|
||||
('coinbase', 'Coinbase Public'),
|
||||
('binance', 'Binance Public'), # Try Binance last since it has geo restrictions
|
||||
]
|
||||
|
||||
self.exchange = None
|
||||
self._selected_exchange_name = None
|
||||
|
||||
# WebSocket connections for real-time data
|
||||
self._ws_callbacks = {}
|
||||
self._polling_timer = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to exchange (test public API access).
|
||||
|
||||
Tries multiple exchanges as fallbacks if one is blocked.
|
||||
|
||||
Returns:
|
||||
True if connection successful (public API is always available)
|
||||
"""
|
||||
# Try each exchange until one works
|
||||
for exchange_id, exchange_display_name in self._exchange_options:
|
||||
try:
|
||||
# Create exchange instance
|
||||
exchange_class = getattr(ccxt, exchange_id, None)
|
||||
if not exchange_class:
|
||||
continue
|
||||
|
||||
self.exchange = exchange_class({
|
||||
'enableRateLimit': True,
|
||||
'options': {
|
||||
'defaultType': 'spot',
|
||||
}
|
||||
})
|
||||
|
||||
# Load markets if not already loaded
|
||||
if not hasattr(self.exchange, 'markets') or not self.exchange.markets:
|
||||
try:
|
||||
await self.exchange.load_markets()
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
# Test connection with a common symbol
|
||||
# Different exchanges use different symbol formats
|
||||
test_symbols = ['BTC/USDT', 'BTC/USD', 'BTC/EUR']
|
||||
ticker_result = None
|
||||
|
||||
for test_symbol in test_symbols:
|
||||
try:
|
||||
ticker_result = await self.exchange.fetch_ticker(test_symbol)
|
||||
break
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if not ticker_result:
|
||||
raise Exception("Could not fetch ticker from any test symbol")
|
||||
|
||||
# Success! Use this exchange
|
||||
self._selected_exchange_name = exchange_display_name
|
||||
self._connected = True
|
||||
logger.info(f"Connected to {self.name} (public data, no API keys needed)")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to connect to {exchange_id}: {e}, trying next exchange...")
|
||||
continue
|
||||
|
||||
# All exchanges failed
|
||||
logger.error("Failed to connect to any public exchange")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from exchange."""
|
||||
self._ws_callbacks.clear()
|
||||
if self._polling_timer:
|
||||
self._polling_timer.cancel()
|
||||
self._polling_timer = None
|
||||
|
||||
if self.exchange:
|
||||
await self.exchange.close()
|
||||
|
||||
self._connected = False
|
||||
logger.info(f"Disconnected from {self.name}")
|
||||
|
||||
async def get_balance(self, currency: Optional[str] = None) -> Dict[str, Decimal]:
|
||||
"""Get account balance.
|
||||
|
||||
Note: Public data adapter cannot access account balances.
|
||||
Returns empty dict. Use paper trading balance instead.
|
||||
"""
|
||||
logger.warning("Public data adapter cannot access account balances")
|
||||
return {}
|
||||
|
||||
async def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker (public endpoint)."""
|
||||
try:
|
||||
ticker = await self.exchange.fetch_ticker(self.normalize_symbol(symbol))
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': Decimal(str(ticker.get('bid', 0))),
|
||||
'ask': Decimal(str(ticker.get('ask', 0))),
|
||||
'last': Decimal(str(ticker.get('last', 0))),
|
||||
'high': Decimal(str(ticker.get('high', 0))),
|
||||
'low': Decimal(str(ticker.get('low', 0))),
|
||||
'volume': Decimal(str(ticker.get('quoteVolume', 0))),
|
||||
'timestamp': ticker.get('timestamp'),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ticker for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
async def get_orderbook(self, symbol: str, limit: int = 20) -> Dict[str, List]:
|
||||
"""Get order book (public endpoint)."""
|
||||
try:
|
||||
orderbook = await self.exchange.fetch_order_book(self.normalize_symbol(symbol), limit)
|
||||
return {
|
||||
'bids': [[Decimal(str(b[0])), Decimal(str(b[1]))] for b in orderbook['bids']],
|
||||
'asks': [[Decimal(str(a[0])), Decimal(str(a[1]))] for a in orderbook['asks']],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get orderbook for {symbol}: {e}")
|
||||
return {'bids': [], 'asks': []}
|
||||
|
||||
async def place_order(self, order: Order) -> Dict[str, Any]:
|
||||
"""Place an order.
|
||||
|
||||
Note: Public data adapter cannot place orders.
|
||||
Returns error message. Use paper trading instead.
|
||||
"""
|
||||
logger.warning("Public data adapter cannot place orders - use paper trading")
|
||||
return {'error': 'Public data adapter is read-only. Use paper trading for order execution.'}
|
||||
|
||||
async def cancel_order(self, order_id: str, symbol: str) -> bool:
|
||||
"""Cancel an order.
|
||||
|
||||
Note: Public data adapter cannot cancel orders.
|
||||
"""
|
||||
logger.warning("Public data adapter cannot cancel orders")
|
||||
return False
|
||||
|
||||
async def get_order_status(self, order_id: str, symbol: str) -> Dict[str, Any]:
|
||||
"""Get order status.
|
||||
|
||||
Note: Public data adapter cannot access order status.
|
||||
"""
|
||||
logger.warning("Public data adapter cannot access order status")
|
||||
return {}
|
||||
|
||||
async def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get open orders.
|
||||
|
||||
Note: Public data adapter cannot access orders.
|
||||
"""
|
||||
logger.warning("Public data adapter cannot access orders")
|
||||
return []
|
||||
|
||||
async def get_positions(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get open positions.
|
||||
|
||||
Note: Public data adapter cannot access positions.
|
||||
"""
|
||||
logger.warning("Public data adapter cannot access positions")
|
||||
return []
|
||||
|
||||
async def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV data (public endpoint).
|
||||
|
||||
This is the main method for fetching historical data.
|
||||
Most exchanges support up to 1000 candles per request.
|
||||
"""
|
||||
try:
|
||||
since_timestamp = int(since.timestamp() * 1000) if since else None
|
||||
|
||||
# Most exchanges support up to 1000 candles per request
|
||||
# If limit > 1000, we'll need to make multiple requests
|
||||
max_limit = min(limit, 1000)
|
||||
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
|
||||
ohlcv = await self.exchange.fetch_ohlcv(
|
||||
normalized_symbol,
|
||||
timeframe,
|
||||
since_timestamp,
|
||||
max_limit
|
||||
)
|
||||
|
||||
logger.info(f"Fetched {len(ohlcv)} candles for {symbol} ({timeframe})")
|
||||
return ohlcv
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable):
|
||||
"""Subscribe to ticker updates (polling-based for public data).
|
||||
|
||||
Since we don't have WebSocket auth, we poll the ticker endpoint.
|
||||
"""
|
||||
try:
|
||||
import threading
|
||||
import time
|
||||
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
self._ws_callbacks[f'ticker_{symbol}'] = callback
|
||||
|
||||
def poll_ticker():
|
||||
"""Poll ticker every 2 seconds."""
|
||||
while f'ticker_{symbol}' in self._ws_callbacks:
|
||||
try:
|
||||
ticker = self.get_ticker(symbol)
|
||||
if ticker and callback:
|
||||
callback({
|
||||
'price': ticker.get('last', 0),
|
||||
'bid': ticker.get('bid', 0),
|
||||
'ask': ticker.get('ask', 0),
|
||||
'volume': ticker.get('volume', 0),
|
||||
'timestamp': ticker.get('timestamp'),
|
||||
})
|
||||
time.sleep(2) # Poll every 2 seconds
|
||||
except Exception as e:
|
||||
logger.error(f"Ticker polling error: {e}")
|
||||
time.sleep(5) # Wait longer on error
|
||||
|
||||
# Start polling thread
|
||||
thread = threading.Thread(target=poll_ticker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Subscribed to ticker updates for {symbol} (polling mode)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe to ticker: {e}")
|
||||
self._ws_callbacks[f'ticker_{symbol}'] = callback
|
||||
|
||||
def subscribe_orderbook(self, symbol: str, callback: Callable):
|
||||
"""Subscribe to order book updates (polling-based)."""
|
||||
try:
|
||||
import threading
|
||||
import time
|
||||
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
self._ws_callbacks[f'orderbook_{symbol}'] = callback
|
||||
|
||||
def poll_orderbook():
|
||||
"""Poll order book every 1 second."""
|
||||
while f'orderbook_{symbol}' in self._ws_callbacks:
|
||||
try:
|
||||
orderbook = self.get_orderbook(symbol, limit=20)
|
||||
if orderbook and callback:
|
||||
callback(orderbook)
|
||||
time.sleep(1) # Poll every second
|
||||
except Exception as e:
|
||||
logger.error(f"Orderbook polling error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
thread = threading.Thread(target=poll_orderbook, daemon=True)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Subscribed to orderbook updates for {symbol} (polling mode)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe to orderbook: {e}")
|
||||
self._ws_callbacks[f'orderbook_{symbol}'] = callback
|
||||
|
||||
def subscribe_trades(self, symbol: str, callback: Callable):
|
||||
"""Subscribe to trade updates (polling-based)."""
|
||||
try:
|
||||
import threading
|
||||
import time
|
||||
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
self._ws_callbacks[f'trades_{symbol}'] = callback
|
||||
|
||||
def poll_trades():
|
||||
"""Poll recent trades every 1 second."""
|
||||
while f'trades_{symbol}' in self._ws_callbacks:
|
||||
try:
|
||||
trades = self.exchange.fetch_trades(normalized_symbol, limit=10)
|
||||
if trades and callback:
|
||||
for trade in trades[-5:]: # Last 5 trades
|
||||
callback({
|
||||
'price': trade.get('price', 0),
|
||||
'amount': trade.get('amount', 0),
|
||||
'side': trade.get('side', 'buy'),
|
||||
'timestamp': trade.get('timestamp'),
|
||||
})
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Trades polling error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
thread = threading.Thread(target=poll_trades, daemon=True)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Subscribed to trades updates for {symbol} (polling mode)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to subscribe to trades: {e}")
|
||||
self._ws_callbacks[f'trades_{symbol}'] = callback
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol for the selected exchange.
|
||||
|
||||
Different exchanges use different symbol formats:
|
||||
- Binance/Kraken: BTC/USDT, BTC/USD
|
||||
- Some exchanges: BTC-USDT, BTC-USD
|
||||
- Kraken sometimes uses: XBT/USD instead of BTC/USD
|
||||
"""
|
||||
if not self.exchange:
|
||||
return symbol.replace('-', '/').upper()
|
||||
|
||||
# Basic normalization: convert dashes to slashes, uppercase
|
||||
normalized = symbol.replace('-', '/').upper()
|
||||
|
||||
# Try to use exchange's built-in symbol normalization
|
||||
try:
|
||||
if hasattr(self.exchange, 'markets') and self.exchange.markets:
|
||||
# Check if normalized symbol exists in markets
|
||||
if normalized in self.exchange.markets:
|
||||
return normalized
|
||||
|
||||
# Try alternative formats
|
||||
# For USD pairs, try USDT (common on many exchanges)
|
||||
if '/USD' in normalized:
|
||||
alt_symbol = normalized.replace('/USD', '/USDT')
|
||||
if alt_symbol in self.exchange.markets:
|
||||
return alt_symbol
|
||||
|
||||
# For Kraken: try XBT instead of BTC
|
||||
if normalized.startswith('BTC/'):
|
||||
alt_symbol = normalized.replace('BTC/', 'XBT/')
|
||||
if alt_symbol in self.exchange.markets:
|
||||
return alt_symbol
|
||||
|
||||
# Try to find similar symbol (fuzzy match)
|
||||
base = normalized.split('/')[0] if '/' in normalized else normalized
|
||||
quote = normalized.split('/')[1] if '/' in normalized else 'USD'
|
||||
|
||||
# Search for matching symbols
|
||||
for market_symbol in self.exchange.markets.keys():
|
||||
if market_symbol.startswith(base + '/') or market_symbol.endswith('/' + quote):
|
||||
return market_symbol
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Fallback: return normalized symbol (let CCXT handle errors)
|
||||
return normalized
|
||||
|
||||
def get_fee_structure(self) -> Dict[str, float]:
|
||||
"""Get fee structure for the selected exchange."""
|
||||
# Default fees (approximate spot trading fees)
|
||||
if not self.exchange or not hasattr(self.exchange, 'id'):
|
||||
return {'maker': 0.001, 'taker': 0.001}
|
||||
|
||||
exchange_id = self.exchange.id if hasattr(self.exchange, 'id') else None
|
||||
|
||||
# Exchange-specific fees (approximate)
|
||||
fee_map = {
|
||||
'binance': {'maker': 0.001, 'taker': 0.001}, # 0.1%
|
||||
'kraken': {'maker': 0.0016, 'taker': 0.0026}, # 0.16% / 0.26%
|
||||
'coinbase': {'maker': 0.004, 'taker': 0.006}, # 0.4% / 0.6%
|
||||
}
|
||||
|
||||
return fee_map.get(exchange_id, {'maker': 0.001, 'taker': 0.001})
|
||||
|
||||
def get_available_symbols(self) -> List[str]:
|
||||
"""Get list of available trading symbols.
|
||||
|
||||
Returns:
|
||||
List of available symbol pairs (e.g., ['BTC/USDT', 'ETH/USDT', ...])
|
||||
"""
|
||||
try:
|
||||
markets = self.exchange.load_markets()
|
||||
return list(markets.keys())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get available symbols: {e}")
|
||||
return []
|
||||
0
src/optimization/__init__.py
Normal file
0
src/optimization/__init__.py
Normal file
76
src/optimization/bayesian.py
Normal file
76
src/optimization/bayesian.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Bayesian optimization."""
|
||||
|
||||
from typing import Dict, Any, Callable
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BayesianOptimizer:
|
||||
"""Bayesian optimization using scikit-optimize."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Bayesian optimizer."""
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def optimize(
|
||||
self,
|
||||
param_space: Dict[str, Any],
|
||||
objective_function: Callable[[Dict[str, Any]], float],
|
||||
n_calls: int = 50,
|
||||
maximize: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""Run Bayesian optimization.
|
||||
|
||||
Args:
|
||||
param_space: Parameter space definition
|
||||
objective_function: Objective function
|
||||
n_calls: Number of optimization iterations
|
||||
maximize: True to maximize, False to minimize
|
||||
|
||||
Returns:
|
||||
Best parameters and score
|
||||
"""
|
||||
try:
|
||||
from skopt import gp_minimize
|
||||
from skopt.space import Real, Integer
|
||||
|
||||
# Convert param_space to skopt format
|
||||
dimensions = []
|
||||
param_names = []
|
||||
for name, space in param_space.items():
|
||||
param_names.append(name)
|
||||
if isinstance(space, tuple):
|
||||
if isinstance(space[0], int):
|
||||
dimensions.append(Integer(space[0], space[1], name=name))
|
||||
else:
|
||||
dimensions.append(Real(space[0], space[1], name=name))
|
||||
|
||||
# Wrapper function
|
||||
def objective(params):
|
||||
param_dict = dict(zip(param_names, params))
|
||||
score = objective_function(param_dict)
|
||||
return -score if maximize else score
|
||||
|
||||
result = gp_minimize(
|
||||
objective,
|
||||
dimensions,
|
||||
n_calls=n_calls,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
best_params = dict(zip(param_names, result.x))
|
||||
best_score = -result.fun if maximize else result.fun
|
||||
|
||||
return {
|
||||
"best_params": best_params,
|
||||
"best_score": best_score,
|
||||
}
|
||||
except ImportError:
|
||||
logger.warning("scikit-optimize not available, using grid search fallback")
|
||||
from .grid_search import GridSearchOptimizer
|
||||
optimizer = GridSearchOptimizer()
|
||||
# Convert param_space to grid format
|
||||
param_grid = {k: [v[0], v[1], (v[0] + v[1]) / 2] for k, v in param_space.items()}
|
||||
return optimizer.optimize(param_grid, objective_function, maximize)
|
||||
|
||||
111
src/optimization/genetic.py
Normal file
111
src/optimization/genetic.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Genetic algorithm optimization."""
|
||||
|
||||
from typing import Dict, List, Any, Callable
|
||||
import random
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GeneticOptimizer:
|
||||
"""Genetic algorithm parameter optimization."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize genetic optimizer."""
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def optimize(
|
||||
self,
|
||||
param_ranges: Dict[str, tuple],
|
||||
objective_function: Callable[[Dict[str, Any]], float],
|
||||
population_size: int = 50,
|
||||
generations: int = 100,
|
||||
maximize: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""Run genetic algorithm optimization.
|
||||
|
||||
Args:
|
||||
param_ranges: Dictionary of parameter -> (min, max) range
|
||||
objective_function: Objective function
|
||||
population_size: Population size
|
||||
generations: Number of generations
|
||||
maximize: True to maximize, False to minimize
|
||||
|
||||
Returns:
|
||||
Best parameters and score
|
||||
"""
|
||||
# Simplified genetic algorithm implementation
|
||||
best_score = float('-inf') if maximize else float('inf')
|
||||
best_params = None
|
||||
|
||||
# Initialize population
|
||||
population = self._initialize_population(param_ranges, population_size)
|
||||
|
||||
for generation in range(generations):
|
||||
# Evaluate fitness
|
||||
fitness_scores = []
|
||||
for individual in population:
|
||||
try:
|
||||
score = objective_function(individual)
|
||||
fitness_scores.append((individual, score))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Sort by fitness
|
||||
fitness_scores.sort(key=lambda x: x[1], reverse=maximize)
|
||||
|
||||
# Update best
|
||||
if fitness_scores:
|
||||
best_individual, best_gen_score = fitness_scores[0]
|
||||
if (maximize and best_gen_score > best_score) or (not maximize and best_gen_score < best_score):
|
||||
best_score = best_gen_score
|
||||
best_params = best_individual.copy()
|
||||
|
||||
# Create next generation (simplified)
|
||||
population = self._evolve_population(fitness_scores, param_ranges, population_size)
|
||||
|
||||
return {
|
||||
"best_params": best_params,
|
||||
"best_score": best_score,
|
||||
}
|
||||
|
||||
def _initialize_population(self, param_ranges: Dict, size: int) -> List[Dict]:
|
||||
"""Initialize random population."""
|
||||
population = []
|
||||
for _ in range(size):
|
||||
individual = {}
|
||||
for param, (min_val, max_val) in param_ranges.items():
|
||||
if isinstance(min_val, int):
|
||||
individual[param] = random.randint(min_val, max_val)
|
||||
else:
|
||||
individual[param] = random.uniform(min_val, max_val)
|
||||
population.append(individual)
|
||||
return population
|
||||
|
||||
def _evolve_population(self, fitness_scores: List, param_ranges: Dict, size: int) -> List[Dict]:
|
||||
"""Evolve population (crossover and mutation)."""
|
||||
# Simplified evolution
|
||||
new_population = []
|
||||
elite_size = size // 10
|
||||
|
||||
# Keep elite
|
||||
for i in range(min(elite_size, len(fitness_scores))):
|
||||
new_population.append(fitness_scores[i][0].copy())
|
||||
|
||||
# Generate rest through mutation
|
||||
while len(new_population) < size:
|
||||
parent = random.choice(fitness_scores[:size//2])[0]
|
||||
child = parent.copy()
|
||||
|
||||
# Mutate
|
||||
param = random.choice(list(param_ranges.keys()))
|
||||
min_val, max_val = param_ranges[param]
|
||||
if isinstance(min_val, int):
|
||||
child[param] = random.randint(min_val, max_val)
|
||||
else:
|
||||
child[param] = random.uniform(min_val, max_val)
|
||||
|
||||
new_population.append(child)
|
||||
|
||||
return new_population[:size]
|
||||
|
||||
57
src/optimization/grid_search.py
Normal file
57
src/optimization/grid_search.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Grid search optimization."""
|
||||
|
||||
from typing import Dict, List, Any, Callable
|
||||
from itertools import product
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GridSearchOptimizer:
|
||||
"""Grid search parameter optimization."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize grid search optimizer."""
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def optimize(
|
||||
self,
|
||||
param_grid: Dict[str, List[Any]],
|
||||
objective_function: Callable[[Dict[str, Any]], float],
|
||||
maximize: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""Run grid search optimization.
|
||||
|
||||
Args:
|
||||
param_grid: Dictionary of parameter -> list of values
|
||||
objective_function: Function that takes parameters and returns score
|
||||
maximize: True to maximize, False to minimize
|
||||
|
||||
Returns:
|
||||
Best parameters and score
|
||||
"""
|
||||
best_score = float('-inf') if maximize else float('inf')
|
||||
best_params = None
|
||||
|
||||
# Generate all parameter combinations
|
||||
param_names = list(param_grid.keys())
|
||||
param_values = list(param_grid.values())
|
||||
|
||||
for combination in product(*param_values):
|
||||
params = dict(zip(param_names, combination))
|
||||
|
||||
try:
|
||||
score = objective_function(params)
|
||||
|
||||
if (maximize and score > best_score) or (not maximize and score < best_score):
|
||||
best_score = score
|
||||
best_params = params
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to evaluate parameters {params}: {e}")
|
||||
continue
|
||||
|
||||
return {
|
||||
"best_params": best_params,
|
||||
"best_score": best_score,
|
||||
}
|
||||
|
||||
0
src/portfolio/__init__.py
Normal file
0
src/portfolio/__init__.py
Normal file
265
src/portfolio/analytics.py
Normal file
265
src/portfolio/analytics.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Advanced portfolio analytics (Sharpe ratio, Sortino, drawdown analysis, performance charts)."""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import select
|
||||
from src.core.database import get_database, PortfolioSnapshot, Trade
|
||||
from src.core.logger import get_logger
|
||||
from .tracker import get_portfolio_tracker
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PortfolioAnalytics:
|
||||
"""Advanced portfolio analytics."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize portfolio analytics."""
|
||||
self.db = get_database()
|
||||
self.tracker = get_portfolio_tracker()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def calculate_sharpe_ratio(
|
||||
self,
|
||||
returns: pd.Series,
|
||||
risk_free_rate: float = 0.0,
|
||||
periods_per_year: int = 252
|
||||
) -> float:
|
||||
"""Calculate Sharpe ratio.
|
||||
|
||||
Args:
|
||||
returns: Series of returns
|
||||
risk_free_rate: Risk-free rate (annual)
|
||||
periods_per_year: Trading periods per year
|
||||
|
||||
Returns:
|
||||
Sharpe ratio
|
||||
"""
|
||||
if len(returns) == 0 or returns.std() == 0:
|
||||
return 0.0
|
||||
|
||||
excess_returns = returns - (risk_free_rate / periods_per_year)
|
||||
return float(np.sqrt(periods_per_year) * excess_returns.mean() / returns.std())
|
||||
|
||||
def calculate_sortino_ratio(
|
||||
self,
|
||||
returns: pd.Series,
|
||||
risk_free_rate: float = 0.0,
|
||||
periods_per_year: int = 252
|
||||
) -> float:
|
||||
"""Calculate Sortino ratio.
|
||||
|
||||
Args:
|
||||
returns: Series of returns
|
||||
risk_free_rate: Risk-free rate (annual)
|
||||
periods_per_year: Trading periods per year
|
||||
|
||||
Returns:
|
||||
Sortino ratio
|
||||
"""
|
||||
if len(returns) == 0:
|
||||
return 0.0
|
||||
|
||||
excess_returns = returns - (risk_free_rate / periods_per_year)
|
||||
downside_returns = returns[returns < 0]
|
||||
|
||||
if len(downside_returns) == 0 or downside_returns.std() == 0:
|
||||
return 0.0
|
||||
|
||||
downside_std = downside_returns.std()
|
||||
return float(np.sqrt(periods_per_year) * excess_returns.mean() / downside_std)
|
||||
|
||||
def calculate_drawdown(self, values: pd.Series) -> Dict[str, Any]:
|
||||
"""Calculate drawdown metrics.
|
||||
|
||||
Args:
|
||||
values: Series of portfolio values
|
||||
|
||||
Returns:
|
||||
Dictionary with drawdown metrics
|
||||
"""
|
||||
if len(values) == 0:
|
||||
return {
|
||||
"max_drawdown": 0.0,
|
||||
"current_drawdown": 0.0,
|
||||
"drawdown_duration": 0,
|
||||
}
|
||||
|
||||
# Calculate running maximum
|
||||
running_max = values.expanding().max()
|
||||
drawdown = (values - running_max) / running_max
|
||||
|
||||
max_drawdown = float(drawdown.min())
|
||||
current_drawdown = float(drawdown.iloc[-1])
|
||||
|
||||
# Calculate drawdown duration
|
||||
in_drawdown = drawdown < 0
|
||||
drawdown_duration = 0
|
||||
if in_drawdown.iloc[-1]:
|
||||
# Count consecutive days in drawdown
|
||||
for i in range(len(in_drawdown) - 1, -1, -1):
|
||||
if in_drawdown.iloc[i]:
|
||||
drawdown_duration += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return {
|
||||
"max_drawdown": max_drawdown,
|
||||
"current_drawdown": current_drawdown,
|
||||
"drawdown_duration": drawdown_duration,
|
||||
"drawdown_series": drawdown.tolist(),
|
||||
}
|
||||
|
||||
async def get_performance_metrics(
|
||||
self,
|
||||
days: int = 30,
|
||||
paper_trading: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comprehensive performance metrics.
|
||||
|
||||
Args:
|
||||
days: Number of days to analyze
|
||||
paper_trading: Paper trading flag
|
||||
|
||||
Returns:
|
||||
Dictionary of performance metrics
|
||||
"""
|
||||
history = await self.tracker.get_portfolio_history(days, paper_trading)
|
||||
|
||||
if len(history) < 2:
|
||||
return {
|
||||
"total_return": 0.0,
|
||||
"sharpe_ratio": 0.0,
|
||||
"sortino_ratio": 0.0,
|
||||
"max_drawdown": 0.0,
|
||||
"win_rate": 0.0,
|
||||
}
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(history)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
||||
df = df.set_index('timestamp').sort_index()
|
||||
|
||||
# Calculate returns
|
||||
returns = df['total_value'].pct_change().dropna()
|
||||
|
||||
# Calculate metrics
|
||||
initial_value = df['total_value'].iloc[0]
|
||||
final_value = df['total_value'].iloc[-1]
|
||||
total_return = (final_value - initial_value) / initial_value
|
||||
|
||||
sharpe = self.calculate_sharpe_ratio(returns)
|
||||
sortino = self.calculate_sortino_ratio(returns)
|
||||
drawdown = self.calculate_drawdown(df['total_value'])
|
||||
|
||||
# Calculate win rate from trades
|
||||
win_rate = await self._calculate_win_rate(days, paper_trading)
|
||||
|
||||
# Calculate fee metrics
|
||||
fee_metrics = await self._calculate_fee_metrics(days, paper_trading, initial_value)
|
||||
|
||||
return {
|
||||
"total_return": float(total_return),
|
||||
"total_return_percent": float(total_return * 100),
|
||||
"sharpe_ratio": sharpe,
|
||||
"sortino_ratio": sortino,
|
||||
"max_drawdown": drawdown["max_drawdown"],
|
||||
"current_drawdown": drawdown["current_drawdown"],
|
||||
"win_rate": win_rate,
|
||||
"initial_value": float(initial_value),
|
||||
"final_value": float(final_value),
|
||||
**fee_metrics, # Include fee metrics
|
||||
}
|
||||
|
||||
async def _calculate_fee_metrics(
|
||||
self,
|
||||
days: int,
|
||||
paper_trading: bool,
|
||||
initial_value: float
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate fee-related metrics.
|
||||
|
||||
Args:
|
||||
days: Number of days
|
||||
paper_trading: Paper trading flag
|
||||
initial_value: Initial portfolio value
|
||||
|
||||
Returns:
|
||||
Dictionary of fee metrics
|
||||
"""
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
stmt = select(Trade).where(
|
||||
Trade.paper_trading == paper_trading,
|
||||
Trade.timestamp >= since
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
trades = result.scalars().all()
|
||||
|
||||
total_fees = sum(float(trade.fee or 0) for trade in trades)
|
||||
total_trades = len(trades)
|
||||
avg_fee_per_trade = total_fees / total_trades if total_trades > 0 else 0.0
|
||||
|
||||
# Calculate fee percentage of initial value
|
||||
fee_percentage = (total_fees / initial_value * 100) if initial_value > 0 else 0.0
|
||||
|
||||
return {
|
||||
"total_fees": total_fees,
|
||||
"avg_fee_per_trade": avg_fee_per_trade,
|
||||
"fee_percentage": fee_percentage,
|
||||
"total_trades_with_fees": total_trades,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating fee metrics: {e}")
|
||||
return {
|
||||
"total_fees": 0.0,
|
||||
"avg_fee_per_trade": 0.0,
|
||||
"fee_percentage": 0.0,
|
||||
"total_trades_with_fees": 0,
|
||||
}
|
||||
|
||||
async def _calculate_win_rate(self, days: int, paper_trading: bool) -> float:
|
||||
"""Calculate win rate from trades.
|
||||
|
||||
Args:
|
||||
days: Number of days
|
||||
paper_trading: Paper trading flag
|
||||
|
||||
Returns:
|
||||
Win rate (0.0 to 1.0)
|
||||
"""
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
stmt = select(Trade).where(
|
||||
Trade.paper_trading == paper_trading,
|
||||
Trade.timestamp >= since
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
trades = result.scalars().all()
|
||||
|
||||
if len(trades) == 0:
|
||||
return 0.0
|
||||
|
||||
# Simplified win rate calculation
|
||||
# In practice, would need to match buy/sell pairs
|
||||
return 0.5 # Placeholder
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating win rate: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
# Global portfolio analytics
|
||||
_portfolio_analytics: Optional[PortfolioAnalytics] = None
|
||||
|
||||
|
||||
def get_portfolio_analytics() -> PortfolioAnalytics:
|
||||
"""Get global portfolio analytics instance."""
|
||||
global _portfolio_analytics
|
||||
if _portfolio_analytics is None:
|
||||
_portfolio_analytics = PortfolioAnalytics()
|
||||
return _portfolio_analytics
|
||||
144
src/portfolio/tracker.py
Normal file
144
src/portfolio/tracker.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Portfolio tracking with real-time P&L calculation."""
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from sqlalchemy import select
|
||||
from src.core.database import get_database, Position, PortfolioSnapshot, Trade
|
||||
from src.core.logger import get_logger
|
||||
from src.trading.paper_trading import get_paper_trading
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PortfolioTracker:
|
||||
"""Tracks portfolio with real-time P&L calculation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize portfolio tracker."""
|
||||
self.db = get_database()
|
||||
self.paper_trading = get_paper_trading()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def get_current_portfolio(self, paper_trading: bool = True) -> Dict[str, Any]:
|
||||
"""Get current portfolio state.
|
||||
|
||||
Args:
|
||||
paper_trading: Paper trading flag
|
||||
|
||||
Returns:
|
||||
Portfolio dictionary
|
||||
"""
|
||||
if paper_trading:
|
||||
performance = self.paper_trading.get_performance()
|
||||
positions = self.paper_trading.get_positions()
|
||||
else:
|
||||
# Live trading - get from database
|
||||
async with self.db.get_session() as session:
|
||||
stmt = select(Position).where(Position.paper_trading == False)
|
||||
result = await session.execute(stmt)
|
||||
positions = result.scalars().all()
|
||||
|
||||
# Calculate performance from positions
|
||||
total_value = Decimal(0)
|
||||
unrealized_pnl = Decimal(0)
|
||||
for pos in positions:
|
||||
if pos.current_price:
|
||||
pos_value = pos.quantity * pos.current_price
|
||||
total_value += pos_value
|
||||
unrealized_pnl += (pos.current_price - pos.entry_price) * pos.quantity
|
||||
|
||||
performance = {
|
||||
"current_value": float(total_value),
|
||||
"unrealized_pnl": float(unrealized_pnl),
|
||||
"realized_pnl": float(sum(pos.realized_pnl for pos in positions)),
|
||||
}
|
||||
|
||||
return {
|
||||
"positions": [
|
||||
{
|
||||
"symbol": pos.symbol if hasattr(pos, 'symbol') else pos.symbol,
|
||||
"quantity": float(pos.quantity),
|
||||
"entry_price": float(pos.entry_price),
|
||||
"current_price": float(pos.current_price) if pos.current_price else float(pos.entry_price),
|
||||
"unrealized_pnl": float(pos.unrealized_pnl) if hasattr(pos, 'unrealized_pnl') else 0.0,
|
||||
}
|
||||
for pos in positions
|
||||
],
|
||||
"performance": performance,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def update_positions_prices(self, prices: Dict[str, Decimal], paper_trading: bool = True):
|
||||
"""Update current prices for positions.
|
||||
|
||||
Args:
|
||||
prices: Dictionary of symbol -> current_price
|
||||
paper_trading: Paper trading flag
|
||||
"""
|
||||
if paper_trading:
|
||||
await self.paper_trading.update_positions_prices(prices)
|
||||
else:
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
stmt = select(Position).where(Position.paper_trading == False)
|
||||
result = await session.execute(stmt)
|
||||
positions = result.scalars().all()
|
||||
|
||||
for pos in positions:
|
||||
if pos.symbol in prices:
|
||||
pos.current_price = prices[pos.symbol]
|
||||
pos.unrealized_pnl = (prices[pos.symbol] - pos.entry_price) * pos.quantity
|
||||
pos.updated_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to update positions prices: {e}")
|
||||
|
||||
async def get_portfolio_history(
|
||||
self,
|
||||
days: int = 30,
|
||||
paper_trading: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get portfolio history.
|
||||
|
||||
Args:
|
||||
days: Number of days
|
||||
paper_trading: Paper trading flag
|
||||
|
||||
Returns:
|
||||
List of portfolio snapshots
|
||||
"""
|
||||
async with self.db.get_session() as session:
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
stmt = select(PortfolioSnapshot).where(
|
||||
PortfolioSnapshot.paper_trading == paper_trading,
|
||||
PortfolioSnapshot.timestamp >= since
|
||||
).order_by(PortfolioSnapshot.timestamp)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
snapshots = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"timestamp": snapshot.timestamp.isoformat(),
|
||||
"total_value": float(snapshot.total_value),
|
||||
"cash": float(snapshot.cash),
|
||||
"positions_value": float(snapshot.positions_value),
|
||||
"unrealized_pnl": float(snapshot.unrealized_pnl),
|
||||
"realized_pnl": float(snapshot.realized_pnl),
|
||||
}
|
||||
for snapshot in snapshots
|
||||
]
|
||||
|
||||
|
||||
# Global portfolio tracker
|
||||
_portfolio_tracker: Optional[PortfolioTracker] = None
|
||||
|
||||
|
||||
def get_portfolio_tracker() -> PortfolioTracker:
|
||||
"""Get global portfolio tracker instance."""
|
||||
global _portfolio_tracker
|
||||
if _portfolio_tracker is None:
|
||||
_portfolio_tracker = PortfolioTracker()
|
||||
return _portfolio_tracker
|
||||
0
src/rebalancing/__init__.py
Normal file
0
src/rebalancing/__init__.py
Normal file
196
src/rebalancing/engine.py
Normal file
196
src/rebalancing/engine.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Portfolio rebalancing engine."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, RebalancingEvent
|
||||
from src.core.logger import get_logger
|
||||
from src.portfolio.tracker import get_portfolio_tracker
|
||||
from src.trading.engine import get_trading_engine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RebalancingEngine:
|
||||
"""Portfolio rebalancing engine."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize rebalancing engine."""
|
||||
self.db = get_database()
|
||||
self.tracker = get_portfolio_tracker()
|
||||
self.trading_engine = get_trading_engine()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def rebalance(
|
||||
self,
|
||||
target_allocations: Dict[str, float],
|
||||
exchange_id: int,
|
||||
paper_trading: bool = True
|
||||
) -> bool:
|
||||
"""Rebalance portfolio to target allocations.
|
||||
|
||||
Args:
|
||||
target_allocations: Dictionary of symbol -> target percentage
|
||||
exchange_id: Exchange ID
|
||||
paper_trading: Paper trading flag
|
||||
|
||||
Returns:
|
||||
True if rebalancing successful
|
||||
"""
|
||||
try:
|
||||
# Get current portfolio
|
||||
portfolio = self.tracker.get_current_portfolio(paper_trading)
|
||||
total_value = portfolio['performance']['current_value']
|
||||
|
||||
# Calculate current allocations
|
||||
current_allocations = {}
|
||||
for pos in portfolio['positions']:
|
||||
pos_value = pos['quantity'] * pos['current_price']
|
||||
current_allocations[pos['symbol']] = float(pos_value / total_value) if total_value > 0 else 0.0
|
||||
|
||||
# Get exchange adapter for fee calculations
|
||||
adapter = await self.trading_engine.get_exchange_adapter(exchange_id)
|
||||
|
||||
# Calculate required trades, factoring in fees
|
||||
orders = []
|
||||
from src.trading.fee_calculator import get_fee_calculator
|
||||
fee_calculator = get_fee_calculator()
|
||||
|
||||
# Get fee threshold from config (default 0.5% to account for round-trip fees)
|
||||
fee_threshold = Decimal(str(self.tracker.db.get_session().query(
|
||||
# Get from config
|
||||
))) if False else Decimal("0.005") # 0.5% default threshold
|
||||
|
||||
for symbol, target_pct in target_allocations.items():
|
||||
current_pct = current_allocations.get(symbol, 0.0)
|
||||
deviation = target_pct - current_pct
|
||||
|
||||
# Only rebalance if deviation exceeds fee threshold
|
||||
# Default threshold is 1%, but we'll use a configurable fee-aware threshold
|
||||
min_deviation = max(Decimal("0.01"), fee_threshold) # At least 1% or fee threshold
|
||||
|
||||
if abs(deviation) > min_deviation:
|
||||
target_value = total_value * Decimal(str(target_pct))
|
||||
current_value = Decimal(str(current_allocations.get(symbol, 0.0))) * total_value
|
||||
trade_value = target_value - current_value
|
||||
|
||||
# Get current price
|
||||
if adapter:
|
||||
ticker = await adapter.get_ticker(symbol)
|
||||
price = ticker.get('last', Decimal(0))
|
||||
|
||||
if price > 0:
|
||||
# Estimate fee for this trade
|
||||
estimated_quantity = abs(trade_value / price)
|
||||
estimated_fee = fee_calculator.estimate_round_trip_fee(
|
||||
quantity=estimated_quantity,
|
||||
price=price,
|
||||
exchange_adapter=adapter
|
||||
)
|
||||
|
||||
# Adjust trade value to account for fees
|
||||
# For buy: reduce quantity to account for fee
|
||||
# For sell: fee comes from proceeds
|
||||
if trade_value > 0: # Buy
|
||||
# Reduce trade value by estimated fee
|
||||
adjusted_trade_value = trade_value - estimated_fee
|
||||
quantity = adjusted_trade_value / price if price > 0 else Decimal(0)
|
||||
else: # Sell
|
||||
# Fee comes from proceeds, so quantity stays the same
|
||||
quantity = abs(trade_value / price)
|
||||
|
||||
if quantity > 0:
|
||||
side = 'buy' if trade_value > 0 else 'sell'
|
||||
orders.append({
|
||||
'symbol': symbol,
|
||||
'side': side,
|
||||
'quantity': quantity,
|
||||
'price': price,
|
||||
})
|
||||
|
||||
# Execute rebalancing orders
|
||||
executed_orders = []
|
||||
for order in orders:
|
||||
from src.core.database import OrderSide, OrderType
|
||||
side = OrderSide.BUY if order['side'] == 'buy' else OrderSide.SELL
|
||||
|
||||
result = await self.trading_engine.execute_order(
|
||||
exchange_id=exchange_id,
|
||||
strategy_id=None,
|
||||
symbol=order['symbol'],
|
||||
side=side,
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=order['quantity'],
|
||||
paper_trading=paper_trading
|
||||
)
|
||||
|
||||
if result:
|
||||
executed_orders.append(result.id)
|
||||
|
||||
# Record rebalancing event
|
||||
self._record_rebalancing_event(
|
||||
'manual',
|
||||
target_allocations,
|
||||
current_allocations,
|
||||
executed_orders
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rebalance portfolio: {e}")
|
||||
return False
|
||||
|
||||
def _record_rebalancing_event(
|
||||
self,
|
||||
trigger_type: str,
|
||||
target_allocations: Dict[str, float],
|
||||
before_allocations: Dict[str, float],
|
||||
orders_placed: List[int]
|
||||
):
|
||||
"""Record rebalancing event in database.
|
||||
|
||||
Args:
|
||||
trigger_type: Trigger type
|
||||
target_allocations: Target allocations
|
||||
before_allocations: Allocations before rebalancing
|
||||
orders_placed: List of order IDs
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
# Get after allocations
|
||||
portfolio = self.tracker.get_current_portfolio()
|
||||
total_value = portfolio['performance']['current_value']
|
||||
after_allocations = {}
|
||||
for pos in portfolio['positions']:
|
||||
pos_value = pos['quantity'] * pos['current_price']
|
||||
after_allocations[pos['symbol']] = float(pos_value / total_value) if total_value > 0 else 0.0
|
||||
|
||||
event = RebalancingEvent(
|
||||
trigger_type=trigger_type,
|
||||
target_allocations=target_allocations,
|
||||
before_allocations=before_allocations,
|
||||
after_allocations=after_allocations,
|
||||
orders_placed=orders_placed,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
session.add(event)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to record rebalancing event: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# Global rebalancing engine
|
||||
_rebalancing_engine: Optional[RebalancingEngine] = None
|
||||
|
||||
|
||||
def get_rebalancing_engine() -> RebalancingEngine:
|
||||
"""Get global rebalancing engine instance."""
|
||||
global _rebalancing_engine
|
||||
if _rebalancing_engine is None:
|
||||
_rebalancing_engine = RebalancingEngine()
|
||||
return _rebalancing_engine
|
||||
|
||||
36
src/rebalancing/strategies.py
Normal file
36
src/rebalancing/strategies.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Rebalancing strategies."""
|
||||
|
||||
from typing import Dict
|
||||
from decimal import Decimal
|
||||
|
||||
class RebalancingStrategy:
|
||||
"""Base rebalancing strategy."""
|
||||
|
||||
def calculate_trades(
|
||||
self,
|
||||
current_allocations: Dict[str, float],
|
||||
target_allocations: Dict[str, float],
|
||||
total_value: Decimal
|
||||
) -> Dict[str, Decimal]:
|
||||
"""Calculate required trades.
|
||||
|
||||
Args:
|
||||
current_allocations: Current allocations
|
||||
target_allocations: Target allocations
|
||||
total_value: Total portfolio value
|
||||
|
||||
Returns:
|
||||
Dictionary of symbol -> trade value (positive = buy, negative = sell)
|
||||
"""
|
||||
trades = {}
|
||||
for symbol, target_pct in target_allocations.items():
|
||||
current_pct = current_allocations.get(symbol, 0.0)
|
||||
deviation = target_pct - current_pct
|
||||
|
||||
if abs(deviation) > 0.01: # 1% threshold
|
||||
target_value = total_value * Decimal(str(target_pct))
|
||||
current_value = Decimal(str(current_pct)) * total_value
|
||||
trades[symbol] = target_value - current_value
|
||||
|
||||
return trades
|
||||
|
||||
0
src/reporting/__init__.py
Normal file
0
src/reporting/__init__.py
Normal file
120
src/reporting/csv_exporter.py
Normal file
120
src/reporting/csv_exporter.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""CSV export functionality."""
|
||||
|
||||
import csv
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, Trade, Order
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CSVExporter:
|
||||
"""CSV export functionality."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CSV exporter."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def export_trades(
|
||||
self,
|
||||
filepath: Path,
|
||||
paper_trading: bool = True,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> bool:
|
||||
"""Export trades to CSV.
|
||||
|
||||
Args:
|
||||
filepath: Output file path
|
||||
paper_trading: Filter by paper trading
|
||||
start_date: Start date filter
|
||||
end_date: End date filter
|
||||
|
||||
Returns:
|
||||
True if export successful
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
query = session.query(Trade).filter_by(paper_trading=paper_trading)
|
||||
if start_date:
|
||||
query = query.filter(Trade.timestamp >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Trade.timestamp <= end_date)
|
||||
|
||||
trades = query.order_by(Trade.timestamp).all()
|
||||
|
||||
with open(filepath, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([
|
||||
'Timestamp', 'Symbol', 'Side', 'Quantity', 'Price', 'Fee', 'Total'
|
||||
])
|
||||
|
||||
for trade in trades:
|
||||
writer.writerow([
|
||||
trade.timestamp.isoformat(),
|
||||
trade.symbol,
|
||||
trade.side.value,
|
||||
float(trade.quantity),
|
||||
float(trade.price),
|
||||
float(trade.fee),
|
||||
float(trade.total),
|
||||
])
|
||||
|
||||
logger.info(f"Exported {len(trades)} trades to {filepath}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export trades: {e}")
|
||||
return False
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def export_portfolio(self, filepath: Path) -> bool:
|
||||
"""Export portfolio snapshot to CSV.
|
||||
|
||||
Args:
|
||||
filepath: Output file path
|
||||
|
||||
Returns:
|
||||
True if export successful
|
||||
"""
|
||||
from src.portfolio.tracker import get_portfolio_tracker
|
||||
|
||||
tracker = get_portfolio_tracker()
|
||||
portfolio = tracker.get_current_portfolio()
|
||||
|
||||
try:
|
||||
with open(filepath, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(['Symbol', 'Quantity', 'Entry Price', 'Current Price', 'Unrealized P&L'])
|
||||
|
||||
for pos in portfolio['positions']:
|
||||
writer.writerow([
|
||||
pos['symbol'],
|
||||
pos['quantity'],
|
||||
pos['entry_price'],
|
||||
pos['current_price'],
|
||||
pos['unrealized_pnl'],
|
||||
])
|
||||
|
||||
logger.info(f"Exported portfolio to {filepath}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export portfolio: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Global CSV exporter
|
||||
_csv_exporter: Optional[CSVExporter] = None
|
||||
|
||||
|
||||
def get_csv_exporter() -> CSVExporter:
|
||||
"""Get global CSV exporter instance."""
|
||||
global _csv_exporter
|
||||
if _csv_exporter is None:
|
||||
_csv_exporter = CSVExporter()
|
||||
return _csv_exporter
|
||||
|
||||
111
src/reporting/pdf_generator.py
Normal file
111
src/reporting/pdf_generator.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""PDF report generation."""
|
||||
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
|
||||
from reportlab.lib.styles import getSampleStyleSheet
|
||||
from reportlab.lib import colors
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PDFGenerator:
|
||||
"""PDF report generation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize PDF generator."""
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def generate_performance_report(
|
||||
self,
|
||||
filepath: Path,
|
||||
metrics: Dict[str, Any],
|
||||
title: str = "Portfolio Performance Report"
|
||||
) -> bool:
|
||||
"""Generate performance report PDF.
|
||||
|
||||
Args:
|
||||
filepath: Output file path
|
||||
metrics: Performance metrics dictionary
|
||||
title: Report title
|
||||
|
||||
Returns:
|
||||
True if generation successful
|
||||
"""
|
||||
try:
|
||||
doc = SimpleDocTemplate(str(filepath), pagesize=letter)
|
||||
story = []
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
# Title
|
||||
story.append(Paragraph(title, styles['Title']))
|
||||
story.append(Spacer(1, 12))
|
||||
|
||||
# Metrics table
|
||||
data = [
|
||||
['Metric', 'Value'],
|
||||
['Total Return', f"{metrics.get('total_return_percent', 0):.2f}%"],
|
||||
['Sharpe Ratio', f"{metrics.get('sharpe_ratio', 0):.2f}"],
|
||||
['Sortino Ratio', f"{metrics.get('sortino_ratio', 0):.2f}"],
|
||||
['Max Drawdown', f"{metrics.get('max_drawdown', 0):.2%}"],
|
||||
['Win Rate', f"{metrics.get('win_rate', 0):.2%}"],
|
||||
]
|
||||
|
||||
table = Table(data)
|
||||
table.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
||||
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 14),
|
||||
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
|
||||
('BACKGROUND', (0, 1), (-1, -1), colors.beige),
|
||||
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
||||
]))
|
||||
|
||||
story.append(table)
|
||||
story.append(Spacer(1, 12))
|
||||
story.append(Paragraph(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", styles['Normal']))
|
||||
|
||||
doc.build(story)
|
||||
logger.info(f"Generated PDF report: {filepath}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate PDF report: {e}")
|
||||
return False
|
||||
|
||||
def generate_backtest_report(
|
||||
self,
|
||||
results: Dict[str, Any],
|
||||
filepath: str
|
||||
) -> bool:
|
||||
"""Generate backtest report PDF.
|
||||
|
||||
Args:
|
||||
results: Backtest results dictionary
|
||||
filepath: Output file path
|
||||
|
||||
Returns:
|
||||
True if generation successful
|
||||
"""
|
||||
return self.generate_performance_report(
|
||||
Path(filepath),
|
||||
results,
|
||||
"Backtest Report"
|
||||
)
|
||||
|
||||
|
||||
# Global PDF generator
|
||||
_pdf_generator: Optional[PDFGenerator] = None
|
||||
|
||||
|
||||
def get_pdf_generator() -> PDFGenerator:
|
||||
"""Get global PDF generator instance."""
|
||||
global _pdf_generator
|
||||
if _pdf_generator is None:
|
||||
_pdf_generator = PDFGenerator()
|
||||
return _pdf_generator
|
||||
|
||||
127
src/reporting/tax_reporter.py
Normal file
127
src/reporting/tax_reporter.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Tax reporting (FIFO, LIFO, specific identification)."""
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, Trade
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TaxReporter:
|
||||
"""Tax reporting with different cost basis methods."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize tax reporter."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def generate_fifo_report(
|
||||
self,
|
||||
symbol: str,
|
||||
year: int,
|
||||
paper_trading: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate FIFO tax report.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
year: Tax year
|
||||
paper_trading: Paper trading flag
|
||||
|
||||
Returns:
|
||||
List of taxable events
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
start_date = datetime(year, 1, 1)
|
||||
end_date = datetime(year, 12, 31, 23, 59, 59)
|
||||
|
||||
trades = session.query(Trade).filter(
|
||||
Trade.symbol == symbol,
|
||||
Trade.paper_trading == paper_trading,
|
||||
Trade.timestamp >= start_date,
|
||||
Trade.timestamp <= end_date
|
||||
).order_by(Trade.timestamp).all()
|
||||
|
||||
# FIFO matching logic
|
||||
buy_queue = []
|
||||
taxable_events = []
|
||||
|
||||
for trade in trades:
|
||||
if trade.side.value == 'buy':
|
||||
buy_queue.append({
|
||||
'date': trade.timestamp,
|
||||
'quantity': trade.quantity,
|
||||
'price': trade.price,
|
||||
'fee': trade.fee,
|
||||
})
|
||||
else: # sell
|
||||
remaining = trade.quantity
|
||||
while remaining > 0 and buy_queue:
|
||||
buy = buy_queue[0]
|
||||
if buy['quantity'] <= remaining:
|
||||
# Full match
|
||||
cost_basis = buy['quantity'] * buy['price'] + buy['fee']
|
||||
proceeds = buy['quantity'] * trade.price - (buy['quantity'] / trade.quantity) * trade.fee
|
||||
gain_loss = proceeds - cost_basis
|
||||
|
||||
taxable_events.append({
|
||||
'date': trade.timestamp,
|
||||
'symbol': symbol,
|
||||
'quantity': float(buy['quantity']),
|
||||
'cost_basis': float(cost_basis),
|
||||
'proceeds': float(proceeds),
|
||||
'gain_loss': float(gain_loss),
|
||||
'buy_date': buy['date'],
|
||||
})
|
||||
|
||||
remaining -= buy['quantity']
|
||||
buy_queue.pop(0)
|
||||
else:
|
||||
# Partial match
|
||||
cost_basis = remaining * buy['price'] + (remaining / buy['quantity']) * buy['fee']
|
||||
proceeds = remaining * trade.price - (remaining / trade.quantity) * trade.fee
|
||||
gain_loss = proceeds - cost_basis
|
||||
|
||||
taxable_events.append({
|
||||
'date': trade.timestamp,
|
||||
'symbol': symbol,
|
||||
'quantity': float(remaining),
|
||||
'cost_basis': float(cost_basis),
|
||||
'proceeds': float(proceeds),
|
||||
'gain_loss': float(gain_loss),
|
||||
'buy_date': buy['date'],
|
||||
})
|
||||
|
||||
buy['quantity'] -= remaining
|
||||
remaining = 0
|
||||
|
||||
return taxable_events
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def generate_lifo_report(
|
||||
self,
|
||||
symbol: str,
|
||||
year: int,
|
||||
paper_trading: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate LIFO tax report (similar to FIFO but uses stack instead of queue)."""
|
||||
# Similar to FIFO but uses stack (LIFO)
|
||||
return self.generate_fifo_report(symbol, year, paper_trading) # Simplified
|
||||
|
||||
|
||||
# Global tax reporter
|
||||
_tax_reporter: Optional[TaxReporter] = None
|
||||
|
||||
|
||||
def get_tax_reporter() -> TaxReporter:
|
||||
"""Get global tax reporter instance."""
|
||||
global _tax_reporter
|
||||
if _tax_reporter is None:
|
||||
_tax_reporter = TaxReporter()
|
||||
return _tax_reporter
|
||||
|
||||
0
src/resilience/__init__.py
Normal file
0
src/resilience/__init__.py
Normal file
100
src/resilience/health_monitor.py
Normal file
100
src/resilience/health_monitor.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Health monitoring and self-healing."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""Monitors system health."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize health monitor."""
|
||||
self.logger = get_logger(__name__)
|
||||
self.errors: List[Dict] = []
|
||||
self.connections: Dict[str, bool] = {}
|
||||
self.last_check: Optional[datetime] = None
|
||||
|
||||
def record_error(self, context: str, error: Exception):
|
||||
"""Record an error.
|
||||
|
||||
Args:
|
||||
context: Error context
|
||||
error: Exception object
|
||||
"""
|
||||
self.errors.append({
|
||||
"timestamp": datetime.utcnow(),
|
||||
"context": context,
|
||||
"error": str(error),
|
||||
})
|
||||
|
||||
# Keep only last 100 errors
|
||||
if len(self.errors) > 100:
|
||||
self.errors = self.errors[-100:]
|
||||
|
||||
def check_health(self) -> Dict[str, bool]:
|
||||
"""Check system health.
|
||||
|
||||
Returns:
|
||||
Dictionary of component -> healthy
|
||||
"""
|
||||
health = {
|
||||
"database": self._check_database(),
|
||||
"exchanges": self._check_exchanges(),
|
||||
}
|
||||
|
||||
self.last_check = datetime.utcnow()
|
||||
return health
|
||||
|
||||
def _check_database(self) -> bool:
|
||||
"""Check database connection.
|
||||
|
||||
Returns:
|
||||
True if healthy
|
||||
"""
|
||||
try:
|
||||
from src.core.database import get_database
|
||||
db = get_database()
|
||||
session = db.get_session()
|
||||
session.execute("SELECT 1")
|
||||
session.close()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _check_exchanges(self) -> bool:
|
||||
"""Check exchange connections.
|
||||
|
||||
Returns:
|
||||
True if at least one exchange is connected
|
||||
"""
|
||||
# Simplified - would check actual connections
|
||||
return len(self.connections) > 0
|
||||
|
||||
def get_error_rate(self, minutes: int = 60) -> float:
|
||||
"""Get error rate over time period.
|
||||
|
||||
Args:
|
||||
minutes: Time period in minutes
|
||||
|
||||
Returns:
|
||||
Errors per minute
|
||||
"""
|
||||
since = datetime.utcnow() - timedelta(minutes=minutes)
|
||||
recent_errors = [e for e in self.errors if e["timestamp"] >= since]
|
||||
return len(recent_errors) / minutes if minutes > 0 else 0.0
|
||||
|
||||
|
||||
# Global health monitor
|
||||
_health_monitor: Optional[HealthMonitor] = None
|
||||
|
||||
|
||||
def get_health_monitor() -> HealthMonitor:
|
||||
"""Get global health monitor instance."""
|
||||
global _health_monitor
|
||||
if _health_monitor is None:
|
||||
_health_monitor = HealthMonitor()
|
||||
return _health_monitor
|
||||
|
||||
103
src/resilience/recovery.py
Normal file
103
src/resilience/recovery.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Error recovery mechanisms."""
|
||||
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any
|
||||
from src.core.logger import get_logger
|
||||
from .state_manager import get_state_manager
|
||||
from .health_monitor import get_health_monitor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RecoveryManager:
|
||||
"""Manages error recovery and system resilience."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize recovery manager."""
|
||||
self.state_manager = get_state_manager()
|
||||
self.health_monitor = get_health_monitor()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def handle_error(self, error: Exception, context: str = ""):
|
||||
"""Handle errors with recovery.
|
||||
|
||||
Args:
|
||||
error: Exception object
|
||||
context: Error context
|
||||
"""
|
||||
error_msg = f"{context}: {str(error)}"
|
||||
logger.error(error_msg)
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
# Save error state
|
||||
self.state_manager.save_state("last_error", {
|
||||
"message": str(error),
|
||||
"context": context,
|
||||
"traceback": traceback.format_exc(),
|
||||
})
|
||||
|
||||
# Update health monitor
|
||||
self.health_monitor.record_error(context, error)
|
||||
|
||||
def recover_orders(self) -> bool:
|
||||
"""Recover order state after disconnection.
|
||||
|
||||
Returns:
|
||||
True if recovery successful
|
||||
"""
|
||||
try:
|
||||
from src.trading.order_manager import get_order_manager
|
||||
order_manager = get_order_manager()
|
||||
|
||||
# Get pending/open orders
|
||||
open_orders = order_manager.get_open_orders()
|
||||
|
||||
# Try to sync with exchange
|
||||
for order in open_orders:
|
||||
# Would sync order status with exchange
|
||||
pass
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to recover orders: {e}")
|
||||
return False
|
||||
|
||||
def recover_connections(self) -> bool:
|
||||
"""Recover exchange connections.
|
||||
|
||||
Returns:
|
||||
True if recovery successful
|
||||
"""
|
||||
try:
|
||||
from src.exchanges.factory import ExchangeFactory
|
||||
from src.core.database import get_database, Exchange
|
||||
|
||||
db = get_database()
|
||||
session = db.get_session()
|
||||
|
||||
try:
|
||||
exchanges = session.query(Exchange).filter_by(enabled=True).all()
|
||||
for exchange in exchanges:
|
||||
adapter = ExchangeFactory.create(exchange.id)
|
||||
if adapter:
|
||||
self.logger.info(f"Recovered connection to {exchange.name}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to recover connections: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Global recovery manager
|
||||
_recovery_manager: Optional[RecoveryManager] = None
|
||||
|
||||
|
||||
def get_recovery_manager() -> RecoveryManager:
|
||||
"""Get global recovery manager instance."""
|
||||
global _recovery_manager
|
||||
if _recovery_manager is None:
|
||||
_recovery_manager = RecoveryManager()
|
||||
return _recovery_manager
|
||||
|
||||
90
src/resilience/state_manager.py
Normal file
90
src/resilience/state_manager.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""State persistence for recovery."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, AppState
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StateManager:
|
||||
"""Manages application state persistence."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize state manager."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def save_state(self, key: str, value: Any):
|
||||
"""Save application state.
|
||||
|
||||
Args:
|
||||
key: State key
|
||||
value: State value (must be JSON serializable)
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
existing = session.query(AppState).filter_by(key=key).first()
|
||||
if existing:
|
||||
existing.value = value
|
||||
existing.updated_at = datetime.utcnow()
|
||||
else:
|
||||
state = AppState(key=key, value=value, updated_at=datetime.utcnow())
|
||||
session.add(state)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to save state {key}: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def load_state(self, key: str, default: Any = None) -> Any:
|
||||
"""Load application state.
|
||||
|
||||
Args:
|
||||
key: State key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
State value or default
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
state = session.query(AppState).filter_by(key=key).first()
|
||||
return state.value if state else default
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def clear_state(self, key: str):
|
||||
"""Clear application state.
|
||||
|
||||
Args:
|
||||
key: State key
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
state = session.query(AppState).filter_by(key=key).first()
|
||||
if state:
|
||||
session.delete(state)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to clear state {key}: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# Global state manager
|
||||
_state_manager: Optional[StateManager] = None
|
||||
|
||||
|
||||
def get_state_manager() -> StateManager:
|
||||
"""Get global state manager instance."""
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
_state_manager = StateManager()
|
||||
return _state_manager
|
||||
|
||||
0
src/risk/__init__.py
Normal file
0
src/risk/__init__.py
Normal file
166
src/risk/limits.py
Normal file
166
src/risk/limits.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Drawdown and loss limits."""
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, PortfolioSnapshot, Trade
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RiskLimitsManager:
|
||||
"""Manages risk limits (drawdown, daily loss, etc.)."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize risk limits manager."""
|
||||
self.db = get_database()
|
||||
self.config = get_config()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def check_daily_loss_limit(self) -> bool:
|
||||
"""Check if daily loss limit is exceeded."""
|
||||
daily_loss_limit = Decimal(str(
|
||||
self.config.get("risk.daily_loss_limit_percent", 5.0)
|
||||
)) / 100
|
||||
|
||||
daily_pnl = await self.get_daily_pnl()
|
||||
portfolio_value = await self.get_portfolio_value()
|
||||
|
||||
if portfolio_value == 0:
|
||||
return True
|
||||
|
||||
daily_loss_percent = abs(daily_pnl) / portfolio_value if daily_pnl < 0 else Decimal(0)
|
||||
|
||||
if daily_loss_percent > daily_loss_limit:
|
||||
self.logger.warning(f"Daily loss limit exceeded: {daily_loss_percent:.2%} > {daily_loss_limit:.2%}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def check_max_drawdown(self) -> bool:
|
||||
"""Check if maximum drawdown limit is exceeded."""
|
||||
max_drawdown_limit = Decimal(str(
|
||||
self.config.get("risk.max_drawdown_percent", 20.0)
|
||||
)) / 100
|
||||
|
||||
current_drawdown = await self.get_current_drawdown()
|
||||
|
||||
if current_drawdown > max_drawdown_limit:
|
||||
self.logger.warning(f"Max drawdown limit exceeded: {current_drawdown:.2%} > {max_drawdown_limit:.2%}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def check_portfolio_allocation(self, symbol: str, position_value: Decimal) -> bool:
|
||||
"""Check portfolio allocation limits."""
|
||||
# Default: max 20% per asset
|
||||
max_allocation_percent = Decimal("0.20")
|
||||
portfolio_value = await self.get_portfolio_value()
|
||||
|
||||
if portfolio_value == 0:
|
||||
return True
|
||||
|
||||
allocation = position_value / portfolio_value
|
||||
|
||||
if allocation > max_allocation_percent:
|
||||
self.logger.warning(f"Portfolio allocation exceeded for {symbol}: {allocation:.2%}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def get_daily_pnl(self) -> Decimal:
|
||||
"""Get today's P&L."""
|
||||
from sqlalchemy import select
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
today = datetime.utcnow().date()
|
||||
start_of_day = datetime.combine(today, datetime.min.time())
|
||||
|
||||
# Get trades from today
|
||||
stmt = select(Trade).where(
|
||||
Trade.timestamp >= start_of_day,
|
||||
Trade.paper_trading == True
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
trades = result.scalars().all()
|
||||
|
||||
# Calculate P&L from trades
|
||||
pnl = Decimal(0)
|
||||
for trade in trades:
|
||||
trade_value = trade.quantity * trade.price
|
||||
fee = trade.fee if trade.fee else Decimal(0)
|
||||
|
||||
if trade.side.value == "sell":
|
||||
pnl += trade_value - fee
|
||||
else:
|
||||
pnl -= trade_value + fee
|
||||
|
||||
return pnl
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating daily P&L: {e}")
|
||||
return Decimal(0)
|
||||
|
||||
async def get_current_drawdown(self) -> Decimal:
|
||||
"""Get current drawdown percentage."""
|
||||
from sqlalchemy import select
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
# Get peak portfolio value
|
||||
stmt_peak = select(PortfolioSnapshot).order_by(
|
||||
PortfolioSnapshot.total_value.desc()
|
||||
).limit(1)
|
||||
result_peak = await session.execute(stmt_peak)
|
||||
peak = result_peak.scalar_one_or_none()
|
||||
|
||||
if not peak:
|
||||
return Decimal(0)
|
||||
|
||||
# Get current value
|
||||
stmt_current = select(PortfolioSnapshot).order_by(
|
||||
PortfolioSnapshot.timestamp.desc()
|
||||
).limit(1)
|
||||
result_current = await session.execute(stmt_current)
|
||||
current = result_current.scalar_one_or_none()
|
||||
|
||||
if not current or current.total_value >= peak.total_value:
|
||||
return Decimal(0)
|
||||
|
||||
drawdown = (peak.total_value - current.total_value) / peak.total_value
|
||||
return drawdown
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating drawdown: {e}")
|
||||
return Decimal(0)
|
||||
|
||||
async def get_portfolio_value(self) -> Decimal:
|
||||
"""Get current portfolio value."""
|
||||
from sqlalchemy import select
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
stmt = select(PortfolioSnapshot).order_by(
|
||||
PortfolioSnapshot.timestamp.desc()
|
||||
).limit(1)
|
||||
result = await session.execute(stmt)
|
||||
latest = result.scalar_one_or_none()
|
||||
|
||||
if latest:
|
||||
return latest.total_value
|
||||
return Decimal(0)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting portfolio value: {e}")
|
||||
return Decimal(0)
|
||||
|
||||
def get_all_limits(self) -> Dict[str, Any]:
|
||||
"""Get all risk limits configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary of limit configurations
|
||||
"""
|
||||
return {
|
||||
'max_drawdown_percent': self.config.get("risk.max_drawdown_percent", 20.0),
|
||||
'daily_loss_limit_percent': self.config.get("risk.daily_loss_limit_percent", 5.0),
|
||||
'position_size_percent': self.config.get("risk.position_size_percent", 2.0),
|
||||
}
|
||||
|
||||
91
src/risk/manager.py
Normal file
91
src/risk/manager.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Risk management engine with stop-loss, position sizing, drawdown limits, and daily loss limits."""
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, RiskLimit, Trade, PortfolioSnapshot
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
from .stop_loss import StopLossManager
|
||||
from .position_sizing import PositionSizingManager
|
||||
from .limits import RiskLimitsManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RiskManager:
|
||||
"""Comprehensive risk management engine."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize risk manager."""
|
||||
self.db = get_database()
|
||||
self.config = get_config()
|
||||
self.stop_loss = StopLossManager()
|
||||
self.position_sizing = PositionSizingManager()
|
||||
self.limits = RiskLimitsManager()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def check_order_risk(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str,
|
||||
quantity: Decimal,
|
||||
price: Decimal,
|
||||
current_balance: Decimal,
|
||||
exchange_adapter=None
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""Check if order passes risk checks, including fee validation."""
|
||||
# Check position sizing (includes fee validation)
|
||||
if not self.position_sizing.validate_position_size(
|
||||
symbol, quantity, price, current_balance, exchange_adapter
|
||||
):
|
||||
return False, "Position size exceeds limits (including fees)"
|
||||
|
||||
# Check daily loss limit
|
||||
if not await self.limits.check_daily_loss_limit():
|
||||
return False, "Daily loss limit reached"
|
||||
|
||||
# Check maximum drawdown
|
||||
if not await self.limits.check_max_drawdown():
|
||||
return False, "Maximum drawdown limit reached"
|
||||
|
||||
# Check portfolio allocation
|
||||
if not await self.limits.check_portfolio_allocation(symbol, quantity * price):
|
||||
return False, "Portfolio allocation limit exceeded"
|
||||
|
||||
return True, None
|
||||
|
||||
# ... calc position size omitted as it relies on sync position sizing ...
|
||||
# Wait, calculate_position_size uses position_sizing which is sync.
|
||||
|
||||
async def check_limits(self) -> Dict[str, bool]:
|
||||
"""Check all risk limits."""
|
||||
return {
|
||||
'daily_loss': await self.limits.check_daily_loss_limit(),
|
||||
'max_drawdown': await self.limits.check_max_drawdown(),
|
||||
'position_size': True, # Checked per order
|
||||
'portfolio_allocation': True, # Checked per order
|
||||
}
|
||||
|
||||
async def get_risk_metrics(self) -> Dict[str, Any]:
|
||||
"""Get current risk metrics."""
|
||||
return {
|
||||
'daily_pnl': await self.limits.get_daily_pnl(),
|
||||
'max_drawdown': await self.limits.get_current_drawdown(),
|
||||
'portfolio_value': await self.limits.get_portfolio_value(),
|
||||
'risk_limits': self.limits.get_all_limits(), # This call is sync
|
||||
}
|
||||
|
||||
|
||||
# Global risk manager
|
||||
_risk_manager: Optional[RiskManager] = None
|
||||
|
||||
|
||||
def get_risk_manager() -> RiskManager:
|
||||
"""Get global risk manager instance."""
|
||||
global _risk_manager
|
||||
if _risk_manager is None:
|
||||
_risk_manager = RiskManager()
|
||||
return _risk_manager
|
||||
|
||||
144
src/risk/position_sizing.py
Normal file
144
src/risk/position_sizing.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Position sizing rules."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
from src.exchanges.base import BaseExchangeAdapter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PositionSizingManager:
|
||||
"""Manages position sizing calculations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize position sizing manager."""
|
||||
self.config = get_config()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def calculate_size(
|
||||
self,
|
||||
symbol: str,
|
||||
price: Decimal,
|
||||
balance: Decimal,
|
||||
risk_percent: Optional[Decimal] = None,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None
|
||||
) -> Decimal:
|
||||
"""Calculate position size, accounting for fees.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Entry price
|
||||
balance: Available balance
|
||||
risk_percent: Risk percentage (uses config default if None)
|
||||
exchange_adapter: Exchange adapter for fee calculation (optional)
|
||||
|
||||
Returns:
|
||||
Calculated position size
|
||||
"""
|
||||
if risk_percent is None:
|
||||
risk_percent = Decimal(str(
|
||||
self.config.get("risk.position_size_percent", 2.0)
|
||||
)) / 100
|
||||
|
||||
position_value = balance * risk_percent
|
||||
|
||||
# Account for fees by reserving fee amount
|
||||
from src.trading.fee_calculator import get_fee_calculator
|
||||
fee_calculator = get_fee_calculator()
|
||||
|
||||
# Reserve ~0.4% for round-trip fees
|
||||
fee_reserve = fee_calculator.calculate_fee_reserve(
|
||||
position_value=position_value,
|
||||
exchange_adapter=exchange_adapter,
|
||||
reserve_percent=0.004 # 0.4% for round-trip
|
||||
)
|
||||
|
||||
# Adjust position value to account for fees
|
||||
adjusted_position_value = position_value - fee_reserve
|
||||
|
||||
if price > 0:
|
||||
quantity = adjusted_position_value / price
|
||||
return max(Decimal(0), quantity) # Ensure non-negative
|
||||
|
||||
return Decimal(0)
|
||||
|
||||
def calculate_kelly_criterion(
|
||||
self,
|
||||
win_rate: float,
|
||||
avg_win: float,
|
||||
avg_loss: float
|
||||
) -> Decimal:
|
||||
"""Calculate position size using Kelly Criterion.
|
||||
|
||||
Args:
|
||||
win_rate: Win rate (0.0 to 1.0)
|
||||
avg_win: Average win amount
|
||||
avg_loss: Average loss amount
|
||||
|
||||
Returns:
|
||||
Kelly percentage
|
||||
"""
|
||||
if avg_loss == 0:
|
||||
return Decimal(0)
|
||||
|
||||
kelly = (win_rate * avg_win - (1 - win_rate) * avg_loss) / avg_win
|
||||
# Use fractional Kelly (half) for safety
|
||||
return Decimal(str(kelly / 2))
|
||||
|
||||
def validate_position_size(
|
||||
self,
|
||||
symbol: str,
|
||||
quantity: Decimal,
|
||||
price: Decimal,
|
||||
balance: Decimal,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None
|
||||
) -> bool:
|
||||
"""Validate position size against limits, accounting for fees.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
quantity: Position quantity
|
||||
price: Entry price
|
||||
balance: Available balance
|
||||
exchange_adapter: Exchange adapter for fee calculation (optional)
|
||||
|
||||
Returns:
|
||||
True if position size is valid
|
||||
"""
|
||||
position_value = quantity * price
|
||||
|
||||
# Calculate estimated fee for this trade
|
||||
from src.trading.fee_calculator import get_fee_calculator
|
||||
from src.core.database import OrderType
|
||||
|
||||
fee_calculator = get_fee_calculator()
|
||||
estimated_fee = fee_calculator.calculate_fee(
|
||||
quantity=quantity,
|
||||
price=price,
|
||||
order_type=OrderType.MARKET, # Use market as worst case
|
||||
exchange_adapter=exchange_adapter
|
||||
)
|
||||
|
||||
total_cost = position_value + estimated_fee
|
||||
|
||||
# Check if exceeds available balance (including fees)
|
||||
if total_cost > balance:
|
||||
self.logger.warning(
|
||||
f"Position cost {total_cost} (value: {position_value}, fee: {estimated_fee}) "
|
||||
f"exceeds balance {balance}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Check against risk limits
|
||||
max_position_percent = Decimal(str(
|
||||
self.config.get("risk.position_size_percent", 2.0)
|
||||
)) / 100
|
||||
|
||||
if position_value > balance * max_position_percent:
|
||||
self.logger.warning(f"Position size exceeds risk limit")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
229
src/risk/stop_loss.py
Normal file
229
src/risk/stop_loss.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Stop-loss logic."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Optional, Any
|
||||
import pandas as pd
|
||||
from src.core.logger import get_logger
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StopLossManager:
|
||||
"""Manages stop-loss orders."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize stop-loss manager."""
|
||||
self.stop_losses: Dict[int, Dict[str, Any]] = {} # position_id -> stop config
|
||||
self.logger = get_logger(__name__)
|
||||
self.indicators = get_indicators()
|
||||
|
||||
def set_stop_loss(
|
||||
self,
|
||||
position_id: int,
|
||||
stop_price: Decimal,
|
||||
trailing: bool = False,
|
||||
trail_percent: Optional[Decimal] = None,
|
||||
use_atr: bool = False,
|
||||
atr_multiplier: Decimal = Decimal('2.0'),
|
||||
atr_period: int = 14,
|
||||
ohlcv_data: Optional[pd.DataFrame] = None
|
||||
):
|
||||
"""Set stop-loss for position.
|
||||
|
||||
Args:
|
||||
position_id: Position ID
|
||||
stop_price: Stop price (ignored if use_atr=True)
|
||||
trailing: Enable trailing stop
|
||||
trail_percent: Trail percentage (ignored if use_atr=True)
|
||||
use_atr: Use ATR-based stop loss
|
||||
atr_multiplier: ATR multiplier for stop distance (default 2.0)
|
||||
atr_period: ATR period (default 14)
|
||||
ohlcv_data: OHLCV DataFrame for ATR calculation
|
||||
"""
|
||||
if use_atr and ohlcv_data is not None and len(ohlcv_data) >= atr_period:
|
||||
# Calculate ATR-based stop
|
||||
try:
|
||||
high = ohlcv_data['high']
|
||||
low = ohlcv_data['low']
|
||||
close = ohlcv_data['close']
|
||||
|
||||
atr = self.indicators.atr(high, low, close, period=atr_period)
|
||||
current_atr = Decimal(str(atr.iloc[-1])) if not pd.isna(atr.iloc[-1]) else None
|
||||
|
||||
if current_atr is not None:
|
||||
current_price = Decimal(str(close.iloc[-1]))
|
||||
atr_distance = current_atr * atr_multiplier
|
||||
|
||||
# For long positions, stop is below entry
|
||||
# For short positions, stop is above entry
|
||||
# We'll determine direction from stop_price vs current_price
|
||||
if stop_price < current_price:
|
||||
# Long position - stop below
|
||||
stop_price = current_price - atr_distance
|
||||
else:
|
||||
# Short position - stop above
|
||||
stop_price = current_price + atr_distance
|
||||
|
||||
self.logger.info(
|
||||
f"Set ATR-based stop-loss for position {position_id}: "
|
||||
f"ATR={current_atr}, multiplier={atr_multiplier}, "
|
||||
f"stop_price={stop_price}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to calculate ATR-based stop, using provided stop_price: {e}")
|
||||
|
||||
self.stop_losses[position_id] = {
|
||||
'stop_price': stop_price,
|
||||
'trailing': trailing,
|
||||
'trail_percent': trail_percent,
|
||||
'use_atr': use_atr,
|
||||
'atr_multiplier': atr_multiplier if use_atr else None,
|
||||
'atr_period': atr_period if use_atr else None,
|
||||
'highest_price': stop_price, # For trailing stops
|
||||
}
|
||||
if not use_atr:
|
||||
self.logger.info(f"Set stop-loss for position {position_id} at {stop_price}")
|
||||
|
||||
def check_stop_loss(
|
||||
self,
|
||||
position_id: int,
|
||||
current_price: Decimal,
|
||||
is_long: bool = True,
|
||||
ohlcv_data: Optional[pd.DataFrame] = None
|
||||
) -> bool:
|
||||
"""Check if stop-loss should trigger.
|
||||
|
||||
Args:
|
||||
position_id: Position ID
|
||||
current_price: Current market price
|
||||
is_long: True for long position
|
||||
ohlcv_data: OHLCV DataFrame for ATR-based trailing stops
|
||||
|
||||
Returns:
|
||||
True if stop-loss should trigger
|
||||
"""
|
||||
if position_id not in self.stop_losses:
|
||||
return False
|
||||
|
||||
stop_config = self.stop_losses[position_id]
|
||||
stop_price = stop_config['stop_price']
|
||||
use_atr = stop_config.get('use_atr', False)
|
||||
|
||||
if stop_config['trailing']:
|
||||
# Update trailing stop
|
||||
if use_atr and ohlcv_data is not None and stop_config.get('atr_period'):
|
||||
# ATR-based trailing stop
|
||||
try:
|
||||
atr_period = stop_config['atr_period']
|
||||
atr_multiplier = stop_config.get('atr_multiplier', Decimal('2.0'))
|
||||
|
||||
if len(ohlcv_data) >= atr_period:
|
||||
high = ohlcv_data['high']
|
||||
low = ohlcv_data['low']
|
||||
close = ohlcv_data['close']
|
||||
|
||||
atr = self.indicators.atr(high, low, close, period=atr_period)
|
||||
current_atr = Decimal(str(atr.iloc[-1])) if not pd.isna(atr.iloc[-1]) else None
|
||||
|
||||
if current_atr is not None:
|
||||
atr_distance = current_atr * atr_multiplier
|
||||
|
||||
if is_long:
|
||||
# Long position: trailing stop moves up
|
||||
if current_price > stop_config['highest_price']:
|
||||
stop_config['highest_price'] = current_price
|
||||
stop_price = current_price - atr_distance
|
||||
stop_config['stop_price'] = stop_price
|
||||
else:
|
||||
# Short position: trailing stop moves down
|
||||
if current_price < stop_config['highest_price']:
|
||||
stop_config['highest_price'] = current_price
|
||||
stop_price = current_price + atr_distance
|
||||
stop_config['stop_price'] = stop_price
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error updating ATR-based trailing stop: {e}")
|
||||
else:
|
||||
# Percentage-based trailing stop
|
||||
if is_long:
|
||||
if current_price > stop_config['highest_price']:
|
||||
stop_config['highest_price'] = current_price
|
||||
if stop_config.get('trail_percent'):
|
||||
stop_price = current_price * (1 - stop_config['trail_percent'])
|
||||
stop_config['stop_price'] = stop_price
|
||||
else:
|
||||
if current_price < stop_config['highest_price']:
|
||||
stop_config['highest_price'] = current_price
|
||||
if stop_config.get('trail_percent'):
|
||||
stop_price = current_price * (1 + stop_config['trail_percent'])
|
||||
stop_config['stop_price'] = stop_price
|
||||
|
||||
# Check trigger
|
||||
if is_long:
|
||||
return current_price <= stop_price
|
||||
else:
|
||||
return current_price >= stop_price
|
||||
|
||||
def calculate_atr_stop(
|
||||
self,
|
||||
entry_price: Decimal,
|
||||
is_long: bool,
|
||||
ohlcv_data: pd.DataFrame,
|
||||
atr_multiplier: Decimal = Decimal('2.0'),
|
||||
atr_period: int = 14
|
||||
) -> Decimal:
|
||||
"""Calculate ATR-based stop loss price.
|
||||
|
||||
Args:
|
||||
entry_price: Entry price
|
||||
is_long: True for long position
|
||||
ohlcv_data: OHLCV DataFrame
|
||||
atr_multiplier: ATR multiplier (default 2.0)
|
||||
atr_period: ATR period (default 14)
|
||||
|
||||
Returns:
|
||||
Stop loss price
|
||||
"""
|
||||
if len(ohlcv_data) < atr_period:
|
||||
# Fallback to 2% stop if insufficient data
|
||||
if is_long:
|
||||
return entry_price * Decimal('0.98')
|
||||
else:
|
||||
return entry_price * Decimal('1.02')
|
||||
|
||||
try:
|
||||
high = ohlcv_data['high']
|
||||
low = ohlcv_data['low']
|
||||
close = ohlcv_data['close']
|
||||
|
||||
atr = self.indicators.atr(high, low, close, period=atr_period)
|
||||
current_atr = Decimal(str(atr.iloc[-1])) if not pd.isna(atr.iloc[-1]) else None
|
||||
|
||||
if current_atr is not None:
|
||||
atr_distance = current_atr * atr_multiplier
|
||||
|
||||
if is_long:
|
||||
# Long position: stop below entry
|
||||
return entry_price - atr_distance
|
||||
else:
|
||||
# Short position: stop above entry
|
||||
return entry_price + atr_distance
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error calculating ATR stop, using fallback: {e}")
|
||||
|
||||
# Fallback to percentage-based stop
|
||||
if is_long:
|
||||
return entry_price * Decimal('0.98')
|
||||
else:
|
||||
return entry_price * Decimal('1.02')
|
||||
|
||||
def remove_stop_loss(self, position_id: int):
|
||||
"""Remove stop-loss for position.
|
||||
|
||||
Args:
|
||||
position_id: Position ID
|
||||
"""
|
||||
if position_id in self.stop_losses:
|
||||
del self.stop_losses[position_id]
|
||||
self.logger.info(f"Removed stop-loss for position {position_id}")
|
||||
|
||||
0
src/security/__init__.py
Normal file
0
src/security/__init__.py
Normal file
94
src/security/audit.py
Normal file
94
src/security/audit.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Audit logging for security and actions."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, AuditLog
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""Audit logging for security and important actions."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize audit logger."""
|
||||
self.db = get_database()
|
||||
|
||||
def log(
|
||||
self,
|
||||
action: str,
|
||||
entity_type: Optional[str] = None,
|
||||
entity_id: Optional[int] = None,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log an audit event.
|
||||
|
||||
Args:
|
||||
action: Action performed (e.g., "api_key_added", "order_placed")
|
||||
entity_type: Type of entity (e.g., "exchange", "strategy", "order")
|
||||
entity_id: ID of the entity
|
||||
details: Additional details as dictionary
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
audit_entry = AuditLog(
|
||||
action=action,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
details=details or {},
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
session.add(audit_entry)
|
||||
session.commit()
|
||||
|
||||
# Also log to application logger
|
||||
logger.info(f"Audit: {action} on {entity_type} {entity_id}")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to log audit event: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_audit_log(
|
||||
self,
|
||||
entity_type: Optional[str] = None,
|
||||
entity_id: Optional[int] = None,
|
||||
limit: int = 100
|
||||
) -> list[AuditLog]:
|
||||
"""Get audit log entries.
|
||||
|
||||
Args:
|
||||
entity_type: Filter by entity type
|
||||
entity_id: Filter by entity ID
|
||||
limit: Maximum number of entries to return
|
||||
|
||||
Returns:
|
||||
List of AuditLog entries
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
query = session.query(AuditLog)
|
||||
|
||||
if entity_type:
|
||||
query = query.filter_by(entity_type=entity_type)
|
||||
if entity_id:
|
||||
query = query.filter_by(entity_id=entity_id)
|
||||
|
||||
return query.order_by(AuditLog.timestamp.desc()).limit(limit).all()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# Global audit logger
|
||||
_audit_logger: Optional[AuditLogger] = None
|
||||
|
||||
|
||||
def get_audit_logger() -> AuditLogger:
|
||||
"""Get global audit logger instance."""
|
||||
global _audit_logger
|
||||
if _audit_logger is None:
|
||||
_audit_logger = AuditLogger()
|
||||
return _audit_logger
|
||||
|
||||
108
src/security/encryption.py
Normal file
108
src/security/encryption.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Encryption utilities for API keys and sensitive data."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from typing import Optional
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EncryptionManager:
|
||||
"""Manages encryption and decryption of sensitive data."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize encryption manager."""
|
||||
self.config = get_config()
|
||||
self._key = self._get_or_create_key()
|
||||
self.cipher = Fernet(self._key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""Get or create encryption key."""
|
||||
# Try to get key from keyring first
|
||||
try:
|
||||
import keyring
|
||||
key = keyring.get_password("crypto_trader", "encryption_key")
|
||||
if key:
|
||||
return key.encode()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get key from keyring: {e}")
|
||||
|
||||
# Generate key from user's password or system
|
||||
# For now, use a file-based approach (in production, use keyring)
|
||||
key_file = self.config.config_dir / ".encryption_key"
|
||||
|
||||
if key_file.exists():
|
||||
with open(key_file, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
# Generate new key
|
||||
key = Fernet.generate_key()
|
||||
try:
|
||||
key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(key_file, 'wb') as f:
|
||||
f.write(key)
|
||||
# Restrict permissions after file is created
|
||||
key_file.chmod(0o600)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create encryption key file: {e}")
|
||||
raise
|
||||
|
||||
# Try to store in keyring
|
||||
try:
|
||||
import keyring
|
||||
keyring.set_password("crypto_trader", "encryption_key", key.decode())
|
||||
except Exception:
|
||||
pass # Fallback to file if keyring unavailable
|
||||
|
||||
return key
|
||||
|
||||
def encrypt(self, data: str) -> str:
|
||||
"""Encrypt sensitive data.
|
||||
|
||||
Args:
|
||||
data: Plain text data to encrypt
|
||||
|
||||
Returns:
|
||||
Encrypted data as base64 string
|
||||
"""
|
||||
if not data:
|
||||
return ""
|
||||
encrypted = self.cipher.encrypt(data.encode())
|
||||
return base64.b64encode(encrypted).decode()
|
||||
|
||||
def decrypt(self, encrypted_data: str) -> str:
|
||||
"""Decrypt sensitive data.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64 encoded encrypted data
|
||||
|
||||
Returns:
|
||||
Decrypted plain text
|
||||
"""
|
||||
if not encrypted_data:
|
||||
return ""
|
||||
try:
|
||||
decoded = base64.b64decode(encrypted_data.encode())
|
||||
decrypted = self.cipher.decrypt(decoded)
|
||||
return decrypted.decode()
|
||||
except Exception as e:
|
||||
logger.error(f"Decryption failed: {e}")
|
||||
raise ValueError("Failed to decrypt data") from e
|
||||
|
||||
|
||||
# Global encryption manager
|
||||
_encryption_manager: Optional[EncryptionManager] = None
|
||||
|
||||
|
||||
def get_encryption_manager() -> EncryptionManager:
|
||||
"""Get global encryption manager instance."""
|
||||
global _encryption_manager
|
||||
if _encryption_manager is None:
|
||||
_encryption_manager = EncryptionManager()
|
||||
return _encryption_manager
|
||||
|
||||
173
src/security/key_manager.py
Normal file
173
src/security/key_manager.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""API key management with read-only/trading modes and encryption."""
|
||||
|
||||
from typing import Optional, Dict
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, Exchange
|
||||
from src.core.logger import get_logger
|
||||
from .encryption import get_encryption_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
"""Manages API keys with encryption and permission modes."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize API key manager."""
|
||||
self.db = get_database()
|
||||
self.encryption = get_encryption_manager()
|
||||
|
||||
async def add_exchange(
|
||||
self,
|
||||
name: str,
|
||||
api_key: str,
|
||||
api_secret: str,
|
||||
read_only: bool = True,
|
||||
sandbox: bool = False
|
||||
) -> Exchange:
|
||||
"""Add exchange with encrypted API credentials."""
|
||||
from sqlalchemy import select
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
# Check if exchange already exists
|
||||
stmt = select(Exchange).where(Exchange.name == name)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
raise ValueError(f"Exchange {name} already exists")
|
||||
|
||||
# Encrypt credentials
|
||||
encrypted_key = self.encryption.encrypt(api_key)
|
||||
encrypted_secret = self.encryption.encrypt(api_secret)
|
||||
|
||||
# Create exchange record
|
||||
exchange = Exchange(
|
||||
name=name,
|
||||
api_key_encrypted=encrypted_key,
|
||||
api_secret_encrypted=encrypted_secret,
|
||||
read_only=read_only,
|
||||
sandbox=sandbox,
|
||||
enabled=True
|
||||
)
|
||||
|
||||
session.add(exchange)
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Added exchange {name} (read_only={read_only}, sandbox={sandbox})")
|
||||
return exchange
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to add exchange {name}: {e}")
|
||||
raise
|
||||
|
||||
async def get_exchange_credentials(self, exchange_id: int) -> Dict[str, str]:
|
||||
"""Get decrypted exchange credentials."""
|
||||
from sqlalchemy import select
|
||||
async with self.db.get_session() as session:
|
||||
exchange = await session.get(Exchange, exchange_id)
|
||||
if not exchange:
|
||||
raise ValueError(f"Exchange {exchange_id} not found")
|
||||
|
||||
if not exchange.api_key_encrypted or not exchange.api_secret_encrypted:
|
||||
# Return empty credentials for public data exchanges
|
||||
return {
|
||||
"api_key": "",
|
||||
"api_secret": "",
|
||||
"read_only": getattr(exchange, 'read_only', True),
|
||||
"sandbox": getattr(exchange, 'sandbox', False),
|
||||
}
|
||||
|
||||
return {
|
||||
"api_key": self.encryption.decrypt(exchange.api_key_encrypted),
|
||||
"api_secret": self.encryption.decrypt(exchange.api_secret_encrypted),
|
||||
"read_only": exchange.read_only,
|
||||
"sandbox": exchange.sandbox,
|
||||
}
|
||||
|
||||
async def update_exchange(
|
||||
self,
|
||||
exchange_id: int,
|
||||
api_key: Optional[str] = None,
|
||||
api_secret: Optional[str] = None,
|
||||
read_only: Optional[bool] = None,
|
||||
sandbox: Optional[bool] = None,
|
||||
enabled: Optional[bool] = None
|
||||
) -> Exchange:
|
||||
"""Update exchange configuration."""
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
exchange = await session.get(Exchange, exchange_id)
|
||||
if not exchange:
|
||||
raise ValueError(f"Exchange {exchange_id} not found")
|
||||
|
||||
if api_key is not None:
|
||||
exchange.api_key_encrypted = self.encryption.encrypt(api_key)
|
||||
if api_secret is not None:
|
||||
exchange.api_secret_encrypted = self.encryption.encrypt(api_secret)
|
||||
if read_only is not None:
|
||||
exchange.read_only = read_only
|
||||
if sandbox is not None:
|
||||
exchange.sandbox = sandbox
|
||||
if enabled is not None:
|
||||
exchange.enabled = enabled
|
||||
|
||||
await session.commit()
|
||||
logger.info(f"Updated exchange {exchange.name}")
|
||||
return exchange
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to update exchange {exchange_id}: {e}")
|
||||
raise
|
||||
|
||||
async def delete_exchange(self, exchange_id: int):
|
||||
"""Delete exchange and its credentials."""
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
exchange = await session.get(Exchange, exchange_id)
|
||||
if not exchange:
|
||||
raise ValueError(f"Exchange {exchange_id} not found")
|
||||
|
||||
await session.delete(exchange)
|
||||
await session.commit()
|
||||
logger.info(f"Deleted exchange {exchange.name}")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to delete exchange {exchange_id}: {e}")
|
||||
raise
|
||||
|
||||
async def list_exchanges(self) -> list[Exchange]:
|
||||
"""List all exchanges."""
|
||||
from sqlalchemy import select
|
||||
async with self.db.get_session() as session:
|
||||
result = await session.execute(select(Exchange))
|
||||
return result.scalars().all()
|
||||
|
||||
async def validate_permissions(self, exchange_id: int, requires_trading: bool = False) -> bool:
|
||||
"""Validate if exchange has required permissions."""
|
||||
async with self.db.get_session() as session:
|
||||
exchange = await session.get(Exchange, exchange_id)
|
||||
if not exchange:
|
||||
return False
|
||||
|
||||
if not exchange.enabled:
|
||||
return False
|
||||
|
||||
if requires_trading and exchange.read_only:
|
||||
logger.warning(f"Exchange {exchange.name} is read-only but trading required")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Global API key manager
|
||||
_key_manager: Optional[APIKeyManager] = None
|
||||
|
||||
|
||||
def get_key_manager() -> APIKeyManager:
|
||||
"""Get global API key manager instance."""
|
||||
global _key_manager
|
||||
if _key_manager is None:
|
||||
_key_manager = APIKeyManager()
|
||||
return _key_manager
|
||||
|
||||
45
src/strategies/__init__.py
Normal file
45
src/strategies/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Strategies package."""
|
||||
|
||||
from .base import BaseStrategy, StrategyRegistry, get_strategy_registry, SignalType, StrategySignal
|
||||
from .technical.rsi_strategy import RSIStrategy
|
||||
from .technical.macd_strategy import MACDStrategy
|
||||
from .technical.moving_avg_strategy import MovingAverageStrategy
|
||||
from .technical.confirmed_strategy import ConfirmedStrategy
|
||||
from .technical.divergence_strategy import DivergenceStrategy
|
||||
from .technical.bollinger_mean_reversion import BollingerMeanReversionStrategy
|
||||
from .dca.dca_strategy import DCAStrategy
|
||||
from .grid.grid_strategy import GridStrategy
|
||||
from .momentum.momentum_strategy import MomentumStrategy
|
||||
from .ensemble.consensus_strategy import ConsensusStrategy
|
||||
from .technical.pairs_trading import PairsTradingStrategy
|
||||
from .technical.volatility_breakout import VolatilityBreakoutStrategy
|
||||
from .sentiment.sentiment_strategy import SentimentStrategy
|
||||
from .market_making.market_making_strategy import MarketMakingStrategy
|
||||
|
||||
# Register strategies
|
||||
registry = get_strategy_registry()
|
||||
registry.register("rsi", RSIStrategy)
|
||||
registry.register("macd", MACDStrategy)
|
||||
registry.register("moving_average", MovingAverageStrategy)
|
||||
registry.register("confirmed", ConfirmedStrategy)
|
||||
registry.register("divergence", DivergenceStrategy)
|
||||
registry.register("bollinger_mean_reversion", BollingerMeanReversionStrategy)
|
||||
registry.register("dca", DCAStrategy)
|
||||
registry.register("grid", GridStrategy)
|
||||
registry.register("momentum", MomentumStrategy)
|
||||
registry.register("consensus", ConsensusStrategy)
|
||||
registry.register("pairs_trading", PairsTradingStrategy)
|
||||
registry.register("volatility_breakout", VolatilityBreakoutStrategy)
|
||||
registry.register("sentiment", SentimentStrategy)
|
||||
registry.register("market_making", MarketMakingStrategy)
|
||||
|
||||
__all__ = [
|
||||
'BaseStrategy', 'StrategyRegistry', 'get_strategy_registry',
|
||||
'SignalType', 'StrategySignal',
|
||||
'RSIStrategy', 'MACDStrategy', 'MovingAverageStrategy',
|
||||
'ConfirmedStrategy', 'DivergenceStrategy', 'BollingerMeanReversionStrategy',
|
||||
'DCAStrategy', 'GridStrategy', 'MomentumStrategy',
|
||||
'ConsensusStrategy', 'PairsTradingStrategy', 'VolatilityBreakoutStrategy',
|
||||
'SentimentStrategy', 'MarketMakingStrategy'
|
||||
]
|
||||
|
||||
450
src/strategies/base.py
Normal file
450
src/strategies/base.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""Base strategy class and strategy registry system."""
|
||||
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Optional, List, Any, Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from src.core.logger import get_logger
|
||||
from src.core.database import OrderSide, OrderType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SignalType(str, Enum):
|
||||
"""Trading signal types."""
|
||||
BUY = "buy"
|
||||
SELL = "sell"
|
||||
HOLD = "hold"
|
||||
CLOSE = "close"
|
||||
|
||||
|
||||
class StrategySignal:
|
||||
"""Trading signal from strategy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
signal_type: SignalType,
|
||||
symbol: str,
|
||||
strength: float = 1.0,
|
||||
price: Optional[Decimal] = None,
|
||||
quantity: Optional[Decimal] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Initialize strategy signal.
|
||||
|
||||
Args:
|
||||
signal_type: Signal type (buy, sell, hold, close)
|
||||
symbol: Trading symbol
|
||||
strength: Signal strength (0.0 to 1.0)
|
||||
price: Suggested price
|
||||
quantity: Suggested quantity
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
self.signal_type = signal_type
|
||||
self.symbol = symbol
|
||||
self.strength = strength
|
||||
self.price = price
|
||||
self.quantity = quantity
|
||||
self.metadata = metadata or {}
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""Base class for all trading strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
timeframes: Optional[List[str]] = None
|
||||
):
|
||||
"""Initialize strategy.
|
||||
|
||||
Args:
|
||||
name: Strategy name
|
||||
parameters: Strategy parameters
|
||||
timeframes: List of timeframes (e.g., ['1h', '15m'])
|
||||
"""
|
||||
self.name = name
|
||||
self.parameters = parameters or {}
|
||||
self.timeframes = timeframes or ['1h']
|
||||
self.enabled = False
|
||||
self.logger = get_logger(f"strategy.{name}")
|
||||
self._data_cache: Dict[str, Any] = {}
|
||||
|
||||
@abstractmethod
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Called on each price update.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Current price
|
||||
timeframe: Timeframe of the update
|
||||
data: Additional market data
|
||||
|
||||
Returns:
|
||||
StrategySignal or None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process and potentially modify signal.
|
||||
|
||||
Args:
|
||||
signal: Generated signal
|
||||
|
||||
Returns:
|
||||
Modified signal or None to cancel
|
||||
"""
|
||||
pass
|
||||
|
||||
def calculate_position_size(
|
||||
self,
|
||||
signal: StrategySignal,
|
||||
balance: Decimal,
|
||||
price: Decimal,
|
||||
exchange_adapter=None
|
||||
) -> Decimal:
|
||||
"""Calculate position size for signal, accounting for fees.
|
||||
|
||||
Args:
|
||||
signal: Trading signal
|
||||
balance: Available balance
|
||||
price: Current price
|
||||
exchange_adapter: Exchange adapter for fee calculation (optional)
|
||||
|
||||
Returns:
|
||||
Position size
|
||||
"""
|
||||
# Default: use 2% of balance
|
||||
risk_percent = self.parameters.get('position_size_percent', 2.0) / 100.0
|
||||
position_value = balance * Decimal(str(risk_percent))
|
||||
|
||||
# Account for fees by reserving fee amount
|
||||
from src.trading.fee_calculator import get_fee_calculator
|
||||
fee_calculator = get_fee_calculator()
|
||||
|
||||
# Reserve ~0.4% for round-trip fees (conservative estimate)
|
||||
fee_reserve = fee_calculator.calculate_fee_reserve(
|
||||
position_value=position_value,
|
||||
exchange_adapter=exchange_adapter,
|
||||
reserve_percent=0.004 # 0.4% for round-trip
|
||||
)
|
||||
|
||||
# Adjust position value to account for fees
|
||||
adjusted_position_value = position_value - fee_reserve
|
||||
|
||||
# Calculate quantity
|
||||
if price > 0:
|
||||
quantity = adjusted_position_value / price
|
||||
return max(Decimal(0), quantity) # Ensure non-negative
|
||||
|
||||
return Decimal(0)
|
||||
|
||||
def should_execute(self, signal: StrategySignal) -> bool:
|
||||
"""Check if signal should be executed.
|
||||
|
||||
Args:
|
||||
signal: Trading signal
|
||||
|
||||
Returns:
|
||||
True if should execute
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
# Check signal strength threshold
|
||||
min_strength = self.parameters.get('min_signal_strength', 0.5)
|
||||
if signal.strength < min_strength:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def should_execute_with_fees(
|
||||
self,
|
||||
signal: StrategySignal,
|
||||
balance: Decimal,
|
||||
price: Decimal,
|
||||
exchange_adapter=None
|
||||
) -> bool:
|
||||
"""Check if signal should be executed considering fees and minimum profit threshold.
|
||||
|
||||
Args:
|
||||
signal: Trading signal
|
||||
balance: Available balance
|
||||
price: Current price
|
||||
exchange_adapter: Exchange adapter for fee calculation (optional)
|
||||
|
||||
Returns:
|
||||
True if should execute after fee consideration
|
||||
"""
|
||||
# First check basic execution criteria
|
||||
if not self.should_execute(signal):
|
||||
return False
|
||||
|
||||
# Calculate position size
|
||||
quantity = signal.quantity or self.calculate_position_size(signal, balance, price, exchange_adapter)
|
||||
|
||||
if quantity <= 0:
|
||||
return False
|
||||
|
||||
# Check minimum profit threshold
|
||||
from src.trading.fee_calculator import get_fee_calculator
|
||||
fee_calculator = get_fee_calculator()
|
||||
|
||||
# Get minimum profit multiplier from strategy parameters (default 2.0)
|
||||
min_profit_multiplier = self.parameters.get('min_profit_multiplier', 2.0)
|
||||
|
||||
min_profit_threshold = fee_calculator.get_minimum_profit_threshold(
|
||||
quantity=quantity,
|
||||
price=price,
|
||||
exchange_adapter=exchange_adapter,
|
||||
multiplier=min_profit_multiplier
|
||||
)
|
||||
|
||||
# Estimate potential profit (simplified - strategies can override)
|
||||
# For buy signals, we'd need to estimate exit price
|
||||
# For now, we'll use a basic check: if signal strength is high enough
|
||||
# Strategies should override this method for more sophisticated checks
|
||||
|
||||
# If we have a target price in signal metadata, use it
|
||||
target_price = signal.metadata.get('target_price')
|
||||
if target_price:
|
||||
if signal.signal_type.value == "buy":
|
||||
potential_profit = (target_price - price) * quantity
|
||||
else: # sell
|
||||
potential_profit = (price - target_price) * quantity
|
||||
|
||||
if potential_profit < min_profit_threshold:
|
||||
self.logger.debug(
|
||||
f"Signal filtered: potential profit {potential_profit} < "
|
||||
f"minimum threshold {min_profit_threshold}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def apply_trend_filter(
|
||||
self,
|
||||
signal: StrategySignal,
|
||||
ohlcv_data: Any,
|
||||
adx_period: int = 14,
|
||||
min_adx: float = 25.0
|
||||
) -> Optional[StrategySignal]:
|
||||
"""Apply ADX-based trend filter to signal.
|
||||
|
||||
Filters signals based on trend strength:
|
||||
- Only allow BUY signals when ADX > threshold (strong trend)
|
||||
- Only allow SELL signals in downtrends with ADX > threshold
|
||||
- Filters out choppy/ranging markets
|
||||
|
||||
Args:
|
||||
signal: Trading signal to filter
|
||||
ohlcv_data: OHLCV DataFrame with columns: high, low, close
|
||||
adx_period: ADX calculation period (default 14)
|
||||
min_adx: Minimum ADX value for signal (default 25.0)
|
||||
|
||||
Returns:
|
||||
Filtered signal or None if filtered out
|
||||
"""
|
||||
if not self.parameters.get('use_trend_filter', False):
|
||||
return signal
|
||||
|
||||
try:
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
if ohlcv_data is None or len(ohlcv_data) < adx_period:
|
||||
# Not enough data, allow signal
|
||||
return signal
|
||||
|
||||
# Ensure we have a DataFrame
|
||||
if not isinstance(ohlcv_data, pd.DataFrame):
|
||||
return signal
|
||||
|
||||
indicators = get_indicators()
|
||||
|
||||
# Calculate ADX
|
||||
high = ohlcv_data['high']
|
||||
low = ohlcv_data['low']
|
||||
close = ohlcv_data['close']
|
||||
|
||||
adx = indicators.adx(high, low, close, period=adx_period)
|
||||
current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0
|
||||
|
||||
# Check if trend is strong enough
|
||||
if current_adx < min_adx:
|
||||
# Weak trend - filter out signal
|
||||
self.logger.debug(
|
||||
f"Trend filter: ADX {current_adx:.2f} < {min_adx}, "
|
||||
f"filtering {signal.signal_type.value} signal"
|
||||
)
|
||||
return None
|
||||
|
||||
# Additional check: for BUY signals, ensure uptrend
|
||||
# For SELL signals, ensure downtrend
|
||||
# We can use price vs moving average to determine trend direction
|
||||
if len(close) >= 20:
|
||||
sma_20 = indicators.sma(close, period=20)
|
||||
current_price = close.iloc[-1]
|
||||
sma_value = sma_20.iloc[-1] if not pd.isna(sma_20.iloc[-1]) else current_price
|
||||
|
||||
if signal.signal_type == SignalType.BUY:
|
||||
# BUY only in uptrend (price above SMA)
|
||||
if current_price < sma_value:
|
||||
self.logger.debug(
|
||||
f"Trend filter: BUY signal filtered - price below SMA "
|
||||
f"(price: {current_price}, SMA: {sma_value})"
|
||||
)
|
||||
return None
|
||||
elif signal.signal_type == SignalType.SELL:
|
||||
# SELL only in downtrend (price below SMA)
|
||||
if current_price > sma_value:
|
||||
self.logger.debug(
|
||||
f"Trend filter: SELL signal filtered - price above SMA "
|
||||
f"(price: {current_price}, SMA: {sma_value})"
|
||||
)
|
||||
return None
|
||||
|
||||
return signal
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error applying trend filter: {e}, allowing signal")
|
||||
return signal
|
||||
|
||||
def get_required_indicators(self) -> List[str]:
|
||||
"""Get list of required indicators.
|
||||
|
||||
Returns:
|
||||
List of indicator names
|
||||
"""
|
||||
return []
|
||||
|
||||
def validate_parameters(self) -> bool:
|
||||
"""Validate strategy parameters.
|
||||
|
||||
Returns:
|
||||
True if parameters are valid
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get strategy state for persistence.
|
||||
|
||||
Returns:
|
||||
State dictionary
|
||||
"""
|
||||
return {
|
||||
'name': self.name,
|
||||
'parameters': self.parameters,
|
||||
'timeframes': self.timeframes,
|
||||
'enabled': self.enabled,
|
||||
}
|
||||
|
||||
def set_state(self, state: Dict[str, Any]):
|
||||
"""Restore strategy state.
|
||||
|
||||
Args:
|
||||
state: State dictionary
|
||||
"""
|
||||
self.parameters = state.get('parameters', {})
|
||||
self.timeframes = state.get('timeframes', ['1h'])
|
||||
self.enabled = state.get('enabled', False)
|
||||
|
||||
|
||||
class StrategyRegistry:
|
||||
"""Registry for managing strategies."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize strategy registry."""
|
||||
self._strategies: Dict[str, type] = {}
|
||||
self._instances: Dict[int, BaseStrategy] = {}
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def register(self, name: str, strategy_class: type):
|
||||
"""Register a strategy class.
|
||||
|
||||
Args:
|
||||
name: Strategy name
|
||||
strategy_class: Strategy class (subclass of BaseStrategy)
|
||||
"""
|
||||
if not issubclass(strategy_class, BaseStrategy):
|
||||
raise ValueError(f"Strategy class must inherit from BaseStrategy")
|
||||
|
||||
self._strategies[name.lower()] = strategy_class
|
||||
self.logger.info(f"Registered strategy: {name}")
|
||||
|
||||
def create_instance(
|
||||
self,
|
||||
strategy_id: int,
|
||||
name: str,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
timeframes: Optional[List[str]] = None
|
||||
) -> Optional[BaseStrategy]:
|
||||
"""Create strategy instance.
|
||||
|
||||
Args:
|
||||
strategy_id: Strategy ID from src.database
|
||||
name: Strategy name
|
||||
parameters: Strategy parameters
|
||||
timeframes: List of timeframes
|
||||
|
||||
Returns:
|
||||
Strategy instance or None
|
||||
"""
|
||||
strategy_class = self._strategies.get(name.lower())
|
||||
if not strategy_class:
|
||||
self.logger.error(f"Strategy {name} not registered")
|
||||
return None
|
||||
|
||||
try:
|
||||
instance = strategy_class(name, parameters, timeframes)
|
||||
self._instances[strategy_id] = instance
|
||||
return instance
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create strategy instance: {e}")
|
||||
return None
|
||||
|
||||
def get_instance(self, strategy_id: int) -> Optional[BaseStrategy]:
|
||||
"""Get strategy instance by ID.
|
||||
|
||||
Args:
|
||||
strategy_id: Strategy ID
|
||||
|
||||
Returns:
|
||||
Strategy instance or None
|
||||
"""
|
||||
return self._instances.get(strategy_id)
|
||||
|
||||
def list_available(self) -> List[str]:
|
||||
"""List available strategy types.
|
||||
|
||||
Returns:
|
||||
List of strategy names
|
||||
"""
|
||||
return list(self._strategies.keys())
|
||||
|
||||
def unregister(self, name: str):
|
||||
"""Unregister a strategy.
|
||||
|
||||
Args:
|
||||
name: Strategy name
|
||||
"""
|
||||
if name.lower() in self._strategies:
|
||||
del self._strategies[name.lower()]
|
||||
self.logger.info(f"Unregistered strategy: {name}")
|
||||
|
||||
|
||||
# Global strategy registry
|
||||
_registry: Optional[StrategyRegistry] = None
|
||||
|
||||
|
||||
def get_strategy_registry() -> StrategyRegistry:
|
||||
"""Get global strategy registry instance."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = StrategyRegistry()
|
||||
return _registry
|
||||
|
||||
0
src/strategies/dca/__init__.py
Normal file
0
src/strategies/dca/__init__.py
Normal file
90
src/strategies/dca/dca_strategy.py
Normal file
90
src/strategies/dca/dca_strategy.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Dollar Cost Averaging (DCA) strategy."""
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DCAStrategy(BaseStrategy):
|
||||
"""Dollar Cost Averaging strategy - fixed amount per interval."""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize DCA strategy.
|
||||
|
||||
Parameters:
|
||||
amount: Fixed amount to invest per interval (default: 10)
|
||||
interval: Interval type - 'daily', 'weekly', 'monthly' (default: 'daily')
|
||||
target_allocation: Target allocation percentage (default: 10%)
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
self.amount = Decimal(str(self.parameters.get('amount', 10)))
|
||||
self.interval = self.parameters.get('interval', 'daily')
|
||||
self.target_allocation = Decimal(str(self.parameters.get('target_allocation', 10))) / 100
|
||||
|
||||
# Track last purchase
|
||||
self.last_purchase_time: Optional[datetime] = None
|
||||
self._calculate_interval_delta()
|
||||
|
||||
def _calculate_interval_delta(self):
|
||||
"""Calculate timedelta for interval."""
|
||||
if self.interval == 'daily':
|
||||
self.interval_delta = timedelta(days=1)
|
||||
elif self.interval == 'weekly':
|
||||
self.interval_delta = timedelta(weeks=1)
|
||||
elif self.interval == 'monthly':
|
||||
self.interval_delta = timedelta(days=30)
|
||||
else:
|
||||
self.interval_delta = timedelta(days=1)
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate DCA signal based on interval."""
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Check if enough time has passed since last purchase
|
||||
if self.last_purchase_time:
|
||||
time_since_last = now - self.last_purchase_time
|
||||
if time_since_last < self.interval_delta:
|
||||
return None
|
||||
|
||||
# Generate buy signal for fixed amount
|
||||
quantity = self.amount / price
|
||||
|
||||
self.last_purchase_time = now
|
||||
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol=symbol,
|
||||
strength=1.0,
|
||||
price=price,
|
||||
quantity=quantity,
|
||||
metadata={
|
||||
'amount': float(self.amount),
|
||||
'interval': self.interval,
|
||||
'strategy': 'dca'
|
||||
}
|
||||
)
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal."""
|
||||
return signal if self.should_execute(signal) else None
|
||||
|
||||
def should_rebalance(self, current_allocation: Decimal, portfolio_value: Decimal) -> bool:
|
||||
"""Check if rebalancing is needed.
|
||||
|
||||
Args:
|
||||
current_allocation: Current allocation percentage
|
||||
portfolio_value: Total portfolio value
|
||||
|
||||
Returns:
|
||||
True if rebalancing needed
|
||||
"""
|
||||
target_value = portfolio_value * self.target_allocation
|
||||
current_value = portfolio_value * current_allocation
|
||||
|
||||
# Rebalance if allocation deviates by more than 5%
|
||||
deviation = abs(current_allocation - self.target_allocation)
|
||||
return deviation > Decimal('0.05')
|
||||
6
src/strategies/ensemble/__init__.py
Normal file
6
src/strategies/ensemble/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Ensemble strategy package."""
|
||||
|
||||
from .consensus_strategy import ConsensusStrategy
|
||||
|
||||
__all__ = ['ConsensusStrategy']
|
||||
|
||||
244
src/strategies/ensemble/consensus_strategy.py
Normal file
244
src/strategies/ensemble/consensus_strategy.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Ensemble/consensus strategy.
|
||||
|
||||
Combines signals from multiple strategies with voting mechanism.
|
||||
Only executes when multiple strategies agree, improving signal quality.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any, List
|
||||
import pandas as pd
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType, get_strategy_registry
|
||||
from src.core.logger import get_logger
|
||||
from src.autopilot.performance_tracker import get_performance_tracker
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ConsensusStrategy(BaseStrategy):
|
||||
"""Ensemble strategy that combines signals from multiple strategies.
|
||||
|
||||
This strategy aggregates signals from multiple registered strategies
|
||||
and only generates signals when a minimum consensus is reached.
|
||||
Signals are weighted by strategy performance metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize consensus strategy.
|
||||
|
||||
Parameters:
|
||||
strategy_names: List of strategy names to include (None = all)
|
||||
min_consensus: Minimum number of strategies that must agree (default 2)
|
||||
use_weights: Weight signals by strategy performance (default True)
|
||||
min_weight: Minimum weight for a strategy to participate (default 0.3)
|
||||
exclude_strategies: List of strategy names to exclude
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
|
||||
self.strategy_names = self.parameters.get('strategy_names', None)
|
||||
self.min_consensus = self.parameters.get('min_consensus', 2)
|
||||
self.use_weights = self.parameters.get('use_weights', True)
|
||||
self.min_weight = self.parameters.get('min_weight', 0.3)
|
||||
self.exclude_strategies = self.parameters.get('exclude_strategies', [])
|
||||
|
||||
self.registry = get_strategy_registry()
|
||||
self.performance_tracker = get_performance_tracker()
|
||||
self._strategy_instances: Dict[str, BaseStrategy] = {}
|
||||
self._strategy_weights: Dict[str, float] = {}
|
||||
|
||||
# Initialize strategy instances
|
||||
self._initialize_strategies()
|
||||
|
||||
# Calculate weights if using weighted voting
|
||||
if self.use_weights:
|
||||
self._calculate_weights()
|
||||
|
||||
def _initialize_strategies(self):
|
||||
"""Initialize instances of strategies to monitor."""
|
||||
if self.strategy_names is None:
|
||||
available = self.registry.list_available()
|
||||
# Aggressive filtering to prevent recursion
|
||||
self.strategy_names = [
|
||||
s for s in available
|
||||
if s not in self.exclude_strategies
|
||||
and "consensus" not in s.lower()
|
||||
and s.lower() != self.name.lower()
|
||||
]
|
||||
self.logger.info(f"ConsensusStrategy({self.name}) automatically selected strategies: {self.strategy_names}")
|
||||
else:
|
||||
self.logger.info(f"ConsensusStrategy({self.name}) manually selected strategies: {self.strategy_names}")
|
||||
|
||||
for strategy_name in self.strategy_names:
|
||||
try:
|
||||
strategy_class = self.registry._strategies.get(strategy_name.lower())
|
||||
if strategy_class:
|
||||
instance = strategy_class(
|
||||
name=f"{strategy_name}_consensus",
|
||||
parameters={},
|
||||
timeframes=self.timeframes
|
||||
)
|
||||
instance.enabled = True
|
||||
self._strategy_instances[strategy_name] = instance
|
||||
self.logger.debug(f"Initialized strategy for consensus: {strategy_name}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to initialize strategy {strategy_name} for consensus: {e}")
|
||||
|
||||
async def _calculate_weights(self):
|
||||
"""Calculate weights for strategies based on performance."""
|
||||
for strategy_name in self._strategy_instances.keys():
|
||||
try:
|
||||
metrics = await self.performance_tracker.calculate_metrics(strategy_name, period_days=30)
|
||||
|
||||
# Weight based on win rate and Sharpe ratio
|
||||
win_rate = metrics.get('win_rate', 0.5)
|
||||
sharpe_ratio = max(0.0, metrics.get('sharpe_ratio', 0.0))
|
||||
|
||||
# Normalize to 0-1 range
|
||||
weight = (win_rate * 0.6) + (min(sharpe_ratio / 2.0, 1.0) * 0.4)
|
||||
|
||||
# Apply minimum weight threshold
|
||||
if weight < self.min_weight:
|
||||
weight = self.min_weight
|
||||
|
||||
self._strategy_weights[strategy_name] = weight
|
||||
self.logger.debug(f"Strategy {strategy_name} weight: {weight:.2f}")
|
||||
except Exception as e:
|
||||
# Default weight if metrics unavailable
|
||||
self._strategy_weights[strategy_name] = 0.5
|
||||
self.logger.warning(f"Could not calculate weight for {strategy_name}: {e}")
|
||||
|
||||
async def _collect_signals(
|
||||
self,
|
||||
symbol: str,
|
||||
price: Decimal,
|
||||
timeframe: str,
|
||||
data: Dict[str, Any]
|
||||
) -> List[tuple[str, StrategySignal, float]]:
|
||||
"""Collect signals from all monitored strategies.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Current price
|
||||
timeframe: Timeframe
|
||||
data: Market data
|
||||
|
||||
Returns:
|
||||
List of (strategy_name, signal, weight) tuples
|
||||
"""
|
||||
signals: List[tuple[str, StrategySignal, float]] = []
|
||||
|
||||
for strategy_name, strategy_instance in self._strategy_instances.items():
|
||||
try:
|
||||
signal = await strategy_instance.on_tick(symbol, price, timeframe, data)
|
||||
if signal and signal.signal_type != SignalType.HOLD:
|
||||
# Process signal through strategy's on_signal
|
||||
signal = strategy_instance.on_signal(signal)
|
||||
if signal:
|
||||
weight = self._strategy_weights.get(strategy_name, 0.5) if self.use_weights else 1.0
|
||||
signals.append((strategy_name, signal, weight))
|
||||
except Exception as e:
|
||||
import traceback
|
||||
self.logger.warning(f"Error getting signal from {strategy_name}: {e}\n{traceback.format_exc()}")
|
||||
|
||||
return signals
|
||||
|
||||
def _aggregate_signals(
|
||||
self,
|
||||
signals: List[tuple[str, StrategySignal, float]]
|
||||
) -> Optional[StrategySignal]:
|
||||
"""Aggregate signals and determine consensus.
|
||||
|
||||
Args:
|
||||
signals: List of (strategy_name, signal, weight) tuples
|
||||
|
||||
Returns:
|
||||
Consensus signal or None
|
||||
"""
|
||||
if not signals:
|
||||
return None
|
||||
|
||||
# Group signals by type
|
||||
buy_signals: List[tuple[str, StrategySignal, float]] = []
|
||||
sell_signals: List[tuple[str, StrategySignal, float]] = []
|
||||
|
||||
for strategy_name, signal, weight in signals:
|
||||
if signal.signal_type == SignalType.BUY:
|
||||
buy_signals.append((strategy_name, signal, weight))
|
||||
elif signal.signal_type == SignalType.SELL:
|
||||
sell_signals.append((strategy_name, signal, weight))
|
||||
|
||||
# Calculate weighted consensus scores
|
||||
buy_score = sum(weight * signal.strength for _, signal, weight in buy_signals)
|
||||
sell_score = sum(weight * signal.strength for _, signal, weight in sell_signals)
|
||||
|
||||
buy_count = len(buy_signals)
|
||||
sell_count = len(sell_signals)
|
||||
|
||||
# Determine final signal
|
||||
final_signal_type = None
|
||||
consensus_count = 0
|
||||
consensus_score = 0.0
|
||||
participating_strategies = []
|
||||
|
||||
if buy_count >= self.min_consensus and buy_score > sell_score:
|
||||
final_signal_type = SignalType.BUY
|
||||
consensus_count = buy_count
|
||||
consensus_score = buy_score
|
||||
participating_strategies = [name for name, _, _ in buy_signals]
|
||||
elif sell_count >= self.min_consensus and sell_score > buy_score:
|
||||
final_signal_type = SignalType.SELL
|
||||
consensus_count = sell_count
|
||||
consensus_score = sell_score
|
||||
participating_strategies = [name for name, _, _ in sell_signals]
|
||||
|
||||
if final_signal_type is None:
|
||||
return None
|
||||
|
||||
# Calculate final signal strength (normalized)
|
||||
max_possible_score = len(self._strategy_instances) * 1.0 # Max weight * max strength
|
||||
strength = min(1.0, consensus_score / max_possible_score) if max_possible_score > 0 else 0.5
|
||||
|
||||
# Use price from strongest signal
|
||||
strongest_signal = max(
|
||||
buy_signals if final_signal_type == SignalType.BUY else sell_signals,
|
||||
key=lambda x: x[2] * x[1].strength # weight * strength
|
||||
)[1]
|
||||
|
||||
return StrategySignal(
|
||||
signal_type=final_signal_type,
|
||||
symbol=strongest_signal.symbol,
|
||||
strength=strength,
|
||||
price=strongest_signal.price,
|
||||
metadata={
|
||||
'consensus_count': consensus_count,
|
||||
'consensus_score': consensus_score,
|
||||
'participating_strategies': participating_strategies,
|
||||
'buy_count': buy_count,
|
||||
'sell_count': sell_count,
|
||||
'buy_score': buy_score,
|
||||
'sell_score': sell_score,
|
||||
'strategy': 'consensus'
|
||||
}
|
||||
)
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate consensus signal from multiple strategies."""
|
||||
# Update weights periodically (every 100 ticks)
|
||||
if not hasattr(self, '_tick_count'):
|
||||
self._tick_count = 0
|
||||
self._tick_count += 1
|
||||
|
||||
if self.use_weights and self._tick_count % 100 == 0:
|
||||
await self._calculate_weights()
|
||||
|
||||
# Collect signals from all strategies
|
||||
signals = await self._collect_signals(symbol, price, timeframe, data)
|
||||
|
||||
# Aggregate signals
|
||||
consensus_signal = self._aggregate_signals(signals)
|
||||
|
||||
return consensus_signal
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal."""
|
||||
return signal if self.should_execute(signal) else None
|
||||
|
||||
0
src/strategies/grid/__init__.py
Normal file
0
src/strategies/grid/__init__.py
Normal file
109
src/strategies/grid/grid_strategy.py
Normal file
109
src/strategies/grid/grid_strategy.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Grid trading strategy."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GridStrategy(BaseStrategy):
|
||||
"""Grid trading strategy - buy at lower levels, sell at higher levels."""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize Grid strategy.
|
||||
|
||||
Parameters:
|
||||
grid_spacing: Percentage spacing between grid levels (default: 1%)
|
||||
num_levels: Number of grid levels above and below center (default: 10)
|
||||
center_price: Center price for grid (default: current price)
|
||||
profit_target: Profit target percentage (default: 2%)
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
self.grid_spacing = Decimal(str(self.parameters.get('grid_spacing', 1))) / 100
|
||||
self.num_levels = self.parameters.get('num_levels', 10)
|
||||
self.center_price = self.parameters.get('center_price')
|
||||
self.profit_target = Decimal(str(self.parameters.get('profit_target', 2))) / 100
|
||||
|
||||
# Grid levels
|
||||
self.buy_levels: List[Decimal] = []
|
||||
self.sell_levels: List[Decimal] = []
|
||||
self.positions: Dict[Decimal, Decimal] = {} # entry_price -> quantity
|
||||
|
||||
def _update_grid_levels(self, current_price: Decimal):
|
||||
"""Update grid levels based on current price."""
|
||||
if not self.center_price:
|
||||
self.center_price = current_price
|
||||
|
||||
# Calculate grid levels
|
||||
self.buy_levels = []
|
||||
self.sell_levels = []
|
||||
|
||||
for i in range(1, self.num_levels + 1):
|
||||
# Buy levels below center
|
||||
buy_price = self.center_price * (1 - self.grid_spacing * Decimal(i))
|
||||
self.buy_levels.append(buy_price)
|
||||
|
||||
# Sell levels above center
|
||||
sell_price = self.center_price * (1 + self.grid_spacing * Decimal(i))
|
||||
self.sell_levels.append(sell_price)
|
||||
|
||||
# Sort levels
|
||||
self.buy_levels.sort(reverse=True) # Highest to lowest
|
||||
self.sell_levels.sort() # Lowest to highest
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate grid trading signals."""
|
||||
self._update_grid_levels(price)
|
||||
|
||||
# Check buy levels
|
||||
for buy_level in self.buy_levels:
|
||||
if price <= buy_level and buy_level not in self.positions:
|
||||
# Buy signal
|
||||
quantity = self.parameters.get('position_size', Decimal('0.1'))
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol=symbol,
|
||||
strength=0.8,
|
||||
price=buy_level,
|
||||
quantity=quantity,
|
||||
metadata={
|
||||
'grid_level': float(buy_level),
|
||||
'strategy': 'grid',
|
||||
'type': 'buy'
|
||||
}
|
||||
)
|
||||
|
||||
# Check sell levels (profit taking)
|
||||
for entry_price, quantity in list(self.positions.items()):
|
||||
profit_pct = (price - entry_price) / entry_price
|
||||
if profit_pct >= self.profit_target:
|
||||
# Sell signal for profit taking
|
||||
del self.positions[entry_price]
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.SELL,
|
||||
symbol=symbol,
|
||||
strength=1.0,
|
||||
price=price,
|
||||
quantity=quantity,
|
||||
metadata={
|
||||
'entry_price': float(entry_price),
|
||||
'profit_pct': float(profit_pct * 100),
|
||||
'strategy': 'grid',
|
||||
'type': 'profit_take'
|
||||
}
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal and track positions."""
|
||||
if signal.signal_type == SignalType.BUY:
|
||||
# Track position
|
||||
self.positions[signal.price] = signal.quantity
|
||||
elif signal.signal_type == SignalType.SELL:
|
||||
# Position already removed in on_tick
|
||||
pass
|
||||
|
||||
return signal if self.should_execute(signal) else None
|
||||
5
src/strategies/market_making/__init__.py
Normal file
5
src/strategies/market_making/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Market making strategy package."""
|
||||
|
||||
from .market_making_strategy import MarketMakingStrategy
|
||||
|
||||
__all__ = ['MarketMakingStrategy']
|
||||
206
src/strategies/market_making/market_making_strategy.py
Normal file
206
src/strategies/market_making/market_making_strategy.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Market Making Strategy - Profits from the bid-ask spread in sideways markets."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from ..base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.pricing_service import get_pricing_service
|
||||
from src.data.indicators import get_indicators
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MarketMakingStrategy(BaseStrategy):
|
||||
"""
|
||||
Market Making Strategy that profits from the bid-ask spread.
|
||||
|
||||
Logic:
|
||||
1. Calculate mid-price from order book or last trade.
|
||||
2. Place limit BUY order at (mid - spread%).
|
||||
3. Place limit SELL order at (mid + spread%).
|
||||
4. Re-quote orders when price moves beyond threshold.
|
||||
5. Use inventory skew to manage risk (if holding too much, lower sell price).
|
||||
|
||||
Best suited for:
|
||||
- Sideways/ranging markets with low volatility
|
||||
- High-liquidity pairs with tight spreads
|
||||
|
||||
Risks:
|
||||
- Adverse selection (getting filled on wrong side during trends)
|
||||
- Inventory accumulation
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameters: Dict[str, Any], timeframes: List[str] = None):
|
||||
super().__init__(name, parameters, timeframes)
|
||||
|
||||
# Strategy parameters
|
||||
self.spread_percent = float(parameters.get('spread_percent', 0.2)) / 100 # Convert to decimal
|
||||
self.requote_threshold = float(parameters.get('requote_threshold', 0.5)) / 100
|
||||
self.max_inventory = Decimal(str(parameters.get('max_inventory', 1.0)))
|
||||
self.inventory_skew_factor = float(parameters.get('inventory_skew_factor', 0.5))
|
||||
self.min_adx = float(parameters.get('min_adx', 20)) # Only make markets when ADX < this (low trend)
|
||||
self.order_size_percent = float(parameters.get('order_size_percent', 5)) / 100
|
||||
|
||||
self.pricing_service = get_pricing_service()
|
||||
self.indicators = get_indicators()
|
||||
|
||||
# Track active orders and inventory
|
||||
self._current_inventory = Decimal('0')
|
||||
self._last_mid_price: Optional[Decimal] = None
|
||||
self._last_quote_time: Optional[datetime] = None
|
||||
self._active_orders: Dict[str, Any] = {} # side -> order info
|
||||
|
||||
def _calculate_skewed_spread(self, inventory: Decimal) -> tuple[float, float]:
|
||||
"""Calculate bid/ask spread with inventory skew.
|
||||
|
||||
When inventory is positive (long), we want to sell more aggressively.
|
||||
When inventory is negative (short), we want to buy more aggressively.
|
||||
|
||||
Returns:
|
||||
Tuple of (bid_spread, ask_spread) as percentages
|
||||
"""
|
||||
base_spread = self.spread_percent
|
||||
|
||||
# Calculate skew based on inventory relative to max
|
||||
if self.max_inventory > 0:
|
||||
inventory_ratio = float(inventory / self.max_inventory)
|
||||
else:
|
||||
inventory_ratio = 0.0
|
||||
|
||||
# Clamp between -1 and 1
|
||||
inventory_ratio = max(-1.0, min(1.0, inventory_ratio))
|
||||
|
||||
# Apply skew
|
||||
skew = inventory_ratio * self.inventory_skew_factor * base_spread
|
||||
|
||||
# If positive inventory, tighten ask (lower sell price), widen bid
|
||||
# If negative inventory, tighten bid (higher buy price), widen ask
|
||||
bid_spread = base_spread + skew # Positive inventory -> wider bid
|
||||
ask_spread = base_spread - skew # Positive inventory -> tighter ask
|
||||
|
||||
# Ensure spreads are positive
|
||||
bid_spread = max(0.0001, bid_spread)
|
||||
ask_spread = max(0.0001, ask_spread)
|
||||
|
||||
return bid_spread, ask_spread
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Check for market making opportunities."""
|
||||
if not self.enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
current_price = price
|
||||
|
||||
# Fetch OHLCV for ADX calculation
|
||||
ohlcv = self.pricing_service.get_ohlcv(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=30
|
||||
)
|
||||
|
||||
if ohlcv and len(ohlcv) >= 14:
|
||||
df_data = {
|
||||
'high': [float(c[2]) for c in ohlcv],
|
||||
'low': [float(c[3]) for c in ohlcv],
|
||||
'close': [float(c[4]) for c in ohlcv]
|
||||
}
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(df_data)
|
||||
|
||||
adx = self.indicators.adx(df['high'], df['low'], df['close'], period=14)
|
||||
current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0
|
||||
|
||||
# Only make markets in low-trend environments
|
||||
if current_adx > self.min_adx:
|
||||
self.logger.debug(
|
||||
f"Market Making skipped: ADX {current_adx:.1f} > {self.min_adx} (trending market)"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
current_adx = 0.0
|
||||
|
||||
# Check if we need to requote
|
||||
should_requote = False
|
||||
|
||||
if self._last_mid_price is None:
|
||||
should_requote = True
|
||||
else:
|
||||
price_change = abs(float(current_price - self._last_mid_price) / float(self._last_mid_price))
|
||||
if price_change > self.requote_threshold:
|
||||
should_requote = True
|
||||
self.logger.debug(f"Requote triggered: price moved {price_change:.2%}")
|
||||
|
||||
if not should_requote:
|
||||
return None
|
||||
|
||||
# Calculate skewed spreads based on inventory
|
||||
bid_spread, ask_spread = self._calculate_skewed_spread(self._current_inventory)
|
||||
|
||||
# Calculate quote prices
|
||||
bid_price = current_price * Decimal(str(1 - bid_spread))
|
||||
ask_price = current_price * Decimal(str(1 + ask_spread))
|
||||
|
||||
# Round to 2 decimal places
|
||||
bid_price = bid_price.quantize(Decimal('0.01'))
|
||||
ask_price = ask_price.quantize(Decimal('0.01'))
|
||||
|
||||
self._last_mid_price = current_price
|
||||
self._last_quote_time = datetime.now()
|
||||
|
||||
self.logger.info(
|
||||
f"Market Making {symbol}: Mid={current_price}, "
|
||||
f"Bid={bid_price} ({bid_spread:.3%}), Ask={ask_price} ({ask_spread:.3%}), "
|
||||
f"Inventory={self._current_inventory}, ADX={current_adx:.1f}"
|
||||
)
|
||||
|
||||
# Generate signal for placing both limit orders
|
||||
# The execution engine will need special handling for this
|
||||
signal = StrategySignal(
|
||||
signal_type=SignalType.HOLD, # Special: not BUY or SELL, but both
|
||||
symbol=symbol,
|
||||
strength=1.0 - (current_adx / 100), # Higher strength when less trending
|
||||
price=current_price,
|
||||
metadata={
|
||||
"strategy": "market_making",
|
||||
"order_type": "limit_pair", # Indicates we want both bid and ask
|
||||
"bid_price": float(bid_price),
|
||||
"ask_price": float(ask_price),
|
||||
"bid_spread": bid_spread,
|
||||
"ask_spread": ask_spread,
|
||||
"inventory": float(self._current_inventory),
|
||||
"adx": float(current_adx),
|
||||
"description": f"Market Making: Bid={bid_price}, Ask={ask_price}"
|
||||
}
|
||||
)
|
||||
return signal
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in MarketMakingStrategy: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def update_inventory(self, side: str, quantity: Decimal):
|
||||
"""Update inventory after a fill.
|
||||
|
||||
Args:
|
||||
side: 'buy' or 'sell'
|
||||
quantity: Filled quantity
|
||||
"""
|
||||
if side == 'buy':
|
||||
self._current_inventory += quantity
|
||||
elif side == 'sell':
|
||||
self._current_inventory -= quantity
|
||||
|
||||
self.logger.info(f"Inventory updated: {side} {quantity}, new inventory: {self._current_inventory}")
|
||||
|
||||
def get_inventory(self) -> Decimal:
|
||||
"""Get current inventory position."""
|
||||
return self._current_inventory
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal (pass-through)."""
|
||||
return signal
|
||||
0
src/strategies/momentum/__init__.py
Normal file
0
src/strategies/momentum/__init__.py
Normal file
138
src/strategies/momentum/momentum_strategy.py
Normal file
138
src/strategies/momentum/momentum_strategy.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Momentum trading strategy."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any, List
|
||||
import pandas as pd
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.indicators import get_indicators
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MomentumStrategy(BaseStrategy):
|
||||
"""Momentum-based trading strategy."""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize Momentum strategy.
|
||||
|
||||
Parameters:
|
||||
lookback_period: Lookback period for momentum calculation (default: 20)
|
||||
momentum_threshold: Minimum momentum strength to enter (default: 0.05 = 5%)
|
||||
volume_threshold: Minimum volume increase for confirmation (default: 1.5x)
|
||||
exit_threshold: Momentum reversal threshold for exit (default: -0.02 = -2%)
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
self.lookback_period = self.parameters.get('lookback_period', 20)
|
||||
self.momentum_threshold = Decimal(str(self.parameters.get('momentum_threshold', 0.05)))
|
||||
self.volume_threshold = self.parameters.get('volume_threshold', 1.5)
|
||||
self.exit_threshold = Decimal(str(self.parameters.get('exit_threshold', -0.02)))
|
||||
|
||||
self.indicators = get_indicators()
|
||||
self._price_history: List[float] = []
|
||||
self._volume_history: List[float] = []
|
||||
self._in_position = False
|
||||
self._entry_price: Optional[Decimal] = None
|
||||
|
||||
def _calculate_momentum(self, prices: pd.Series) -> float:
|
||||
"""Calculate price momentum.
|
||||
|
||||
Args:
|
||||
prices: Series of prices
|
||||
|
||||
Returns:
|
||||
Momentum value (percentage change)
|
||||
"""
|
||||
if len(prices) < self.lookback_period:
|
||||
return 0.0
|
||||
|
||||
recent_prices = prices[-self.lookback_period:]
|
||||
old_price = recent_prices.iloc[0]
|
||||
new_price = recent_prices.iloc[-1]
|
||||
|
||||
if old_price == 0:
|
||||
return 0.0
|
||||
|
||||
return float((new_price - old_price) / old_price)
|
||||
|
||||
def _check_volume_confirmation(self, volumes: pd.Series) -> bool:
|
||||
"""Check if volume confirms momentum.
|
||||
|
||||
Args:
|
||||
volumes: Series of volumes
|
||||
|
||||
Returns:
|
||||
True if volume confirms
|
||||
"""
|
||||
if len(volumes) < self.lookback_period:
|
||||
return False
|
||||
|
||||
recent_volumes = volumes[-self.lookback_period:]
|
||||
avg_volume = recent_volumes.mean()
|
||||
current_volume = recent_volumes.iloc[-1]
|
||||
|
||||
return current_volume >= avg_volume * self.volume_threshold
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate momentum-based signals."""
|
||||
# Add to history
|
||||
self._price_history.append(float(price))
|
||||
self._volume_history.append(float(data.get('volume', 0)))
|
||||
|
||||
if len(self._price_history) < self.lookback_period + 1:
|
||||
return None
|
||||
|
||||
prices = pd.Series(self._price_history)
|
||||
volumes = pd.Series(self._volume_history)
|
||||
|
||||
# Calculate momentum
|
||||
momentum = self._calculate_momentum(prices)
|
||||
|
||||
if not self._in_position:
|
||||
# Entry logic
|
||||
if momentum >= float(self.momentum_threshold):
|
||||
# Check volume confirmation
|
||||
if self._check_volume_confirmation(volumes):
|
||||
self._in_position = True
|
||||
self._entry_price = price
|
||||
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol=symbol,
|
||||
strength=min(momentum / float(self.momentum_threshold), 1.0),
|
||||
price=price,
|
||||
metadata={
|
||||
'momentum': momentum,
|
||||
'volume_confirmed': True,
|
||||
'strategy': 'momentum',
|
||||
'type': 'entry'
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Exit logic
|
||||
if momentum <= float(self.exit_threshold):
|
||||
self._in_position = False
|
||||
entry_price = self._entry_price or price
|
||||
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.SELL,
|
||||
symbol=symbol,
|
||||
strength=1.0,
|
||||
price=price,
|
||||
metadata={
|
||||
'momentum': momentum,
|
||||
'entry_price': float(entry_price),
|
||||
'strategy': 'momentum',
|
||||
'type': 'exit'
|
||||
}
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal."""
|
||||
if signal.signal_type == SignalType.SELL:
|
||||
self._in_position = False
|
||||
self._entry_price = None
|
||||
|
||||
return signal if self.should_execute(signal) else None
|
||||
370
src/strategies/scheduler.py
Normal file
370
src/strategies/scheduler.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""Strategy scheduling system with time and condition-based triggers."""
|
||||
|
||||
from datetime import datetime, time
|
||||
from typing import Optional, Callable, Dict, Any, List
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from src.core.logger import get_logger
|
||||
from src.core.database import get_database, Strategy
|
||||
from src.strategies.base import get_strategy_registry, SignalType
|
||||
from src.trading.engine import get_trading_engine
|
||||
from src.data.pricing_service import get_pricing_service
|
||||
from src.core.database import OrderSide, OrderType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StrategyScheduler:
|
||||
"""Schedules strategy execution based on time and conditions."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize strategy scheduler."""
|
||||
self.scheduler = BackgroundScheduler()
|
||||
self.scheduler.start()
|
||||
self.jobs: Dict[int, str] = {} # strategy_id -> job_id
|
||||
self._active_strategies: Dict[int, Dict[str, Any]] = {} # strategy_id -> status info
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def schedule_time_based(
|
||||
self,
|
||||
strategy_id: int,
|
||||
schedule_config: Dict[str, Any],
|
||||
callback: Callable
|
||||
) -> bool:
|
||||
"""Schedule strategy based on time.
|
||||
|
||||
Args:
|
||||
strategy_id: Strategy ID
|
||||
schedule_config: Schedule configuration
|
||||
- type: 'daily', 'weekly', 'cron'
|
||||
- time: Time string (HH:MM) for daily
|
||||
- days: List of days for weekly
|
||||
- cron: Cron expression
|
||||
callback: Function to call
|
||||
|
||||
Returns:
|
||||
True if scheduling successful
|
||||
"""
|
||||
try:
|
||||
# Remove existing job if any
|
||||
if strategy_id in self.jobs:
|
||||
self.unschedule(strategy_id)
|
||||
|
||||
schedule_type = schedule_config.get('type', 'daily')
|
||||
|
||||
if schedule_type == 'daily':
|
||||
time_str = schedule_config.get('time', '09:00')
|
||||
hour, minute = map(int, time_str.split(':'))
|
||||
trigger = CronTrigger(hour=hour, minute=minute)
|
||||
elif schedule_type == 'weekly':
|
||||
days = schedule_config.get('days', [0, 1, 2, 3, 4]) # Monday-Friday
|
||||
time_str = schedule_config.get('time', '09:00')
|
||||
hour, minute = map(int, time_str.split(':'))
|
||||
trigger = CronTrigger(day_of_week=','.join(map(str, days)), hour=hour, minute=minute)
|
||||
elif schedule_type == 'cron':
|
||||
cron_expr = schedule_config.get('cron')
|
||||
trigger = CronTrigger.from_crontab(cron_expr)
|
||||
elif schedule_type == 'interval':
|
||||
interval = schedule_config.get('interval', 60) # seconds
|
||||
trigger = IntervalTrigger(seconds=interval)
|
||||
else:
|
||||
logger.error(f"Unknown schedule type: {schedule_type}")
|
||||
return False
|
||||
|
||||
job_id = f"strategy_{strategy_id}"
|
||||
self.scheduler.add_job(
|
||||
callback,
|
||||
trigger=trigger,
|
||||
id=job_id,
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
self.jobs[strategy_id] = job_id
|
||||
logger.info(f"Scheduled strategy {strategy_id} with {schedule_type} trigger")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule strategy {strategy_id}: {e}")
|
||||
return False
|
||||
|
||||
def schedule_condition_based(
|
||||
self,
|
||||
strategy_id: int,
|
||||
condition: Callable[[], bool],
|
||||
callback: Callable,
|
||||
check_interval: int = 60
|
||||
) -> bool:
|
||||
"""Schedule strategy based on condition.
|
||||
|
||||
Args:
|
||||
strategy_id: Strategy ID
|
||||
condition: Function that returns True when condition is met
|
||||
callback: Function to call when condition is met
|
||||
check_interval: Interval to check condition (seconds)
|
||||
|
||||
Returns:
|
||||
True if scheduling successful
|
||||
"""
|
||||
def check_and_execute():
|
||||
if condition():
|
||||
callback()
|
||||
|
||||
return self.schedule_time_based(
|
||||
strategy_id,
|
||||
{'type': 'interval', 'interval': check_interval},
|
||||
check_and_execute
|
||||
)
|
||||
|
||||
def unschedule(self, strategy_id: int):
|
||||
"""Unschedule a strategy.
|
||||
|
||||
Args:
|
||||
strategy_id: Strategy ID
|
||||
"""
|
||||
if strategy_id in self.jobs:
|
||||
job_id = self.jobs[strategy_id]
|
||||
try:
|
||||
self.scheduler.remove_job(job_id)
|
||||
del self.jobs[strategy_id]
|
||||
logger.info(f"Unscheduled strategy {strategy_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unschedule strategy {strategy_id}: {e}")
|
||||
|
||||
def is_scheduled(self, strategy_id: int) -> bool:
|
||||
"""Check if strategy is scheduled.
|
||||
|
||||
Args:
|
||||
strategy_id: Strategy ID
|
||||
|
||||
Returns:
|
||||
True if scheduled
|
||||
"""
|
||||
return strategy_id in self.jobs
|
||||
|
||||
|
||||
def start_strategy(self, strategy_id: int):
|
||||
"""Start a strategy execution loop."""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import create_engine
|
||||
from src.core.config import get_config
|
||||
|
||||
config = get_config()
|
||||
db_url = config.get("database.url", "postgresql://localhost/crypto_trader")
|
||||
|
||||
# Use sync engine for scheduler context
|
||||
engine = create_engine(db_url)
|
||||
|
||||
with Session(engine) as session:
|
||||
try:
|
||||
strategy_model = session.query(Strategy).filter_by(id=strategy_id).first()
|
||||
if not strategy_model:
|
||||
logger.error(f"Cannot start strategy {strategy_id}: Not found")
|
||||
return
|
||||
|
||||
# Instantiate strategy
|
||||
registry = get_strategy_registry()
|
||||
strategy_instance = registry.create_instance(
|
||||
strategy_id=strategy_id,
|
||||
name=strategy_model.strategy_type,
|
||||
parameters=strategy_model.parameters,
|
||||
timeframes=strategy_model.timeframes
|
||||
)
|
||||
|
||||
if not strategy_instance:
|
||||
logger.error(f"Failed to create instance for strategy {strategy_id}")
|
||||
return
|
||||
|
||||
strategy_instance.enabled = True
|
||||
|
||||
# Store strategy info for status tracking
|
||||
self._active_strategies[strategy_id] = {
|
||||
'instance': strategy_instance,
|
||||
'name': strategy_model.name,
|
||||
'type': strategy_model.strategy_type,
|
||||
'symbol': strategy_model.parameters.get('symbol'),
|
||||
'started_at': datetime.now(),
|
||||
'last_tick': None,
|
||||
'last_signal': None,
|
||||
'signal_count': 0,
|
||||
'error_count': 0,
|
||||
}
|
||||
|
||||
# Use 'interval' from parameters, default 60s
|
||||
interval = strategy_model.parameters.get('interval', 60)
|
||||
|
||||
def execute_wrapper():
|
||||
self._execute_strategy_sync(strategy_id)
|
||||
|
||||
self.schedule_time_based(
|
||||
strategy_id,
|
||||
{'type': 'interval', 'interval': interval},
|
||||
execute_wrapper
|
||||
)
|
||||
logger.info(f"Started strategy {strategy_id} ({strategy_model.name})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initiating strategy {strategy_id}: {e}")
|
||||
|
||||
def stop_strategy(self, strategy_id: int):
|
||||
"""Stop a strategy."""
|
||||
self.unschedule(strategy_id)
|
||||
|
||||
# Remove from active strategies
|
||||
if strategy_id in self._active_strategies:
|
||||
del self._active_strategies[strategy_id]
|
||||
|
||||
logger.info(f"Stopped strategy {strategy_id}")
|
||||
|
||||
def get_strategy_status(self, strategy_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get status of a running strategy."""
|
||||
if strategy_id not in self._active_strategies:
|
||||
return None
|
||||
|
||||
info = self._active_strategies[strategy_id]
|
||||
return {
|
||||
'strategy_id': strategy_id,
|
||||
'name': info['name'],
|
||||
'type': info['type'],
|
||||
'symbol': info['symbol'],
|
||||
'running': True,
|
||||
'started_at': info['started_at'].isoformat() if info['started_at'] else None,
|
||||
'last_tick': info['last_tick'].isoformat() if info['last_tick'] else None,
|
||||
'last_signal': info['last_signal'],
|
||||
'signal_count': info['signal_count'],
|
||||
'error_count': info['error_count'],
|
||||
}
|
||||
|
||||
def get_all_active_strategies(self) -> List[Dict[str, Any]]:
|
||||
"""Get status of all active strategies."""
|
||||
return [self.get_strategy_status(sid) for sid in self._active_strategies.keys()]
|
||||
|
||||
def _execute_strategy_sync(self, strategy_id: int):
|
||||
"""Execute a single strategy cycle (synchronous wrapper)."""
|
||||
if strategy_id not in self._active_strategies:
|
||||
logger.warning(f"Strategy {strategy_id} not in active list, skipping execution")
|
||||
return
|
||||
|
||||
info = self._active_strategies[strategy_id]
|
||||
strategy_instance = info['instance']
|
||||
|
||||
try:
|
||||
# 1. Fetch Data
|
||||
symbol = strategy_instance.parameters.get('symbol')
|
||||
timeframe = strategy_instance.timeframes[0] if strategy_instance.timeframes else '1h'
|
||||
|
||||
pricing_service = get_pricing_service()
|
||||
ticker = pricing_service.get_ticker(symbol) # Sync call
|
||||
current_price = ticker.get('last')
|
||||
|
||||
if not current_price:
|
||||
logger.debug(f"No price for {symbol}, skipping")
|
||||
return
|
||||
|
||||
# Update last tick
|
||||
info['last_tick'] = datetime.now()
|
||||
|
||||
# 2. Run async on_tick in sync context
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
signal = loop.run_until_complete(
|
||||
strategy_instance.on_tick(symbol, current_price, timeframe, ticker)
|
||||
)
|
||||
|
||||
if not signal:
|
||||
logger.debug(f"Strategy {strategy_id}: No signal at price {current_price}")
|
||||
return
|
||||
|
||||
# Update signal tracking
|
||||
info['last_signal'] = {
|
||||
'type': signal.signal_type.value,
|
||||
'strength': signal.strength,
|
||||
'price': float(signal.price),
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'metadata': signal.metadata,
|
||||
}
|
||||
info['signal_count'] += 1
|
||||
|
||||
logger.info(f"Strategy {strategy_id} generated signal: {signal.signal_type.value} @ {current_price}")
|
||||
|
||||
# 3. Handle Signal (Execution)
|
||||
trading_engine = get_trading_engine()
|
||||
|
||||
# Check for Pairs Trading Metadata
|
||||
secondary_symbol = signal.metadata.get('secondary_symbol')
|
||||
secondary_action = signal.metadata.get('secondary_action')
|
||||
|
||||
if secondary_symbol and secondary_action:
|
||||
# Multi-Leg Execution for Pairs Trading
|
||||
logger.info(f"Executing Pairs Trade: {symbol} ({signal.signal_type.value}) & {secondary_symbol} ({secondary_action})")
|
||||
|
||||
# Execute Primary Leg
|
||||
loop.run_until_complete(trading_engine.execute_order(
|
||||
exchange_id=1,
|
||||
strategy_id=strategy_id,
|
||||
symbol=symbol,
|
||||
side=OrderSide(signal.signal_type.value),
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=strategy_instance.calculate_position_size(
|
||||
signal, trading_engine.paper_trading.get_balance(), current_price
|
||||
),
|
||||
paper_trading=True
|
||||
))
|
||||
|
||||
# Execute Secondary Leg
|
||||
sec_ticker = pricing_service.get_ticker(secondary_symbol)
|
||||
sec_price = sec_ticker.get('last')
|
||||
if sec_price:
|
||||
loop.run_until_complete(trading_engine.execute_order(
|
||||
exchange_id=1,
|
||||
strategy_id=strategy_id,
|
||||
symbol=secondary_symbol,
|
||||
side=OrderSide(secondary_action),
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=strategy_instance.calculate_position_size(
|
||||
signal, trading_engine.paper_trading.get_balance(), sec_price
|
||||
),
|
||||
paper_trading=True,
|
||||
))
|
||||
|
||||
else:
|
||||
# Standard Single Leg Execution
|
||||
loop.run_until_complete(trading_engine.execute_order(
|
||||
exchange_id=1,
|
||||
strategy_id=strategy_id,
|
||||
symbol=symbol,
|
||||
side=OrderSide(signal.signal_type.value),
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=signal.quantity or strategy_instance.calculate_position_size(
|
||||
signal, trading_engine.paper_trading.get_balance(), current_price
|
||||
),
|
||||
paper_trading=True
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Strategy {strategy_id} execution error: {e}")
|
||||
info['error_count'] += 1
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown scheduler."""
|
||||
self.scheduler.shutdown()
|
||||
|
||||
|
||||
# Global scheduler
|
||||
_scheduler: Optional[StrategyScheduler] = None
|
||||
|
||||
|
||||
def get_scheduler() -> StrategyScheduler:
|
||||
"""Get global strategy scheduler instance."""
|
||||
global _scheduler
|
||||
if _scheduler is None:
|
||||
_scheduler = StrategyScheduler()
|
||||
return _scheduler
|
||||
|
||||
|
||||
5
src/strategies/sentiment/__init__.py
Normal file
5
src/strategies/sentiment/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Sentiment strategy package."""
|
||||
|
||||
from .sentiment_strategy import SentimentStrategy
|
||||
|
||||
__all__ = ['SentimentStrategy']
|
||||
208
src/strategies/sentiment/sentiment_strategy.py
Normal file
208
src/strategies/sentiment/sentiment_strategy.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Sentiment-Driven Strategy - Trades based on news sentiment and market fear/greed."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, Optional, List
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from ..base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.news_collector import get_news_collector
|
||||
from src.core.logger import get_logger
|
||||
from src.core.config import get_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SentimentStrategy(BaseStrategy):
|
||||
"""
|
||||
Sentiment-Driven Strategy that trades based on news sentiment and market fear.
|
||||
|
||||
Modes:
|
||||
1. Contrarian: Buy during extreme fear, sell during extreme greed
|
||||
2. Momentum: Follow positive news cycles
|
||||
3. Combo: Buy on positive news + fear (best opportunity)
|
||||
|
||||
Signal Logic:
|
||||
- Aggregates sentiment from recent news headlines
|
||||
- Combines with Fear & Greed Index when available
|
||||
- Generates signals based on configured mode
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameters: Dict[str, Any], timeframes: List[str] = None):
|
||||
super().__init__(name, parameters, timeframes)
|
||||
|
||||
# Strategy parameters
|
||||
self.mode = parameters.get('mode', 'contrarian') # 'contrarian', 'momentum', 'combo'
|
||||
self.min_sentiment_score = float(parameters.get('min_sentiment_score', 0.5))
|
||||
self.fear_threshold = int(parameters.get('fear_threshold', 25)) # 0-100
|
||||
self.greed_threshold = int(parameters.get('greed_threshold', 75)) # 0-100
|
||||
self.news_lookback_hours = int(parameters.get('news_lookback_hours', 24))
|
||||
self.min_headlines = int(parameters.get('min_headlines', 5))
|
||||
|
||||
self.news_collector = get_news_collector()
|
||||
self.config = get_config()
|
||||
|
||||
# Simple sentiment keywords
|
||||
self.positive_keywords = [
|
||||
'surge', 'rally', 'soar', 'jump', 'gain', 'bullish', 'adoption',
|
||||
'partnership', 'etf', 'approval', 'institutional', 'buy', 'long',
|
||||
'breakout', 'record', 'high', 'upgrade', 'positive', 'growth'
|
||||
]
|
||||
self.negative_keywords = [
|
||||
'crash', 'drop', 'plunge', 'fall', 'dump', 'bearish', 'hack',
|
||||
'scandal', 'lawsuit', 'ban', 'regulation', 'sell', 'short',
|
||||
'breakdown', 'low', 'downgrade', 'negative', 'fear', 'panic'
|
||||
]
|
||||
|
||||
# Cache for fear & greed
|
||||
self._fear_greed_cache: Optional[Dict[str, Any]] = None
|
||||
self._fear_greed_timestamp: Optional[datetime] = None
|
||||
|
||||
def _analyze_sentiment(self, headlines: List[str]) -> float:
|
||||
"""Calculate aggregate sentiment score from headlines.
|
||||
|
||||
Returns:
|
||||
Score from -1.0 (very negative) to +1.0 (very positive)
|
||||
"""
|
||||
if not headlines:
|
||||
return 0.0
|
||||
|
||||
total_score = 0.0
|
||||
|
||||
for headline in headlines:
|
||||
headline_lower = headline.lower()
|
||||
pos_count = sum(1 for kw in self.positive_keywords if kw in headline_lower)
|
||||
neg_count = sum(1 for kw in self.negative_keywords if kw in headline_lower)
|
||||
|
||||
if pos_count + neg_count > 0:
|
||||
# Normalize to -1 to 1
|
||||
total_score += (pos_count - neg_count) / (pos_count + neg_count)
|
||||
|
||||
# Average across all headlines
|
||||
return total_score / len(headlines)
|
||||
|
||||
async def _fetch_fear_greed_index(self) -> Optional[int]:
|
||||
"""Fetch the Fear & Greed Index from alternative.me API.
|
||||
|
||||
Returns:
|
||||
Fear & Greed value (0-100) or None if unavailable
|
||||
"""
|
||||
# Check cache (valid for 1 hour)
|
||||
if (self._fear_greed_cache and self._fear_greed_timestamp and
|
||||
(datetime.now() - self._fear_greed_timestamp).total_seconds() < 3600):
|
||||
return self._fear_greed_cache.get('value')
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
url = "https://api.alternative.me/fng/"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=10) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data.get('data'):
|
||||
value = int(data['data'][0]['value'])
|
||||
self._fear_greed_cache = {'value': value}
|
||||
self._fear_greed_timestamp = datetime.now()
|
||||
return value
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Failed to fetch Fear & Greed Index: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Check for sentiment-based signals."""
|
||||
if not self.enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Extract base symbol (e.g., "BTC" from "BTC/USD")
|
||||
base_symbol = symbol.split('/')[0] if '/' in symbol else symbol
|
||||
|
||||
# Fetch news headlines
|
||||
headlines = await self.news_collector.fetch_headlines(
|
||||
symbols=[base_symbol],
|
||||
max_age_hours=self.news_lookback_hours
|
||||
)
|
||||
|
||||
if len(headlines) < self.min_headlines:
|
||||
self.logger.debug(f"Not enough headlines: {len(headlines)} < {self.min_headlines}")
|
||||
return None
|
||||
|
||||
# Calculate sentiment score
|
||||
sentiment = self._analyze_sentiment(headlines)
|
||||
|
||||
# Fetch Fear & Greed Index
|
||||
fear_greed = await self._fetch_fear_greed_index()
|
||||
|
||||
self.logger.info(
|
||||
f"Sentiment Analysis {symbol}: Score={sentiment:.2f}, "
|
||||
f"Fear&Greed={fear_greed}, Headlines={len(headlines)}"
|
||||
)
|
||||
|
||||
signal_type = None
|
||||
description = ""
|
||||
|
||||
if self.mode == 'contrarian':
|
||||
# Contrarian: Buy on extreme fear, sell on extreme greed
|
||||
if fear_greed is not None:
|
||||
if fear_greed < self.fear_threshold:
|
||||
signal_type = SignalType.BUY
|
||||
description = f"Contrarian BUY: Extreme Fear ({fear_greed})"
|
||||
elif fear_greed > self.greed_threshold:
|
||||
signal_type = SignalType.SELL
|
||||
description = f"Contrarian SELL: Extreme Greed ({fear_greed})"
|
||||
|
||||
elif self.mode == 'momentum':
|
||||
# Momentum: Follow positive/negative sentiment
|
||||
if sentiment > self.min_sentiment_score:
|
||||
signal_type = SignalType.BUY
|
||||
description = f"Momentum BUY: Positive Sentiment ({sentiment:.2f})"
|
||||
elif sentiment < -self.min_sentiment_score:
|
||||
signal_type = SignalType.SELL
|
||||
description = f"Momentum SELL: Negative Sentiment ({sentiment:.2f})"
|
||||
|
||||
elif self.mode == 'combo':
|
||||
# Combo: Positive news + fearful market = best buying opportunity
|
||||
if fear_greed is not None:
|
||||
if sentiment > self.min_sentiment_score and fear_greed < 40:
|
||||
signal_type = SignalType.BUY
|
||||
description = f"Combo BUY: Positive News + Fear ({sentiment:.2f}, FG={fear_greed})"
|
||||
elif sentiment < -self.min_sentiment_score and fear_greed > 60:
|
||||
signal_type = SignalType.SELL
|
||||
description = f"Combo SELL: Negative News + Greed ({sentiment:.2f}, FG={fear_greed})"
|
||||
|
||||
if signal_type:
|
||||
# Calculate strength based on sentiment and fear/greed extremity
|
||||
strength = abs(sentiment)
|
||||
if fear_greed is not None:
|
||||
fg_extremity = abs(50 - fear_greed) / 50 # How far from neutral
|
||||
strength = (strength + fg_extremity) / 2
|
||||
|
||||
signal = StrategySignal(
|
||||
signal_type=signal_type,
|
||||
symbol=symbol,
|
||||
strength=min(1.0, strength),
|
||||
price=price,
|
||||
metadata={
|
||||
"strategy": "sentiment",
|
||||
"mode": self.mode,
|
||||
"sentiment_score": float(sentiment),
|
||||
"fear_greed": fear_greed,
|
||||
"headline_count": len(headlines),
|
||||
"sample_headlines": headlines[:3], # Include first 3 for context
|
||||
"description": description
|
||||
}
|
||||
)
|
||||
self.logger.info(f"Sentiment Signal: {description}")
|
||||
return signal
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in SentimentStrategy: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal (pass-through)."""
|
||||
return signal
|
||||
0
src/strategies/technical/__init__.py
Normal file
0
src/strategies/technical/__init__.py
Normal file
227
src/strategies/technical/bollinger_mean_reversion.py
Normal file
227
src/strategies/technical/bollinger_mean_reversion.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Bollinger Bands mean reversion strategy.
|
||||
|
||||
Buys when price touches lower band in uptrend, sells when price
|
||||
touches upper band in downtrend. Works well in ranging markets.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any, List
|
||||
import pandas as pd
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.indicators import get_indicators
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BollingerMeanReversionStrategy(BaseStrategy):
|
||||
"""Bollinger Bands mean reversion strategy.
|
||||
|
||||
This strategy trades mean reversion using Bollinger Bands:
|
||||
- Buy when price touches lower band in uptrend
|
||||
- Sell when price touches upper band in downtrend
|
||||
- Uses trend filter to avoid counter-trend trades
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize Bollinger Bands mean reversion strategy.
|
||||
|
||||
Parameters:
|
||||
period: Moving average period for Bollinger Bands (default 20)
|
||||
std_dev: Standard deviation multiplier (default 2.0)
|
||||
trend_filter: Enable trend filter (default True)
|
||||
trend_ma_period: Moving average period for trend detection (default 50)
|
||||
entry_threshold: How close price must be to band (0.0-1.0, default 0.95)
|
||||
exit_threshold: Exit when price reaches middle band (default 0.5)
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
|
||||
self.period = self.parameters.get('period', 20)
|
||||
self.std_dev = self.parameters.get('std_dev', 2.0)
|
||||
self.trend_filter = self.parameters.get('trend_filter', True)
|
||||
self.trend_ma_period = self.parameters.get('trend_ma_period', 50)
|
||||
self.entry_threshold = self.parameters.get('entry_threshold', 0.95)
|
||||
self.exit_threshold = self.parameters.get('exit_threshold', 0.5)
|
||||
|
||||
self.indicators = get_indicators()
|
||||
self._price_history: List[float] = []
|
||||
self._in_position = False
|
||||
self._entry_price: Optional[Decimal] = None
|
||||
|
||||
def _get_trend_direction(self, prices: pd.Series) -> Optional[str]:
|
||||
"""Determine trend direction.
|
||||
|
||||
Args:
|
||||
prices: Price series
|
||||
|
||||
Returns:
|
||||
'up', 'down', or None
|
||||
"""
|
||||
if len(prices) < self.trend_ma_period:
|
||||
return None
|
||||
|
||||
sma = self.indicators.sma(prices, self.trend_ma_period)
|
||||
current_price = prices.iloc[-1]
|
||||
sma_value = sma.iloc[-1]
|
||||
|
||||
# Use percentage difference to determine trend
|
||||
price_diff = (current_price - sma_value) / sma_value if sma_value > 0 else 0.0
|
||||
|
||||
if price_diff > 0.01: # 1% above SMA = uptrend
|
||||
return 'up'
|
||||
elif price_diff < -0.01: # 1% below SMA = downtrend
|
||||
return 'down'
|
||||
|
||||
return None
|
||||
|
||||
def _check_band_touch(self, price: float, upper: float, middle: float, lower: float) -> Optional[str]:
|
||||
"""Check if price is touching a band.
|
||||
|
||||
Args:
|
||||
price: Current price
|
||||
upper: Upper Bollinger Band
|
||||
middle: Middle Bollinger Band
|
||||
lower: Lower Bollinger Band
|
||||
|
||||
Returns:
|
||||
'upper' if touching upper band, 'lower' if touching lower band, None otherwise
|
||||
"""
|
||||
band_width = upper - lower
|
||||
if band_width == 0:
|
||||
return None
|
||||
|
||||
# Calculate position within bands (0 = lower, 1 = upper)
|
||||
position = (price - lower) / band_width
|
||||
|
||||
# Check if touching lower band
|
||||
if position <= (1.0 - self.entry_threshold):
|
||||
return 'lower'
|
||||
|
||||
# Check if touching upper band
|
||||
if position >= self.entry_threshold:
|
||||
return 'upper'
|
||||
|
||||
return None
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate signal based on Bollinger Bands mean reversion."""
|
||||
# Add price to history
|
||||
self._price_history.append(float(price))
|
||||
|
||||
# Need enough data for Bollinger Bands and trend detection
|
||||
max_period = max(self.period, self.trend_ma_period)
|
||||
|
||||
if len(self._price_history) < max_period + 1:
|
||||
return None
|
||||
|
||||
prices = pd.Series(self._price_history[-max_period-1:])
|
||||
current_price = float(price)
|
||||
|
||||
# Calculate Bollinger Bands
|
||||
bb_data = self.indicators.bollinger_bands(prices, period=self.period, std_dev=self.std_dev)
|
||||
|
||||
if len(bb_data['upper']) == 0 or pd.isna(bb_data['upper'].iloc[-1]):
|
||||
return None
|
||||
|
||||
upper = bb_data['upper'].iloc[-1]
|
||||
middle = bb_data['middle'].iloc[-1]
|
||||
lower = bb_data['lower'].iloc[-1]
|
||||
|
||||
# Determine trend direction if using trend filter
|
||||
trend = None
|
||||
if self.trend_filter:
|
||||
trend = self._get_trend_direction(prices)
|
||||
|
||||
# Check for entry signals
|
||||
if not self._in_position:
|
||||
band_touch = self._check_band_touch(current_price, upper, middle, lower)
|
||||
|
||||
if band_touch == 'lower':
|
||||
# Price touching lower band - potential buy signal
|
||||
if not self.trend_filter or trend == 'up':
|
||||
# Only buy in uptrend (mean reversion back up)
|
||||
self._in_position = True
|
||||
self._entry_price = price
|
||||
|
||||
# Calculate signal strength based on how far below lower band
|
||||
distance_from_lower = (lower - current_price) / lower if lower > 0 else 0.0
|
||||
strength = min(1.0, distance_from_lower * 10) # Scale for strength
|
||||
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol=symbol,
|
||||
strength=max(0.5, strength),
|
||||
price=price,
|
||||
metadata={
|
||||
'band_touch': 'lower',
|
||||
'upper': float(upper),
|
||||
'middle': float(middle),
|
||||
'lower': float(lower),
|
||||
'price_position': (current_price - lower) / (upper - lower) if (upper - lower) > 0 else 0.5,
|
||||
'trend': trend,
|
||||
'strategy': 'bollinger_mean_reversion',
|
||||
'type': 'entry'
|
||||
}
|
||||
)
|
||||
|
||||
elif band_touch == 'upper':
|
||||
# Price touching upper band - potential sell signal (for short)
|
||||
if not self.trend_filter or trend == 'down':
|
||||
# Only sell in downtrend (mean reversion back down)
|
||||
# Note: This is a SELL signal (could be used for shorting or exiting long positions)
|
||||
# For mean reversion, we typically only go long, so this might be an exit signal
|
||||
# For simplicity, we'll make it a SELL signal
|
||||
pass # Could implement short entry here if desired
|
||||
|
||||
else:
|
||||
# In position - check for exit
|
||||
entry_price = self._entry_price or price
|
||||
entry_float = float(entry_price)
|
||||
|
||||
# Exit when price reaches middle band (mean reversion complete)
|
||||
band_width = upper - lower
|
||||
position_in_bands = (current_price - lower) / band_width if band_width > 0 else 0.5
|
||||
|
||||
# Exit conditions
|
||||
should_exit = False
|
||||
exit_reason = None
|
||||
|
||||
if position_in_bands >= self.exit_threshold:
|
||||
# Price has moved back toward middle - take profit
|
||||
should_exit = True
|
||||
exit_reason = 'target_reached'
|
||||
elif band_touch := self._check_band_touch(current_price, upper, middle, lower):
|
||||
# Price touched opposite band - stop loss
|
||||
if band_touch == 'upper' and entry_float < middle:
|
||||
should_exit = True
|
||||
exit_reason = 'stop_loss'
|
||||
|
||||
if should_exit:
|
||||
self._in_position = False
|
||||
profit_pct = (current_price - entry_float) / entry_float if entry_float > 0 else 0.0
|
||||
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.SELL,
|
||||
symbol=symbol,
|
||||
strength=1.0,
|
||||
price=price,
|
||||
metadata={
|
||||
'entry_price': float(entry_price),
|
||||
'exit_price': current_price,
|
||||
'profit_pct': profit_pct * 100,
|
||||
'exit_reason': exit_reason,
|
||||
'strategy': 'bollinger_mean_reversion',
|
||||
'type': 'exit'
|
||||
}
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal."""
|
||||
if signal.signal_type == SignalType.SELL:
|
||||
self._in_position = False
|
||||
self._entry_price = None
|
||||
|
||||
return signal if self.should_execute(signal) else None
|
||||
|
||||
246
src/strategies/technical/confirmed_strategy.py
Normal file
246
src/strategies/technical/confirmed_strategy.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Multi-indicator confirmation strategy.
|
||||
|
||||
This strategy requires multiple indicators (RSI, MACD, Moving Average) to align
|
||||
before generating a signal, reducing false signals significantly.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any, List
|
||||
import pandas as pd
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.indicators import get_indicators
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ConfirmedStrategy(BaseStrategy):
|
||||
"""Multi-indicator confirmation strategy.
|
||||
|
||||
Combines RSI, MACD, and Moving Average signals and only generates
|
||||
signals when a configurable number of indicators agree.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize confirmed strategy.
|
||||
|
||||
Parameters:
|
||||
rsi_period: RSI period (default 14)
|
||||
rsi_oversold: RSI oversold threshold (default 30)
|
||||
rsi_overbought: RSI overbought threshold (default 70)
|
||||
macd_fast: MACD fast period (default 12)
|
||||
macd_slow: MACD slow period (default 26)
|
||||
macd_signal: MACD signal period (default 9)
|
||||
ma_fast: Fast MA period (default 10)
|
||||
ma_slow: Slow MA period (default 30)
|
||||
ma_type: MA type - 'sma' or 'ema' (default 'ema')
|
||||
min_confirmations: Minimum number of indicators that must agree (default 2)
|
||||
require_rsi: Require RSI confirmation (default True)
|
||||
require_macd: Require MACD confirmation (default True)
|
||||
require_ma: Require MA confirmation (default True)
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
|
||||
# RSI parameters
|
||||
self.rsi_period = self.parameters.get('rsi_period', 14)
|
||||
self.rsi_oversold = self.parameters.get('rsi_oversold', 30)
|
||||
self.rsi_overbought = self.parameters.get('rsi_overbought', 70)
|
||||
|
||||
# MACD parameters
|
||||
self.macd_fast = self.parameters.get('macd_fast', 12)
|
||||
self.macd_slow = self.parameters.get('macd_slow', 26)
|
||||
self.macd_signal = self.parameters.get('macd_signal', 9)
|
||||
|
||||
# MA parameters
|
||||
self.ma_fast = self.parameters.get('ma_fast', 10)
|
||||
self.ma_slow = self.parameters.get('ma_slow', 30)
|
||||
self.ma_type = self.parameters.get('ma_type', 'ema')
|
||||
|
||||
# Confirmation parameters
|
||||
self.min_confirmations = self.parameters.get('min_confirmations', 2)
|
||||
self.require_rsi = self.parameters.get('require_rsi', True)
|
||||
self.require_macd = self.parameters.get('require_macd', True)
|
||||
self.require_ma = self.parameters.get('require_ma', True)
|
||||
|
||||
self.indicators = get_indicators()
|
||||
self._price_history: List[float] = []
|
||||
|
||||
def _check_rsi_signal(self, prices: pd.Series) -> Optional[SignalType]:
|
||||
"""Check RSI for signal.
|
||||
|
||||
Args:
|
||||
prices: Price series
|
||||
|
||||
Returns:
|
||||
SignalType if RSI indicates signal, None otherwise
|
||||
"""
|
||||
if len(prices) < self.rsi_period + 1:
|
||||
return None
|
||||
|
||||
rsi = self.indicators.rsi(prices, self.rsi_period)
|
||||
if len(rsi) == 0:
|
||||
return None
|
||||
|
||||
current_rsi = rsi.iloc[-1]
|
||||
|
||||
if current_rsi < self.rsi_oversold:
|
||||
return SignalType.BUY
|
||||
elif current_rsi > self.rsi_overbought:
|
||||
return SignalType.SELL
|
||||
|
||||
return None
|
||||
|
||||
def _check_macd_signal(self, prices: pd.Series) -> Optional[SignalType]:
|
||||
"""Check MACD for signal.
|
||||
|
||||
Args:
|
||||
prices: Price series
|
||||
|
||||
Returns:
|
||||
SignalType if MACD indicates signal, None otherwise
|
||||
"""
|
||||
if len(prices) < self.macd_slow + self.macd_signal:
|
||||
return None
|
||||
|
||||
macd_data = self.indicators.macd(prices, self.macd_fast, self.macd_slow, self.macd_signal)
|
||||
|
||||
if len(macd_data['macd']) < 2:
|
||||
return None
|
||||
|
||||
macd = macd_data['macd'].iloc[-1]
|
||||
signal_line = macd_data['signal'].iloc[-1]
|
||||
prev_macd = macd_data['macd'].iloc[-2]
|
||||
prev_signal = macd_data['signal'].iloc[-2]
|
||||
|
||||
# Bullish crossover
|
||||
if prev_macd <= prev_signal and macd > signal_line:
|
||||
return SignalType.BUY
|
||||
# Bearish crossover
|
||||
elif prev_macd >= prev_signal and macd < signal_line:
|
||||
return SignalType.SELL
|
||||
|
||||
return None
|
||||
|
||||
def _check_ma_signal(self, prices: pd.Series) -> Optional[SignalType]:
|
||||
"""Check Moving Average for signal.
|
||||
|
||||
Args:
|
||||
prices: Price series
|
||||
|
||||
Returns:
|
||||
SignalType if MA indicates signal, None otherwise
|
||||
"""
|
||||
if len(prices) < self.ma_slow + 1:
|
||||
return None
|
||||
|
||||
if self.ma_type == 'sma':
|
||||
fast_ma = self.indicators.sma(prices, self.ma_fast)
|
||||
slow_ma = self.indicators.sma(prices, self.ma_slow)
|
||||
else:
|
||||
fast_ma = self.indicators.ema(prices, self.ma_fast)
|
||||
slow_ma = self.indicators.ema(prices, self.ma_slow)
|
||||
|
||||
if len(fast_ma) < 2 or len(slow_ma) < 2:
|
||||
return None
|
||||
|
||||
fast_current = fast_ma.iloc[-1]
|
||||
fast_prev = fast_ma.iloc[-2]
|
||||
slow_current = slow_ma.iloc[-1]
|
||||
slow_prev = slow_ma.iloc[-2]
|
||||
|
||||
# Bullish crossover
|
||||
if fast_prev <= slow_prev and fast_current > slow_current:
|
||||
return SignalType.BUY
|
||||
# Bearish crossover
|
||||
elif fast_prev >= slow_prev and fast_current < slow_current:
|
||||
return SignalType.SELL
|
||||
|
||||
return None
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate signal based on multi-indicator confirmation."""
|
||||
# Add price to history
|
||||
self._price_history.append(float(price))
|
||||
|
||||
# Determine minimum required data points
|
||||
max_period = max(
|
||||
self.rsi_period + 1,
|
||||
self.macd_slow + self.macd_signal,
|
||||
self.ma_slow + 1
|
||||
)
|
||||
|
||||
if len(self._price_history) < max_period:
|
||||
return None
|
||||
|
||||
prices = pd.Series(self._price_history[-max_period:])
|
||||
|
||||
# Collect signals from each indicator
|
||||
signals: List[SignalType] = []
|
||||
signal_metadata: Dict[str, Any] = {}
|
||||
|
||||
# Check RSI
|
||||
if self.require_rsi:
|
||||
rsi_signal = self._check_rsi_signal(prices)
|
||||
if rsi_signal:
|
||||
signals.append(rsi_signal)
|
||||
signal_metadata['rsi'] = rsi_signal.value
|
||||
|
||||
# Check MACD
|
||||
if self.require_macd:
|
||||
macd_signal = self._check_macd_signal(prices)
|
||||
if macd_signal:
|
||||
signals.append(macd_signal)
|
||||
signal_metadata['macd'] = macd_signal.value
|
||||
|
||||
# Check MA
|
||||
if self.require_ma:
|
||||
ma_signal = self._check_ma_signal(prices)
|
||||
if ma_signal:
|
||||
signals.append(ma_signal)
|
||||
signal_metadata['ma'] = ma_signal.value
|
||||
|
||||
# Count confirmations for each signal type
|
||||
buy_count = signals.count(SignalType.BUY)
|
||||
sell_count = signals.count(SignalType.SELL)
|
||||
|
||||
# Determine final signal based on confirmations
|
||||
final_signal = None
|
||||
confirmation_count = 0
|
||||
|
||||
if buy_count >= self.min_confirmations:
|
||||
final_signal = SignalType.BUY
|
||||
confirmation_count = buy_count
|
||||
elif sell_count >= self.min_confirmations:
|
||||
final_signal = SignalType.SELL
|
||||
confirmation_count = sell_count
|
||||
|
||||
if final_signal is None:
|
||||
return None
|
||||
|
||||
# Calculate signal strength based on number of confirmations
|
||||
max_possible = sum([self.require_rsi, self.require_macd, self.require_ma])
|
||||
strength = min(1.0, confirmation_count / max_possible) if max_possible > 0 else 0.5
|
||||
|
||||
return StrategySignal(
|
||||
signal_type=final_signal,
|
||||
symbol=symbol,
|
||||
strength=strength,
|
||||
price=price,
|
||||
metadata={
|
||||
'confirmation_count': confirmation_count,
|
||||
'max_confirmations': max_possible,
|
||||
'signals': signal_metadata,
|
||||
'strategy': 'confirmed',
|
||||
'rsi_period': self.rsi_period,
|
||||
'macd_fast': self.macd_fast,
|
||||
'macd_slow': self.macd_slow,
|
||||
'ma_type': self.ma_type,
|
||||
'ma_fast': self.ma_fast,
|
||||
'ma_slow': self.ma_slow
|
||||
}
|
||||
)
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal."""
|
||||
return signal if self.should_execute(signal) else None
|
||||
|
||||
154
src/strategies/technical/divergence_strategy.py
Normal file
154
src/strategies/technical/divergence_strategy.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Divergence detection strategy.
|
||||
|
||||
Detects price vs. indicator divergences (RSI/MACD) which are powerful
|
||||
reversal signals with high success rates in ranging markets.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any, List
|
||||
import pandas as pd
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.indicators import get_indicators
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DivergenceStrategy(BaseStrategy):
|
||||
"""Divergence detection strategy.
|
||||
|
||||
Detects bullish and bearish divergences between price and indicators
|
||||
(RSI or MACD) to identify potential reversals.
|
||||
|
||||
- Bullish divergence: Price makes lower low, indicator makes higher low → BUY
|
||||
- Bearish divergence: Price makes higher high, indicator makes lower high → SELL
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize divergence strategy.
|
||||
|
||||
Parameters:
|
||||
indicator_type: Type of indicator - 'rsi' or 'macd' (default 'rsi')
|
||||
rsi_period: RSI period if using RSI (default 14)
|
||||
macd_fast: MACD fast period if using MACD (default 12)
|
||||
macd_slow: MACD slow period if using MACD (default 26)
|
||||
macd_signal: MACD signal period if using MACD (default 9)
|
||||
lookback: Lookback period for finding swings (default 20)
|
||||
min_swings: Minimum number of swings to detect divergence (default 2)
|
||||
min_confidence: Minimum confidence threshold for signal (default 0.5)
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
|
||||
self.indicator_type = self.parameters.get('indicator_type', 'rsi').lower()
|
||||
self.rsi_period = self.parameters.get('rsi_period', 14)
|
||||
self.macd_fast = self.parameters.get('macd_fast', 12)
|
||||
self.macd_slow = self.parameters.get('macd_slow', 26)
|
||||
self.macd_signal = self.parameters.get('macd_signal', 9)
|
||||
self.lookback = self.parameters.get('lookback', 20)
|
||||
self.min_swings = self.parameters.get('min_swings', 2)
|
||||
self.min_confidence = self.parameters.get('min_confidence', 0.5)
|
||||
|
||||
self.indicators = get_indicators()
|
||||
self._price_history: List[float] = []
|
||||
self._last_divergence_type: Optional[str] = None
|
||||
|
||||
def _calculate_indicator(self, prices: pd.Series) -> Optional[pd.Series]:
|
||||
"""Calculate the selected indicator.
|
||||
|
||||
Args:
|
||||
prices: Price series
|
||||
|
||||
Returns:
|
||||
Indicator series or None
|
||||
"""
|
||||
if self.indicator_type == 'rsi':
|
||||
if len(prices) < self.rsi_period + 1:
|
||||
return None
|
||||
return self.indicators.rsi(prices, self.rsi_period)
|
||||
|
||||
elif self.indicator_type == 'macd':
|
||||
if len(prices) < self.macd_slow + self.macd_signal:
|
||||
return None
|
||||
macd_data = self.indicators.macd(prices, self.macd_fast, self.macd_slow, self.macd_signal)
|
||||
return macd_data['macd']
|
||||
|
||||
return None
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate signal based on divergence detection."""
|
||||
# Add price to history
|
||||
self._price_history.append(float(price))
|
||||
|
||||
# Determine minimum required data points
|
||||
if self.indicator_type == 'rsi':
|
||||
min_period = self.rsi_period + self.lookback * 3
|
||||
else:
|
||||
min_period = self.macd_slow + self.macd_signal + self.lookback * 3
|
||||
|
||||
if len(self._price_history) < min_period:
|
||||
return None
|
||||
|
||||
prices = pd.Series(self._price_history[-min_period:])
|
||||
|
||||
# Calculate indicator
|
||||
indicator = self._calculate_indicator(prices)
|
||||
if indicator is None or len(indicator) < self.lookback * 2:
|
||||
return None
|
||||
|
||||
# Align prices and indicator (indicator may be shorter due to calculation)
|
||||
if len(prices) != len(indicator):
|
||||
# Take the last len(indicator) prices to match indicator length
|
||||
prices = prices.iloc[-len(indicator):]
|
||||
indicator = indicator.reset_index(drop=True) if hasattr(indicator, 'reset_index') else indicator
|
||||
|
||||
# Detect divergence
|
||||
divergence_result = self.indicators.detect_divergence(
|
||||
prices=prices,
|
||||
indicator=indicator,
|
||||
lookback=self.lookback,
|
||||
min_swings=self.min_swings
|
||||
)
|
||||
|
||||
divergence_type = divergence_result.get('type')
|
||||
confidence = divergence_result.get('confidence', 0.0)
|
||||
|
||||
# Check if we have a valid divergence with sufficient confidence
|
||||
if divergence_type is None or confidence < self.min_confidence:
|
||||
return None
|
||||
|
||||
# Only generate signal if divergence type changed (avoid repeated signals)
|
||||
if divergence_type == self._last_divergence_type:
|
||||
return None
|
||||
|
||||
self._last_divergence_type = divergence_type
|
||||
|
||||
# Map divergence type to signal type
|
||||
if divergence_type == 'bullish':
|
||||
signal_type = SignalType.BUY
|
||||
elif divergence_type == 'bearish':
|
||||
signal_type = SignalType.SELL
|
||||
else:
|
||||
return None
|
||||
|
||||
return StrategySignal(
|
||||
signal_type=signal_type,
|
||||
symbol=symbol,
|
||||
strength=confidence,
|
||||
price=price,
|
||||
metadata={
|
||||
'divergence_type': divergence_type,
|
||||
'confidence': confidence,
|
||||
'indicator_type': self.indicator_type,
|
||||
'lookback': self.lookback,
|
||||
'price_swing_high': divergence_result.get('price_swing_high'),
|
||||
'price_swing_low': divergence_result.get('price_swing_low'),
|
||||
'indicator_swing_high': divergence_result.get('indicator_swing_high'),
|
||||
'indicator_swing_low': divergence_result.get('indicator_swing_low'),
|
||||
'strategy': 'divergence'
|
||||
}
|
||||
)
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal."""
|
||||
return signal if self.should_execute(signal) else None
|
||||
|
||||
72
src/strategies/technical/macd_strategy.py
Normal file
72
src/strategies/technical/macd_strategy.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""MACD (Moving Average Convergence Divergence) strategy."""
|
||||
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
class MACDStrategy(BaseStrategy):
|
||||
"""MACD-based trading strategy."""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize MACD strategy.
|
||||
|
||||
Parameters:
|
||||
fast: Fast EMA period (default 12)
|
||||
slow: Slow EMA period (default 26)
|
||||
signal: Signal line period (default 9)
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
self.fast = self.parameters.get('fast', 12)
|
||||
self.slow = self.parameters.get('slow', 26)
|
||||
self.signal = self.parameters.get('signal', 9)
|
||||
self.indicators = get_indicators()
|
||||
self._price_history = []
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate signal based on MACD."""
|
||||
|
||||
|
||||
# Add price to history
|
||||
self._price_history.append(float(price))
|
||||
if len(self._price_history) < self.slow + self.signal:
|
||||
return None
|
||||
|
||||
# Calculate MACD
|
||||
prices = pd.Series(self._price_history[-self.slow-self.signal:])
|
||||
macd_data = self.indicators.macd(prices, self.fast, self.slow, self.signal)
|
||||
|
||||
if len(macd_data['macd']) < 2:
|
||||
return None
|
||||
|
||||
macd = macd_data['macd'].iloc[-1]
|
||||
signal_line = macd_data['signal'].iloc[-1]
|
||||
prev_macd = macd_data['macd'].iloc[-2]
|
||||
prev_signal = macd_data['signal'].iloc[-2]
|
||||
|
||||
# Bullish crossover
|
||||
if prev_macd <= prev_signal and macd > signal_line:
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol=symbol,
|
||||
strength=min(1.0, abs(macd - signal_line) / abs(prev_macd - prev_signal) if prev_macd != prev_signal else 1.0),
|
||||
price=price,
|
||||
metadata={'macd': float(macd), 'signal': float(signal_line)}
|
||||
)
|
||||
# Bearish crossover
|
||||
elif prev_macd >= prev_signal and macd < signal_line:
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.SELL,
|
||||
symbol=symbol,
|
||||
strength=min(1.0, abs(macd - signal_line) / abs(prev_macd - prev_signal) if prev_macd != prev_signal else 1.0),
|
||||
price=price,
|
||||
metadata={'macd': float(macd), 'signal': float(signal_line)}
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal."""
|
||||
return signal if self.should_execute(signal) else None
|
||||
|
||||
79
src/strategies/technical/moving_avg_strategy.py
Normal file
79
src/strategies/technical/moving_avg_strategy.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Moving Average Crossover strategy."""
|
||||
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
class MovingAverageStrategy(BaseStrategy):
|
||||
"""Moving average crossover strategy."""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize moving average strategy.
|
||||
|
||||
Parameters:
|
||||
fast_period: Fast MA period (default 10)
|
||||
slow_period: Slow MA period (default 30)
|
||||
ma_type: MA type - 'sma' or 'ema' (default 'ema')
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
self.fast_period = self.parameters.get('fast_period', 10)
|
||||
self.slow_period = self.parameters.get('slow_period', 30)
|
||||
self.ma_type = self.parameters.get('ma_type', 'ema')
|
||||
self.indicators = get_indicators()
|
||||
self._price_history = []
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate signal based on MA crossover."""
|
||||
|
||||
|
||||
# Add price to history
|
||||
self._price_history.append(float(price))
|
||||
if len(self._price_history) < self.slow_period + 1:
|
||||
return None
|
||||
|
||||
# Calculate MAs
|
||||
prices = pd.Series(self._price_history[-self.slow_period-1:])
|
||||
|
||||
if self.ma_type == 'sma':
|
||||
fast_ma = self.indicators.sma(prices, self.fast_period)
|
||||
slow_ma = self.indicators.sma(prices, self.slow_period)
|
||||
else:
|
||||
fast_ma = self.indicators.ema(prices, self.fast_period)
|
||||
slow_ma = self.indicators.ema(prices, self.slow_period)
|
||||
|
||||
if len(fast_ma) < 2 or len(slow_ma) < 2:
|
||||
return None
|
||||
|
||||
# Check for crossover
|
||||
fast_current = fast_ma.iloc[-1]
|
||||
fast_prev = fast_ma.iloc[-2]
|
||||
slow_current = slow_ma.iloc[-1]
|
||||
slow_prev = slow_ma.iloc[-2]
|
||||
|
||||
# Bullish crossover
|
||||
if fast_prev <= slow_prev and fast_current > slow_current:
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol=symbol,
|
||||
strength=min(1.0, (fast_current - slow_current) / slow_current),
|
||||
price=price,
|
||||
metadata={'fast_ma': float(fast_current), 'slow_ma': float(slow_current)}
|
||||
)
|
||||
# Bearish crossover
|
||||
elif fast_prev >= slow_prev and fast_current < slow_current:
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.SELL,
|
||||
symbol=symbol,
|
||||
strength=min(1.0, (slow_current - fast_current) / slow_current),
|
||||
price=price,
|
||||
metadata={'fast_ma': float(fast_current), 'slow_ma': float(slow_current)}
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal."""
|
||||
return signal if self.should_execute(signal) else None
|
||||
|
||||
146
src/strategies/technical/pairs_trading.py
Normal file
146
src/strategies/technical/pairs_trading.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Statistical Arbitrage (Pairs Trading) Strategy."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, Optional, List
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from ..base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.pricing_service import get_pricing_service
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class PairsTradingStrategy(BaseStrategy):
|
||||
"""
|
||||
Statistical Arbitrage Strategy that trades the spread between two correlated assets.
|
||||
|
||||
Logic:
|
||||
1. Calculate Spread = Price(A) / Price(B)
|
||||
2. Calculate Z-Score of the spread over a rolling window.
|
||||
3. Mean Reversion signals:
|
||||
- Z-Score > Threshold: Short Spread (Sell A, Buy B)
|
||||
- Z-Score < -Threshold: Long Spread (Buy A, Sell B)
|
||||
- Z-Score approx 0: Close/Neutralize (Exit both)
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameters: Dict[str, Any], timeframes: List[str] = None):
|
||||
super().__init__(name, parameters, timeframes)
|
||||
self.second_symbol = parameters.get('second_symbol')
|
||||
self.lookback_window = int(parameters.get('lookback_period', 20))
|
||||
self.z_threshold = float(parameters.get('z_score_threshold', 2.0))
|
||||
self.pricing_service = get_pricing_service()
|
||||
|
||||
if not self.second_symbol:
|
||||
logger.warning(f"PairsTradingStrategy {name} initialized without 'second_symbol'")
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Check for pairs trading signals."""
|
||||
if not self.second_symbol or not self.enabled:
|
||||
return None
|
||||
|
||||
# We only process on the "Primary" symbol's tick to avoid double-processing
|
||||
# (Assuming the strategy is assigned to the primary symbol in the DB)
|
||||
|
||||
# 1. Fetch data for BOTH symbols
|
||||
# We need historical data to calculate Z-Score
|
||||
try:
|
||||
# Fetch for Primary (Symbol A)
|
||||
ohlcv_a = self.pricing_service.get_ohlcv(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=self.lookback_window + 5
|
||||
)
|
||||
|
||||
# Fetch for Secondary (Symbol B)
|
||||
ohlcv_b = self.pricing_service.get_ohlcv(
|
||||
symbol=self.second_symbol,
|
||||
timeframe=timeframe,
|
||||
limit=self.lookback_window + 5
|
||||
)
|
||||
|
||||
if not ohlcv_a or not ohlcv_b:
|
||||
return None
|
||||
|
||||
# Convert to Series for pandas calc
|
||||
# OHLCV format: [timestamp, open, high, low, close, volume]
|
||||
closes_a = pd.Series([float(c[4]) for c in ohlcv_a])
|
||||
closes_b = pd.Series([float(c[4]) for c in ohlcv_b])
|
||||
|
||||
# Ensure equal length if fetched data differs slightly
|
||||
min_len = min(len(closes_a), len(closes_b))
|
||||
closes_a = closes_a.iloc[-min_len:]
|
||||
closes_b = closes_b.iloc[-min_len:]
|
||||
|
||||
# 2. Calculate Spread and Z-Score
|
||||
# Spread = A / B
|
||||
spread = closes_a / closes_b
|
||||
|
||||
# Rolling statistics
|
||||
rolling_mean = spread.rolling(window=self.lookback_window).mean()
|
||||
rolling_std = spread.rolling(window=self.lookback_window).std()
|
||||
|
||||
current_spread = spread.iloc[-1]
|
||||
current_mean = rolling_mean.iloc[-1]
|
||||
current_std = rolling_std.iloc[-1]
|
||||
|
||||
if pd.isna(current_std) or current_std == 0:
|
||||
return None
|
||||
|
||||
z_score = (current_spread - current_mean) / current_std
|
||||
|
||||
self.logger.info(
|
||||
f"Pairs {symbol}/{self.second_symbol}: Spread={current_spread:.4f}, Z-Score={z_score:.2f}"
|
||||
)
|
||||
|
||||
# 3. Generate Signals
|
||||
# Strategy:
|
||||
# If Z > Threshold -> Spread is too high (A is expensive, B is cheap) -> Sell A, Buy B
|
||||
# If Z < -Threshold -> Spread is too low (A is cheap, B is expensive) -> Buy A, Sell B
|
||||
# If abs(Z) < 0.5 -> Mean reverted -> Close Positions (optional, or just go to neutral)
|
||||
|
||||
signal_type = None
|
||||
primary_side = "hold"
|
||||
secondary_side = "hold"
|
||||
|
||||
if z_score > self.z_threshold:
|
||||
# Sell A, Buy B
|
||||
signal_type = SignalType.SELL
|
||||
primary_side = "sell"
|
||||
secondary_side = "buy"
|
||||
|
||||
elif z_score < -self.z_threshold:
|
||||
# Buy A, Sell B
|
||||
signal_type = SignalType.BUY
|
||||
primary_side = "buy"
|
||||
secondary_side = "sell"
|
||||
|
||||
# TODO: Logic for closing when Z ~ 0 can be added here
|
||||
|
||||
if signal_type:
|
||||
# Create Signal for Primary
|
||||
signal = StrategySignal(
|
||||
signal_type=signal_type,
|
||||
symbol=symbol,
|
||||
strength=min(abs(z_score) / self.z_threshold, 1.0), # Strength capped at 1.0
|
||||
price=price,
|
||||
metadata={
|
||||
"strategy": "pairs_trading",
|
||||
"z_score": float(z_score),
|
||||
"spread": float(current_spread),
|
||||
"secondary_symbol": self.second_symbol,
|
||||
"secondary_action": secondary_side,
|
||||
"description": f"Pairs Arbitrage: Z-Score {z_score:.2f} triggered {primary_side} {symbol} / {secondary_side} {self.second_symbol}"
|
||||
}
|
||||
)
|
||||
return signal
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in PairsTradingStrategy: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal (pass-through)."""
|
||||
return signal
|
||||
67
src/strategies/technical/rsi_strategy.py
Normal file
67
src/strategies/technical/rsi_strategy.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""RSI (Relative Strength Index) strategy."""
|
||||
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
class RSIStrategy(BaseStrategy):
|
||||
"""RSI-based trading strategy."""
|
||||
|
||||
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, timeframes: Optional[list] = None):
|
||||
"""Initialize RSI strategy.
|
||||
|
||||
Parameters:
|
||||
rsi_period: RSI period (default 14)
|
||||
oversold: Oversold threshold (default 30)
|
||||
overbought: Overbought threshold (default 70)
|
||||
"""
|
||||
super().__init__(name, parameters, timeframes)
|
||||
self.rsi_period = self.parameters.get('rsi_period', 14)
|
||||
self.oversold = self.parameters.get('oversold', 30)
|
||||
self.overbought = self.parameters.get('overbought', 70)
|
||||
self.indicators = get_indicators()
|
||||
self._price_history = []
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Generate signal based on RSI."""
|
||||
|
||||
# Add price to history
|
||||
self._price_history.append(float(price))
|
||||
if len(self._price_history) < self.rsi_period + 1:
|
||||
return None
|
||||
|
||||
# Calculate RSI
|
||||
prices = pd.Series(self._price_history[-self.rsi_period-1:])
|
||||
rsi = self.indicators.rsi(prices, self.rsi_period)
|
||||
|
||||
if len(rsi) == 0:
|
||||
return None
|
||||
|
||||
current_rsi = rsi.iloc[-1]
|
||||
|
||||
# Generate signals
|
||||
if current_rsi < self.oversold:
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol=symbol,
|
||||
strength=1.0 - (current_rsi / self.oversold),
|
||||
price=price,
|
||||
metadata={'rsi': float(current_rsi)}
|
||||
)
|
||||
elif current_rsi > self.overbought:
|
||||
return StrategySignal(
|
||||
signal_type=SignalType.SELL,
|
||||
symbol=symbol,
|
||||
strength=(current_rsi - self.overbought) / (100 - self.overbought),
|
||||
price=price,
|
||||
metadata={'rsi': float(current_rsi)}
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal."""
|
||||
return signal if self.should_execute(signal) else None
|
||||
|
||||
177
src/strategies/technical/volatility_breakout.py
Normal file
177
src/strategies/technical/volatility_breakout.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Volatility Breakout Strategy - Captures explosive moves after consolidation."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, Optional, List
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from ..base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.data.pricing_service import get_pricing_service
|
||||
from src.data.indicators import get_indicators
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class VolatilityBreakoutStrategy(BaseStrategy):
|
||||
"""
|
||||
Volatility Breakout Strategy that identifies and trades breakouts from consolidation.
|
||||
|
||||
Logic:
|
||||
1. Detect consolidation using Bollinger Band Width (squeeze).
|
||||
2. Confirm breakout when price exits the bands with volume confirmation.
|
||||
3. Use ADX to ensure the trend is strong enough to follow.
|
||||
|
||||
Entry Conditions (BUY):
|
||||
- Bollinger Band Width < Squeeze Threshold (consolidation detected)
|
||||
- Price breaks above Upper Bollinger Band
|
||||
- Volume > 20-day average volume (confirmation)
|
||||
- ADX > 25 (strong trend)
|
||||
|
||||
Entry Conditions (SELL):
|
||||
- Price breaks below Lower Bollinger Band
|
||||
- Volume > 20-day average volume
|
||||
- ADX > 25 (strong trend)
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameters: Dict[str, Any], timeframes: List[str] = None):
|
||||
super().__init__(name, parameters, timeframes)
|
||||
|
||||
# Strategy parameters
|
||||
self.bb_period = int(parameters.get('bb_period', 20))
|
||||
self.bb_std_dev = float(parameters.get('bb_std_dev', 2.0))
|
||||
self.squeeze_threshold = float(parameters.get('squeeze_threshold', 0.1))
|
||||
self.volume_multiplier = float(parameters.get('volume_multiplier', 1.5))
|
||||
self.adx_period = int(parameters.get('adx_period', 14))
|
||||
self.min_adx = float(parameters.get('min_adx', 25.0))
|
||||
self.use_volume_filter = parameters.get('use_volume_filter', True)
|
||||
|
||||
self.pricing_service = get_pricing_service()
|
||||
self.indicators = get_indicators()
|
||||
|
||||
# Track squeeze state
|
||||
self._in_squeeze = False
|
||||
|
||||
async def on_tick(self, symbol: str, price: Decimal, timeframe: str, data: Dict[str, Any]) -> Optional[StrategySignal]:
|
||||
"""Check for volatility breakout signals."""
|
||||
if not self.enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Fetch OHLCV data
|
||||
ohlcv = self.pricing_service.get_ohlcv(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=max(self.bb_period, self.adx_period) + 30
|
||||
)
|
||||
|
||||
if not ohlcv or len(ohlcv) < self.bb_period + 10:
|
||||
return None
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
||||
|
||||
# Calculate Bollinger Bands
|
||||
close = df['close'].astype(float)
|
||||
high = df['high'].astype(float)
|
||||
low = df['low'].astype(float)
|
||||
volume = df['volume'].astype(float)
|
||||
|
||||
bb_upper, bb_middle, bb_lower = self.indicators.bollinger_bands(
|
||||
close, period=self.bb_period, std_dev=self.bb_std_dev
|
||||
)
|
||||
|
||||
# Calculate Bollinger Band Width (squeeze indicator)
|
||||
bb_width = (bb_upper - bb_lower) / bb_middle
|
||||
current_width = bb_width.iloc[-1]
|
||||
|
||||
# Calculate ADX for trend strength
|
||||
adx = self.indicators.adx(high, low, close, period=self.adx_period)
|
||||
current_adx = adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 0.0
|
||||
|
||||
# Calculate volume metrics
|
||||
avg_volume = volume.rolling(window=20).mean().iloc[-1]
|
||||
current_volume = volume.iloc[-1]
|
||||
volume_ratio = current_volume / avg_volume if avg_volume > 0 else 0
|
||||
|
||||
current_price = float(price)
|
||||
upper_band = bb_upper.iloc[-1]
|
||||
lower_band = bb_lower.iloc[-1]
|
||||
|
||||
# Check for squeeze (consolidation)
|
||||
is_squeeze = current_width < self.squeeze_threshold
|
||||
was_in_squeeze = self._in_squeeze
|
||||
self._in_squeeze = is_squeeze
|
||||
|
||||
self.logger.debug(
|
||||
f"Volatility Breakout {symbol}: Width={current_width:.4f}, "
|
||||
f"ADX={current_adx:.1f}, Vol Ratio={volume_ratio:.2f}, "
|
||||
f"Squeeze={is_squeeze}"
|
||||
)
|
||||
|
||||
# No signal if not breaking out of a squeeze
|
||||
if not was_in_squeeze:
|
||||
return None
|
||||
|
||||
# Volume filter
|
||||
if self.use_volume_filter and volume_ratio < self.volume_multiplier:
|
||||
self.logger.debug(f"Volume filter: {volume_ratio:.2f} < {self.volume_multiplier}")
|
||||
return None
|
||||
|
||||
# ADX filter - ensure trend is strong
|
||||
if current_adx < self.min_adx:
|
||||
self.logger.debug(f"ADX filter: {current_adx:.1f} < {self.min_adx}")
|
||||
return None
|
||||
|
||||
# Check for breakout
|
||||
signal_type = None
|
||||
|
||||
# Bullish breakout - price breaks above upper band
|
||||
if current_price > upper_band:
|
||||
signal_type = SignalType.BUY
|
||||
self.logger.info(
|
||||
f"BULLISH BREAKOUT: {symbol} @ {current_price:.2f} > Upper Band {upper_band:.2f}"
|
||||
)
|
||||
|
||||
# Bearish breakout - price breaks below lower band
|
||||
elif current_price < lower_band:
|
||||
signal_type = SignalType.SELL
|
||||
self.logger.info(
|
||||
f"BEARISH BREAKOUT: {symbol} @ {current_price:.2f} < Lower Band {lower_band:.2f}"
|
||||
)
|
||||
|
||||
if signal_type:
|
||||
# Calculate signal strength based on multiple factors
|
||||
strength = min(1.0, (
|
||||
(current_adx / 50) * 0.4 + # ADX contribution
|
||||
(volume_ratio / 3) * 0.4 + # Volume contribution
|
||||
0.2 # Base strength for breakout
|
||||
))
|
||||
|
||||
signal = StrategySignal(
|
||||
signal_type=signal_type,
|
||||
symbol=symbol,
|
||||
strength=strength,
|
||||
price=price,
|
||||
metadata={
|
||||
"strategy": "volatility_breakout",
|
||||
"bb_width": float(current_width),
|
||||
"adx": float(current_adx),
|
||||
"volume_ratio": float(volume_ratio),
|
||||
"upper_band": float(upper_band),
|
||||
"lower_band": float(lower_band),
|
||||
"was_squeeze": was_in_squeeze,
|
||||
"description": f"Volatility Breakout: ADX={current_adx:.1f}, Vol={volume_ratio:.1f}x"
|
||||
}
|
||||
)
|
||||
return signal
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in VolatilityBreakoutStrategy: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def on_signal(self, signal: StrategySignal) -> Optional[StrategySignal]:
|
||||
"""Process signal (pass-through)."""
|
||||
return signal
|
||||
103
src/strategies/timeframe_manager.py
Normal file
103
src/strategies/timeframe_manager.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Multi-timeframe strategy framework with synchronization."""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TimeframeManager:
|
||||
"""Manages multiple timeframes for strategies."""
|
||||
|
||||
def __init__(self, timeframes: List[str]):
|
||||
"""Initialize timeframe manager.
|
||||
|
||||
Args:
|
||||
timeframes: List of timeframes (e.g., ['1h', '15m'])
|
||||
"""
|
||||
self.timeframes = sorted(timeframes, key=self._timeframe_to_seconds, reverse=True)
|
||||
self.data: Dict[str, Dict[str, Any]] = {tf: {} for tf in self.timeframes}
|
||||
self.last_update: Dict[str, datetime] = {tf: datetime.min for tf in self.timeframes}
|
||||
|
||||
def _timeframe_to_seconds(self, tf: str) -> int:
|
||||
"""Convert timeframe to seconds for sorting.
|
||||
|
||||
Args:
|
||||
tf: Timeframe string (e.g., '1h', '15m')
|
||||
|
||||
Returns:
|
||||
Seconds
|
||||
"""
|
||||
if tf.endswith('m'):
|
||||
return int(tf[:-1]) * 60
|
||||
elif tf.endswith('h'):
|
||||
return int(tf[:-1]) * 3600
|
||||
elif tf.endswith('d'):
|
||||
return int(tf[:-1]) * 86400
|
||||
return 0
|
||||
|
||||
def update(self, timeframe: str, symbol: str, data: Dict[str, Any]):
|
||||
"""Update data for a timeframe.
|
||||
|
||||
Args:
|
||||
timeframe: Timeframe
|
||||
symbol: Trading symbol
|
||||
data: Market data
|
||||
"""
|
||||
if timeframe not in self.timeframes:
|
||||
logger.warning(f"Unknown timeframe: {timeframe}")
|
||||
return
|
||||
|
||||
if symbol not in self.data[timeframe]:
|
||||
self.data[timeframe][symbol] = {}
|
||||
|
||||
self.data[timeframe][symbol].update(data)
|
||||
self.last_update[timeframe] = datetime.utcnow()
|
||||
|
||||
def get_data(self, timeframe: str, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get data for a timeframe.
|
||||
|
||||
Args:
|
||||
timeframe: Timeframe
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Data dictionary or None
|
||||
"""
|
||||
return self.data.get(timeframe, {}).get(symbol)
|
||||
|
||||
def get_all_timeframes(self, symbol: str) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get data for all timeframes.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dictionary of timeframe -> data
|
||||
"""
|
||||
return {
|
||||
tf: self.data[tf].get(symbol, {})
|
||||
for tf in self.timeframes
|
||||
}
|
||||
|
||||
def is_synchronized(self, symbol: str, max_age_seconds: int = 300) -> bool:
|
||||
"""Check if all timeframes are synchronized (recently updated).
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
max_age_seconds: Maximum age in seconds
|
||||
|
||||
Returns:
|
||||
True if all timeframes are synchronized
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
for tf in self.timeframes:
|
||||
if symbol not in self.data[tf]:
|
||||
return False
|
||||
age = (now - self.last_update[tf]).total_seconds()
|
||||
if age > max_age_seconds:
|
||||
return False
|
||||
return True
|
||||
|
||||
0
src/trading/__init__.py
Normal file
0
src/trading/__init__.py
Normal file
409
src/trading/advanced_orders.py
Normal file
409
src/trading/advanced_orders.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""Advanced order types: take-profit, trailing stop, OCO, iceberg orders."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from src.core.database import Order, OrderType, OrderSide, OrderStatus
|
||||
from src.core.logger import get_logger
|
||||
from .order_manager import get_order_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TakeProfitOrder:
|
||||
"""Take-profit order - automatically sells when price reaches target."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_order_id: int,
|
||||
target_price: Decimal,
|
||||
quantity: Optional[Decimal] = None
|
||||
):
|
||||
"""Initialize take-profit order.
|
||||
|
||||
Args:
|
||||
base_order_id: ID of the base position/order
|
||||
target_price: Target price to trigger take-profit
|
||||
quantity: Quantity to sell (None = all)
|
||||
"""
|
||||
self.base_order_id = base_order_id
|
||||
self.target_price = target_price
|
||||
self.quantity = quantity
|
||||
self.triggered = False
|
||||
|
||||
def check_trigger(self, current_price: Decimal) -> bool:
|
||||
"""Check if take-profit should trigger.
|
||||
|
||||
Args:
|
||||
current_price: Current market price
|
||||
|
||||
Returns:
|
||||
True if should trigger
|
||||
"""
|
||||
if self.triggered:
|
||||
return False
|
||||
|
||||
# Trigger if price reaches or exceeds target
|
||||
if current_price >= self.target_price:
|
||||
self.triggered = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TrailingStopOrder:
|
||||
"""Trailing stop-loss order - adjusts stop price as price moves favorably."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_order_id: int,
|
||||
initial_stop_price: Decimal,
|
||||
trail_percent: Decimal,
|
||||
quantity: Optional[Decimal] = None
|
||||
):
|
||||
"""Initialize trailing stop order.
|
||||
|
||||
Args:
|
||||
base_order_id: ID of the base position/order
|
||||
initial_stop_price: Initial stop price
|
||||
trail_percent: Percentage to trail (e.g., 0.02 for 2%)
|
||||
quantity: Quantity to sell (None = all)
|
||||
"""
|
||||
self.base_order_id = base_order_id
|
||||
self.current_stop_price = initial_stop_price
|
||||
self.trail_percent = trail_percent
|
||||
self.quantity = quantity
|
||||
self.triggered = False
|
||||
self.highest_price = initial_stop_price # For long positions
|
||||
|
||||
def update(self, current_price: Decimal, is_long: bool = True):
|
||||
"""Update trailing stop based on current price.
|
||||
|
||||
Args:
|
||||
current_price: Current market price
|
||||
is_long: True for long position, False for short
|
||||
"""
|
||||
if self.triggered:
|
||||
return
|
||||
|
||||
if is_long:
|
||||
# For long positions, trail upward
|
||||
if current_price > self.highest_price:
|
||||
self.highest_price = current_price
|
||||
# Adjust stop to trail below highest price
|
||||
self.current_stop_price = current_price * (1 - self.trail_percent)
|
||||
else:
|
||||
# For short positions, trail downward
|
||||
if current_price < self.highest_price or self.highest_price == self.current_stop_price:
|
||||
self.highest_price = current_price
|
||||
# Adjust stop to trail above lowest price
|
||||
self.current_stop_price = current_price * (1 + self.trail_percent)
|
||||
|
||||
def check_trigger(self, current_price: Decimal, is_long: bool = True) -> bool:
|
||||
"""Check if trailing stop should trigger.
|
||||
|
||||
Args:
|
||||
current_price: Current market price
|
||||
is_long: True for long position, False for short
|
||||
|
||||
Returns:
|
||||
True if should trigger
|
||||
"""
|
||||
if self.triggered:
|
||||
return False
|
||||
|
||||
if is_long:
|
||||
# Trigger if price falls below stop
|
||||
if current_price <= self.current_stop_price:
|
||||
self.triggered = True
|
||||
return True
|
||||
else:
|
||||
# Trigger if price rises above stop
|
||||
if current_price >= self.current_stop_price:
|
||||
self.triggered = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class OCOOrder:
|
||||
"""One-Cancels-Other order - two orders where one cancels the other."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order1_id: int,
|
||||
order2_id: int
|
||||
):
|
||||
"""Initialize OCO order.
|
||||
|
||||
Args:
|
||||
order1_id: First order ID
|
||||
order2_id: Second order ID
|
||||
"""
|
||||
self.order1_id = order1_id
|
||||
self.order2_id = order2_id
|
||||
self.executed = False
|
||||
|
||||
async def on_order_filled(self, filled_order_id: int):
|
||||
"""Handle when one order is filled.
|
||||
|
||||
Args:
|
||||
filled_order_id: ID of the filled order
|
||||
"""
|
||||
if self.executed:
|
||||
return
|
||||
|
||||
order_manager = get_order_manager()
|
||||
|
||||
# Cancel the other order
|
||||
if filled_order_id == self.order1_id:
|
||||
await order_manager.cancel_order(self.order2_id)
|
||||
elif filled_order_id == self.order2_id:
|
||||
await order_manager.cancel_order(self.order1_id)
|
||||
|
||||
self.executed = True
|
||||
logger.info(f"OCO order executed: order {filled_order_id} filled, other cancelled")
|
||||
|
||||
|
||||
class IcebergOrder:
|
||||
"""Iceberg order - large order split into smaller visible parts."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
total_quantity: Decimal,
|
||||
visible_quantity: Decimal,
|
||||
symbol: str,
|
||||
side: OrderSide,
|
||||
price: Optional[Decimal] = None
|
||||
):
|
||||
"""Initialize iceberg order.
|
||||
|
||||
Args:
|
||||
total_quantity: Total quantity to execute
|
||||
visible_quantity: Visible quantity per order
|
||||
symbol: Trading symbol
|
||||
side: Buy or sell
|
||||
price: Limit price (None for market)
|
||||
"""
|
||||
self.total_quantity = total_quantity
|
||||
self.visible_quantity = visible_quantity
|
||||
self.symbol = symbol
|
||||
self.side = side
|
||||
self.price = price
|
||||
self.remaining_quantity = total_quantity
|
||||
self.orders: List[int] = []
|
||||
self.completed = False
|
||||
|
||||
async def create_next_order(self, exchange_id: int) -> Optional[int]:
|
||||
"""Create next visible order.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange ID
|
||||
|
||||
Returns:
|
||||
Order ID or None if completed
|
||||
"""
|
||||
if self.completed or self.remaining_quantity <= 0:
|
||||
return None
|
||||
|
||||
order_manager = get_order_manager()
|
||||
order_type = OrderType.LIMIT if self.price else OrderType.MARKET
|
||||
|
||||
quantity = min(self.visible_quantity, self.remaining_quantity)
|
||||
|
||||
order = await order_manager.create_order(
|
||||
exchange_id=exchange_id,
|
||||
strategy_id=None,
|
||||
symbol=self.symbol,
|
||||
order_type=order_type,
|
||||
side=self.side,
|
||||
quantity=quantity,
|
||||
price=self.price
|
||||
)
|
||||
|
||||
self.orders.append(order.id)
|
||||
self.remaining_quantity -= quantity
|
||||
|
||||
if self.remaining_quantity <= 0:
|
||||
self.completed = True
|
||||
|
||||
return order.id
|
||||
|
||||
async def on_order_filled(self, order_id: int):
|
||||
"""Handle when an order is filled.
|
||||
|
||||
Args:
|
||||
order_id: Filled order ID
|
||||
"""
|
||||
if order_id in self.orders:
|
||||
# Create next order if more quantity remains
|
||||
if not self.completed and self.remaining_quantity > 0:
|
||||
# This would need exchange_id - would need to store it
|
||||
logger.info(f"Iceberg order {order_id} filled, {self.remaining_quantity} remaining")
|
||||
|
||||
|
||||
class AdvancedOrderManager:
|
||||
"""Manages advanced order types."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize advanced order manager."""
|
||||
self.take_profit_orders: Dict[int, TakeProfitOrder] = {}
|
||||
self.trailing_stops: Dict[int, TrailingStopOrder] = {}
|
||||
self.oco_orders: Dict[int, OCOOrder] = {}
|
||||
self.iceberg_orders: Dict[int, IcebergOrder] = {}
|
||||
|
||||
def create_take_profit(
|
||||
self,
|
||||
base_order_id: int,
|
||||
target_price: Decimal,
|
||||
quantity: Optional[Decimal] = None
|
||||
) -> TakeProfitOrder:
|
||||
"""Create a take-profit order.
|
||||
|
||||
Args:
|
||||
base_order_id: Base order/position ID
|
||||
target_price: Target price
|
||||
quantity: Quantity (None = all)
|
||||
|
||||
Returns:
|
||||
TakeProfitOrder instance
|
||||
"""
|
||||
tp_order = TakeProfitOrder(base_order_id, target_price, quantity)
|
||||
self.take_profit_orders[base_order_id] = tp_order
|
||||
logger.info(f"Created take-profit for order {base_order_id} at {target_price}")
|
||||
return tp_order
|
||||
|
||||
def create_trailing_stop(
|
||||
self,
|
||||
base_order_id: int,
|
||||
initial_stop_price: Decimal,
|
||||
trail_percent: Decimal,
|
||||
quantity: Optional[Decimal] = None
|
||||
) -> TrailingStopOrder:
|
||||
"""Create a trailing stop order.
|
||||
|
||||
Args:
|
||||
base_order_id: Base order/position ID
|
||||
initial_stop_price: Initial stop price
|
||||
trail_percent: Trail percentage
|
||||
quantity: Quantity (None = all)
|
||||
|
||||
Returns:
|
||||
TrailingStopOrder instance
|
||||
"""
|
||||
trailing = TrailingStopOrder(base_order_id, initial_stop_price, trail_percent, quantity)
|
||||
self.trailing_stops[base_order_id] = trailing
|
||||
logger.info(f"Created trailing stop for order {base_order_id}")
|
||||
return trailing
|
||||
|
||||
def create_oco(
|
||||
self,
|
||||
order1_id: int,
|
||||
order2_id: int
|
||||
) -> OCOOrder:
|
||||
"""Create an OCO order.
|
||||
|
||||
Args:
|
||||
order1_id: First order ID
|
||||
order2_id: Second order ID
|
||||
|
||||
Returns:
|
||||
OCOOrder instance
|
||||
"""
|
||||
oco = OCOOrder(order1_id, order2_id)
|
||||
self.oco_orders[order1_id] = oco
|
||||
self.oco_orders[order2_id] = oco
|
||||
logger.info(f"Created OCO order: {order1_id} <-> {order2_id}")
|
||||
return oco
|
||||
|
||||
def create_iceberg(
|
||||
self,
|
||||
total_quantity: Decimal,
|
||||
visible_quantity: Decimal,
|
||||
symbol: str,
|
||||
side: OrderSide,
|
||||
price: Optional[Decimal] = None
|
||||
) -> IcebergOrder:
|
||||
"""Create an iceberg order.
|
||||
|
||||
Args:
|
||||
total_quantity: Total quantity
|
||||
visible_quantity: Visible quantity per order
|
||||
symbol: Trading symbol
|
||||
side: Buy or sell
|
||||
price: Limit price (None for market)
|
||||
|
||||
Returns:
|
||||
IcebergOrder instance
|
||||
"""
|
||||
iceberg = IcebergOrder(total_quantity, visible_quantity, symbol, side, price)
|
||||
# Store by a unique ID (could use a counter or hash)
|
||||
iceberg_id = id(iceberg)
|
||||
self.iceberg_orders[iceberg_id] = iceberg
|
||||
logger.info(f"Created iceberg order: {total_quantity} {symbol} ({visible_quantity} visible)")
|
||||
return iceberg
|
||||
|
||||
def update_trailing_stops(self, prices: Dict[int, Decimal], is_long: Dict[int, bool]):
|
||||
"""Update all trailing stops with current prices.
|
||||
|
||||
Args:
|
||||
prices: Dictionary of order_id -> current_price
|
||||
is_long: Dictionary of order_id -> is_long_position
|
||||
"""
|
||||
for order_id, trailing in self.trailing_stops.items():
|
||||
if order_id in prices:
|
||||
trailing.update(prices[order_id], is_long.get(order_id, True))
|
||||
|
||||
def check_triggers(self, prices: Dict[int, Decimal], is_long: Dict[int, bool]) -> List[int]:
|
||||
"""Check all triggers and return triggered order IDs.
|
||||
|
||||
Args:
|
||||
prices: Dictionary of order_id -> current_price
|
||||
is_long: Dictionary of order_id -> is_long_position
|
||||
|
||||
Returns:
|
||||
List of triggered order IDs
|
||||
"""
|
||||
triggered = []
|
||||
|
||||
# Check take-profit orders
|
||||
for order_id, tp in self.take_profit_orders.items():
|
||||
if order_id in prices and tp.check_trigger(prices[order_id]):
|
||||
triggered.append(order_id)
|
||||
|
||||
# Check trailing stops
|
||||
for order_id, trailing in self.trailing_stops.items():
|
||||
if order_id in prices and trailing.check_trigger(
|
||||
prices[order_id],
|
||||
is_long.get(order_id, True)
|
||||
):
|
||||
triggered.append(order_id)
|
||||
|
||||
return triggered
|
||||
|
||||
async def on_order_filled(self, order_id: int):
|
||||
"""Handle order fill for advanced orders.
|
||||
|
||||
Args:
|
||||
order_id: Filled order ID
|
||||
"""
|
||||
# Check OCO orders
|
||||
if order_id in self.oco_orders:
|
||||
await self.oco_orders[order_id].on_order_filled(order_id)
|
||||
|
||||
# Check iceberg orders
|
||||
for iceberg_id, iceberg in self.iceberg_orders.items():
|
||||
if order_id in iceberg.orders:
|
||||
await iceberg.on_order_filled(order_id)
|
||||
|
||||
|
||||
# Global advanced order manager
|
||||
_advanced_order_manager: Optional[AdvancedOrderManager] = None
|
||||
|
||||
|
||||
def get_advanced_order_manager() -> AdvancedOrderManager:
|
||||
"""Get global advanced order manager instance."""
|
||||
global _advanced_order_manager
|
||||
if _advanced_order_manager is None:
|
||||
_advanced_order_manager = AdvancedOrderManager()
|
||||
return _advanced_order_manager
|
||||
|
||||
245
src/trading/engine.py
Normal file
245
src/trading/engine.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Main trading engine with order execution and position management."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.core.database import Order, OrderStatus, OrderSide, OrderType, get_database
|
||||
from src.core.logger import get_logger
|
||||
from src.core.repositories import OrderRepository
|
||||
from src.exchanges import get_exchange
|
||||
from .order_manager import get_order_manager
|
||||
from .paper_trading import get_paper_trading
|
||||
from .fee_calculator import get_fee_calculator
|
||||
from src.risk.manager import get_risk_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TradingEngine:
|
||||
"""Main trading engine orchestrator."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize trading engine."""
|
||||
self.db = get_database()
|
||||
self.order_manager = get_order_manager()
|
||||
self.paper_trading = get_paper_trading()
|
||||
self.risk_manager = get_risk_manager()
|
||||
self.logger = get_logger(__name__)
|
||||
self._exchanges: Dict[int, Any] = {} # exchange_id -> adapter
|
||||
|
||||
async def get_exchange_adapter(self, exchange_id: int):
|
||||
"""Get or create exchange adapter.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange ID
|
||||
|
||||
Returns:
|
||||
Exchange adapter or None
|
||||
"""
|
||||
if exchange_id not in self._exchanges:
|
||||
# We assume get_exchange is a factory function that might NOT be async itself,
|
||||
# but returns an adapter that IS async.
|
||||
# If get_exchange does I/O, it should be awaited.
|
||||
# For now, let's assume it just instantiates the class.
|
||||
adapter = await get_exchange(exchange_id)
|
||||
if adapter:
|
||||
self._exchanges[exchange_id] = adapter
|
||||
if not adapter._connected:
|
||||
await adapter.connect()
|
||||
return self._exchanges.get(exchange_id)
|
||||
|
||||
async def execute_order(
|
||||
self,
|
||||
exchange_id: int,
|
||||
strategy_id: Optional[int],
|
||||
symbol: str,
|
||||
side: OrderSide,
|
||||
order_type: OrderType,
|
||||
quantity: Decimal,
|
||||
price: Optional[Decimal] = None,
|
||||
paper_trading: bool = True
|
||||
) -> Optional[Order]:
|
||||
"""Execute a trading order."""
|
||||
try:
|
||||
# Get exchange adapter
|
||||
adapter = await self.get_exchange_adapter(exchange_id)
|
||||
if not adapter:
|
||||
self.logger.error(f"Exchange {exchange_id} not available")
|
||||
return None
|
||||
|
||||
# Get current balance/price for risk checks
|
||||
balance = self.paper_trading.get_balance() if paper_trading else Decimal(0)
|
||||
ticker = await adapter.get_ticker(symbol)
|
||||
current_price = ticker.get('last', price or Decimal(0))
|
||||
|
||||
# Risk checks (includes fee validation)
|
||||
allowed, reason = await self.risk_manager.check_order_risk(
|
||||
symbol, side.value, quantity, current_price, balance, adapter
|
||||
)
|
||||
if not allowed:
|
||||
self.logger.warning(f"Order rejected: {reason}")
|
||||
return None
|
||||
|
||||
# Create order in database using generic repository pattern would be better
|
||||
# but for now we create the object and save it using our own session management
|
||||
# or rely on the order_manager which we need to check if it's async-ready.
|
||||
# For this refactor, let's use the OrderRepository directly for creation
|
||||
# to be safe with sessions.
|
||||
|
||||
async with self.db.get_session() as session:
|
||||
repo = OrderRepository(session)
|
||||
order = Order(
|
||||
exchange_id=exchange_id,
|
||||
strategy_id=strategy_id,
|
||||
symbol=symbol,
|
||||
order_type=order_type,
|
||||
side=side,
|
||||
quantity=quantity,
|
||||
price=price,
|
||||
paper_trading=paper_trading
|
||||
)
|
||||
order = await repo.create(order)
|
||||
|
||||
if paper_trading:
|
||||
# Execute in paper trading
|
||||
fill_price = current_price # Use current price for market orders
|
||||
if order_type == OrderType.LIMIT and price:
|
||||
fill_price = price
|
||||
|
||||
is_maker = (order_type == OrderType.LIMIT)
|
||||
|
||||
# Use paper trading fee calculation (uses configured fee_exchange)
|
||||
fee_calculator = get_fee_calculator()
|
||||
fee = fee_calculator.calculate_fee_for_paper_trading(
|
||||
quantity=quantity,
|
||||
price=fill_price,
|
||||
order_type=order_type,
|
||||
is_maker=is_maker
|
||||
)
|
||||
|
||||
if await self.paper_trading.execute_order(order, fill_price, fee):
|
||||
self.logger.info(f"Paper trading order {order.id} executed with fee: {fee}")
|
||||
return order
|
||||
else:
|
||||
self.logger.warning(f"Paper trading order {order.id} rejected (insufficient funds or no position)")
|
||||
async with self.db.get_session() as session:
|
||||
repo = OrderRepository(session)
|
||||
await repo.update_status(order.id, OrderStatus.REJECTED)
|
||||
return None
|
||||
else:
|
||||
# Execute live order
|
||||
from src.exchanges.base import Order as ExchangeOrder
|
||||
exchange_order = ExchangeOrder(
|
||||
symbol=symbol,
|
||||
side=side.value,
|
||||
order_type=order_type.value,
|
||||
quantity=quantity,
|
||||
price=price
|
||||
)
|
||||
|
||||
result = await adapter.place_order(exchange_order)
|
||||
if result.get('id'):
|
||||
actual_fee = adapter.extract_fee_from_order_response(result)
|
||||
|
||||
if actual_fee is None:
|
||||
fee_calculator = get_fee_calculator()
|
||||
is_maker = (order_type == OrderType.LIMIT)
|
||||
estimated_price = price or current_price
|
||||
actual_fee = fee_calculator.calculate_fee(
|
||||
quantity=quantity,
|
||||
price=estimated_price,
|
||||
order_type=order_type,
|
||||
exchange_adapter=adapter,
|
||||
is_maker=is_maker
|
||||
)
|
||||
|
||||
async with self.db.get_session() as session:
|
||||
repo = OrderRepository(session)
|
||||
await repo.update_status(
|
||||
order.id,
|
||||
OrderStatus.OPEN,
|
||||
exchange_order_id=result['id'],
|
||||
fee=actual_fee
|
||||
)
|
||||
return order
|
||||
else:
|
||||
async with self.db.get_session() as session:
|
||||
repo = OrderRepository(session)
|
||||
await repo.update_status(order.id, OrderStatus.REJECTED)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to execute order: {e}")
|
||||
return None
|
||||
|
||||
async def cancel_order(self, order_id: int) -> bool:
|
||||
"""Cancel an order."""
|
||||
async with self.db.get_session() as session:
|
||||
repo = OrderRepository(session)
|
||||
order = await repo.get_by_id(order_id)
|
||||
|
||||
if not order:
|
||||
return False
|
||||
|
||||
# If already in final state, cannot cancel
|
||||
if order.status in [OrderStatus.FILLED, OrderStatus.CANCELLED, OrderStatus.REJECTED, OrderStatus.EXPIRED]:
|
||||
self.logger.warning(f"Order {order_id} is already in state {order.status}, cannot cancel")
|
||||
return False
|
||||
|
||||
# If paper trading, just update status in DB
|
||||
if order.paper_trading:
|
||||
await repo.update_status(order_id, OrderStatus.CANCELLED)
|
||||
self.logger.info(f"Paper trading order {order_id} cancelled")
|
||||
return True
|
||||
|
||||
# If live trading, cancel on exchange
|
||||
adapter = await self.get_exchange_adapter(order.exchange_id)
|
||||
if not adapter:
|
||||
self.logger.error(f"Exchange adapter not found for exchange_id {order.exchange_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
if await adapter.cancel_order(order.exchange_order_id or str(order.id), order.symbol):
|
||||
# Update status in DB
|
||||
await repo.update_status(order_id, OrderStatus.CANCELLED)
|
||||
self.logger.info(f"Live order {order_id} cancelled on exchange")
|
||||
return True
|
||||
else:
|
||||
self.logger.error(f"Failed to cancel order {order_id} on exchange")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cancelling order {order_id} on exchange: {e}")
|
||||
|
||||
return False
|
||||
|
||||
async def get_positions(self, exchange_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""Get open positions."""
|
||||
if exchange_id:
|
||||
adapter = await self.get_exchange_adapter(exchange_id)
|
||||
if adapter:
|
||||
return await adapter.get_positions()
|
||||
|
||||
# Get from paper trading
|
||||
positions = self.paper_trading.get_positions()
|
||||
return [
|
||||
{
|
||||
'symbol': pos.symbol,
|
||||
'side': pos.side,
|
||||
'quantity': pos.quantity,
|
||||
'entry_price': pos.entry_price,
|
||||
'current_price': pos.current_price,
|
||||
'unrealized_pnl': pos.unrealized_pnl,
|
||||
}
|
||||
for pos in positions
|
||||
]
|
||||
|
||||
|
||||
# Global trading engine
|
||||
_trading_engine: Optional[TradingEngine] = None
|
||||
|
||||
|
||||
def get_trading_engine() -> TradingEngine:
|
||||
"""Get global trading engine instance."""
|
||||
global _trading_engine
|
||||
if _trading_engine is None:
|
||||
_trading_engine = TradingEngine()
|
||||
return _trading_engine
|
||||
278
src/trading/fee_calculator.py
Normal file
278
src/trading/fee_calculator.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Centralized fee calculation service for trading operations."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Optional, Any
|
||||
from src.core.logger import get_logger
|
||||
from src.core.config import get_config
|
||||
from src.exchanges.base import BaseExchangeAdapter
|
||||
from src.core.database import OrderType, OrderSide
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FeeCalculator:
|
||||
"""Centralized fee calculation service."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize fee calculator."""
|
||||
self.config = get_config()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def calculate_fee(
|
||||
self,
|
||||
quantity: Decimal,
|
||||
price: Decimal,
|
||||
order_type: OrderType,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None,
|
||||
is_maker: Optional[bool] = None
|
||||
) -> Decimal:
|
||||
"""Calculate trading fee for an order.
|
||||
|
||||
Args:
|
||||
quantity: Trade quantity
|
||||
price: Trade price
|
||||
order_type: Order type (MARKET or LIMIT)
|
||||
exchange_adapter: Exchange adapter for fee structure (optional)
|
||||
is_maker: Explicit maker/taker flag (if None, determined from order_type)
|
||||
|
||||
Returns:
|
||||
Trading fee amount
|
||||
"""
|
||||
if quantity <= 0 or price <= 0:
|
||||
return Decimal(0)
|
||||
|
||||
# Determine maker/taker if not explicitly provided
|
||||
if is_maker is None:
|
||||
is_maker = self._is_maker_order(order_type)
|
||||
|
||||
# Get fee structure
|
||||
fee_structure = self._get_fee_structure(exchange_adapter)
|
||||
fee_rate = fee_structure['maker'] if is_maker else fee_structure['taker']
|
||||
|
||||
# Calculate fee
|
||||
trade_value = quantity * price
|
||||
fee = trade_value * Decimal(str(fee_rate))
|
||||
|
||||
# Apply minimum fee if configured
|
||||
min_fee = fee_structure.get('minimum', Decimal(0))
|
||||
if min_fee > 0 and fee < min_fee:
|
||||
fee = min_fee
|
||||
|
||||
return fee
|
||||
|
||||
def estimate_round_trip_fee(
|
||||
self,
|
||||
quantity: Decimal,
|
||||
price: Decimal,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None
|
||||
) -> Decimal:
|
||||
"""Estimate total fees for a round-trip trade (buy + sell).
|
||||
|
||||
Args:
|
||||
quantity: Trade quantity
|
||||
price: Trade price
|
||||
exchange_adapter: Exchange adapter for fee structure (optional)
|
||||
|
||||
Returns:
|
||||
Total estimated round-trip fee
|
||||
"""
|
||||
# Estimate using taker fees (worst case)
|
||||
buy_fee = self.calculate_fee(quantity, price, OrderType.MARKET, exchange_adapter, is_maker=False)
|
||||
sell_fee = self.calculate_fee(quantity, price, OrderType.MARKET, exchange_adapter, is_maker=False)
|
||||
|
||||
return buy_fee + sell_fee
|
||||
|
||||
def get_minimum_profit_threshold(
|
||||
self,
|
||||
quantity: Decimal,
|
||||
price: Decimal,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None,
|
||||
multiplier: float = 2.0
|
||||
) -> Decimal:
|
||||
"""Calculate minimum profit threshold needed to break even after fees.
|
||||
|
||||
Args:
|
||||
quantity: Trade quantity
|
||||
price: Trade price
|
||||
exchange_adapter: Exchange adapter for fee structure (optional)
|
||||
multiplier: Multiplier for minimum profit (default 2.0 = 2x fees)
|
||||
|
||||
Returns:
|
||||
Minimum profit threshold
|
||||
"""
|
||||
round_trip_fee = self.estimate_round_trip_fee(quantity, price, exchange_adapter)
|
||||
return round_trip_fee * Decimal(str(multiplier))
|
||||
|
||||
def calculate_fee_reserve(
|
||||
self,
|
||||
position_value: Decimal,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None,
|
||||
reserve_percent: Optional[float] = None
|
||||
) -> Decimal:
|
||||
"""Calculate fee reserve amount for position sizing.
|
||||
|
||||
Args:
|
||||
position_value: Intended position value
|
||||
exchange_adapter: Exchange adapter for fee structure (optional)
|
||||
reserve_percent: Override reserve percentage (default: 0.4% for round-trip)
|
||||
|
||||
Returns:
|
||||
Fee reserve amount
|
||||
"""
|
||||
if reserve_percent is None:
|
||||
# Default: 0.4% for round-trip (conservative estimate)
|
||||
reserve_percent = 0.004
|
||||
|
||||
return position_value * Decimal(str(reserve_percent))
|
||||
|
||||
def _is_maker_order(self, order_type: OrderType) -> bool:
|
||||
"""Determine if order type is typically a maker order.
|
||||
|
||||
Args:
|
||||
order_type: Order type
|
||||
|
||||
Returns:
|
||||
True if maker order, False if taker order
|
||||
"""
|
||||
# Limit orders that add liquidity = maker
|
||||
# Market orders that take liquidity = taker
|
||||
return order_type == OrderType.LIMIT
|
||||
|
||||
def _get_fee_structure(
|
||||
self,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get fee structure from exchange or config defaults.
|
||||
|
||||
Args:
|
||||
exchange_adapter: Exchange adapter (optional)
|
||||
|
||||
Returns:
|
||||
Fee structure dictionary with 'maker', 'taker', and optional 'minimum'
|
||||
"""
|
||||
# Try to get from exchange adapter first
|
||||
if exchange_adapter:
|
||||
try:
|
||||
fee_structure = exchange_adapter.get_fee_structure()
|
||||
if fee_structure:
|
||||
return fee_structure
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to get fee structure from exchange: {e}")
|
||||
|
||||
# Get from config with exchange-specific overrides
|
||||
default_fees = self.config.get("trading.default_fees", {
|
||||
"maker": 0.001, # 0.1%
|
||||
"taker": 0.001, # 0.1%
|
||||
"minimum": 0.0
|
||||
})
|
||||
|
||||
# Check for exchange-specific fees if adapter provided
|
||||
if exchange_adapter:
|
||||
exchange_name = exchange_adapter.name.lower()
|
||||
exchange_fees = self.config.get(f"trading.exchanges.{exchange_name}.fees")
|
||||
if exchange_fees:
|
||||
default_fees.update(exchange_fees)
|
||||
|
||||
return default_fees
|
||||
|
||||
def get_fee_structure_by_exchange_name(
|
||||
self,
|
||||
exchange_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get fee structure for a specific exchange by name (for paper trading).
|
||||
|
||||
Args:
|
||||
exchange_name: Exchange name (e.g., 'coinbase', 'kraken', 'binance')
|
||||
|
||||
Returns:
|
||||
Fee structure dictionary with 'maker', 'taker', and optional 'minimum'
|
||||
"""
|
||||
# Get default fees
|
||||
default_fees = self.config.get("trading.default_fees", {
|
||||
"maker": 0.001, # 0.1%
|
||||
"taker": 0.001, # 0.1%
|
||||
"minimum": 0.0
|
||||
})
|
||||
|
||||
# Check for exchange-specific fees
|
||||
exchange_fees = self.config.get(f"trading.exchanges.{exchange_name.lower()}.fees")
|
||||
if exchange_fees:
|
||||
return {**default_fees, **exchange_fees}
|
||||
|
||||
return default_fees
|
||||
|
||||
def calculate_fee_for_paper_trading(
|
||||
self,
|
||||
quantity: Decimal,
|
||||
price: Decimal,
|
||||
order_type: OrderType,
|
||||
is_maker: Optional[bool] = None
|
||||
) -> Decimal:
|
||||
"""Calculate trading fee for paper trading using configured exchange.
|
||||
|
||||
Args:
|
||||
quantity: Trade quantity
|
||||
price: Trade price
|
||||
order_type: Order type (MARKET or LIMIT)
|
||||
is_maker: Explicit maker/taker flag (if None, determined from order_type)
|
||||
|
||||
Returns:
|
||||
Trading fee amount
|
||||
"""
|
||||
if quantity <= 0 or price <= 0:
|
||||
return Decimal(0)
|
||||
|
||||
# Determine maker/taker if not explicitly provided
|
||||
if is_maker is None:
|
||||
is_maker = self._is_maker_order(order_type)
|
||||
|
||||
# Get fee exchange from config
|
||||
fee_exchange = self.config.get("paper_trading.fee_exchange", "coinbase")
|
||||
fee_structure = self.get_fee_structure_by_exchange_name(fee_exchange)
|
||||
fee_rate = fee_structure['maker'] if is_maker else fee_structure['taker']
|
||||
|
||||
# Calculate fee
|
||||
trade_value = quantity * price
|
||||
fee = trade_value * Decimal(str(fee_rate))
|
||||
|
||||
# Apply minimum fee if configured
|
||||
min_fee = fee_structure.get('minimum', Decimal(0))
|
||||
if min_fee > 0 and fee < min_fee:
|
||||
fee = min_fee
|
||||
|
||||
return fee
|
||||
|
||||
def get_fee_percentage(
|
||||
self,
|
||||
order_type: OrderType,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None
|
||||
) -> float:
|
||||
"""Get fee percentage for order type.
|
||||
|
||||
Args:
|
||||
order_type: Order type
|
||||
exchange_adapter: Exchange adapter (optional)
|
||||
|
||||
Returns:
|
||||
Fee percentage (e.g., 0.001 for 0.1%)
|
||||
"""
|
||||
is_maker = self._is_maker_order(order_type)
|
||||
fee_structure = self._get_fee_structure(exchange_adapter)
|
||||
return fee_structure['maker'] if is_maker else fee_structure['taker']
|
||||
|
||||
|
||||
# Global fee calculator instance
|
||||
_fee_calculator: Optional[FeeCalculator] = None
|
||||
|
||||
|
||||
def get_fee_calculator() -> FeeCalculator:
|
||||
"""Get global fee calculator instance.
|
||||
|
||||
Returns:
|
||||
FeeCalculator instance
|
||||
"""
|
||||
global _fee_calculator
|
||||
if _fee_calculator is None:
|
||||
_fee_calculator = FeeCalculator()
|
||||
return _fee_calculator
|
||||
|
||||
122
src/trading/futures.py
Normal file
122
src/trading/futures.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Futures and leverage trading support with margin calculations."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Optional
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FuturesManager:
|
||||
"""Manages futures and leverage trading."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize futures manager."""
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def calculate_margin(
|
||||
self,
|
||||
quantity: Decimal,
|
||||
price: Decimal,
|
||||
leverage: int,
|
||||
margin_type: str = "isolated"
|
||||
) -> Decimal:
|
||||
"""Calculate required margin.
|
||||
|
||||
Args:
|
||||
quantity: Position quantity
|
||||
price: Entry price
|
||||
leverage: Leverage multiplier
|
||||
margin_type: "isolated" or "cross"
|
||||
|
||||
Returns:
|
||||
Required margin
|
||||
"""
|
||||
position_value = quantity * price
|
||||
margin = position_value / Decimal(leverage)
|
||||
return margin
|
||||
|
||||
def calculate_liquidation_price(
|
||||
self,
|
||||
entry_price: Decimal,
|
||||
leverage: int,
|
||||
side: str, # "long" or "short"
|
||||
maintenance_margin: Decimal = Decimal("0.01") # 1%
|
||||
) -> Decimal:
|
||||
"""Calculate liquidation price.
|
||||
|
||||
Args:
|
||||
entry_price: Entry price
|
||||
leverage: Leverage multiplier
|
||||
side: Position side
|
||||
maintenance_margin: Maintenance margin rate
|
||||
|
||||
Returns:
|
||||
Liquidation price
|
||||
"""
|
||||
if side == "long":
|
||||
# For long: liquidation when price drops too much
|
||||
liquidation = entry_price * (1 - (1 / leverage) + maintenance_margin)
|
||||
else:
|
||||
# For short: liquidation when price rises too much
|
||||
liquidation = entry_price * (1 + (1 / leverage) - maintenance_margin)
|
||||
|
||||
return liquidation
|
||||
|
||||
def calculate_funding_rate(
|
||||
self,
|
||||
mark_price: Decimal,
|
||||
index_price: Decimal
|
||||
) -> Decimal:
|
||||
"""Calculate funding rate for perpetual futures.
|
||||
|
||||
Args:
|
||||
mark_price: Mark price
|
||||
index_price: Index price
|
||||
|
||||
Returns:
|
||||
Funding rate (8-hour rate)
|
||||
"""
|
||||
premium = (mark_price - index_price) / index_price
|
||||
funding_rate = premium * Decimal("0.01") # Simplified
|
||||
return funding_rate
|
||||
|
||||
def calculate_unrealized_pnl(
|
||||
self,
|
||||
entry_price: Decimal,
|
||||
current_price: Decimal,
|
||||
quantity: Decimal,
|
||||
side: str,
|
||||
leverage: int = 1
|
||||
) -> Decimal:
|
||||
"""Calculate unrealized P&L for futures position.
|
||||
|
||||
Args:
|
||||
entry_price: Entry price
|
||||
current_price: Current mark price
|
||||
quantity: Position quantity
|
||||
side: "long" or "short"
|
||||
leverage: Leverage multiplier
|
||||
|
||||
Returns:
|
||||
Unrealized P&L
|
||||
"""
|
||||
if side == "long":
|
||||
pnl = (current_price - entry_price) * quantity * leverage
|
||||
else:
|
||||
pnl = (entry_price - current_price) * quantity * leverage
|
||||
|
||||
return pnl
|
||||
|
||||
|
||||
# Global futures manager
|
||||
_futures_manager: Optional[FuturesManager] = None
|
||||
|
||||
|
||||
def get_futures_manager() -> FuturesManager:
|
||||
"""Get global futures manager instance."""
|
||||
global _futures_manager
|
||||
if _futures_manager is None:
|
||||
_futures_manager = FuturesManager()
|
||||
return _futures_manager
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user