Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
27
src/data/__init__.py
Normal file
27
src/data/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Data collection and storage module.
|
||||
|
||||
Provides:
|
||||
- DataCollector: Real-time market data collection
|
||||
- NewsCollector: Crypto news headline aggregation for sentiment analysis
|
||||
- TechnicalIndicators: Technical analysis indicators
|
||||
- DataStorage: Data persistence utilities
|
||||
- DataQualityManager: Data quality checks
|
||||
"""
|
||||
|
||||
from .collector import DataCollector
|
||||
from .news_collector import NewsCollector, NewsItem, NewsSource, get_news_collector
|
||||
from .indicators import TechnicalIndicators, get_indicators
|
||||
from .storage import DataStorage
|
||||
from .quality import DataQualityManager
|
||||
|
||||
__all__ = [
|
||||
"DataCollector",
|
||||
"NewsCollector",
|
||||
"NewsItem",
|
||||
"NewsSource",
|
||||
"get_news_collector",
|
||||
"TechnicalIndicators",
|
||||
"get_indicators",
|
||||
"DataStorage",
|
||||
"DataQualityManager",
|
||||
]
|
||||
221
src/data/cache_manager.py
Normal file
221
src/data/cache_manager.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Intelligent caching system for pricing data."""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from collections import OrderedDict
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CacheEntry:
|
||||
"""Cache entry with TTL support."""
|
||||
|
||||
def __init__(self, data: Any, ttl: float):
|
||||
"""Initialize cache entry.
|
||||
|
||||
Args:
|
||||
data: Cached data
|
||||
ttl: Time to live in seconds
|
||||
"""
|
||||
self.data = data
|
||||
self.expires_at = time.time() + ttl
|
||||
self.created_at = time.time()
|
||||
self.access_count = 0
|
||||
self.last_accessed = time.time()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if entry is expired.
|
||||
|
||||
Returns:
|
||||
True if expired
|
||||
"""
|
||||
return time.time() > self.expires_at
|
||||
|
||||
def touch(self):
|
||||
"""Update access statistics."""
|
||||
self.access_count += 1
|
||||
self.last_accessed = time.time()
|
||||
|
||||
def age(self) -> float:
|
||||
"""Get age of entry in seconds.
|
||||
|
||||
Returns:
|
||||
Age in seconds
|
||||
"""
|
||||
return time.time() - self.created_at
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Intelligent cache manager with TTL and size limits.
|
||||
|
||||
Implements LRU (Least Recently Used) eviction when size limit is reached.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_ttl: float = 60.0,
|
||||
max_size: int = 1000,
|
||||
ticker_ttl: float = 2.0,
|
||||
ohlcv_ttl: float = 60.0
|
||||
):
|
||||
"""Initialize cache manager.
|
||||
|
||||
Args:
|
||||
default_ttl: Default TTL in seconds
|
||||
max_size: Maximum number of cache entries
|
||||
ticker_ttl: TTL for ticker data in seconds
|
||||
ohlcv_ttl: TTL for OHLCV data in seconds
|
||||
"""
|
||||
self.default_ttl = default_ttl
|
||||
self.max_size = max_size
|
||||
self.ticker_ttl = ticker_ttl
|
||||
self.ohlcv_ttl = ohlcv_ttl
|
||||
|
||||
# Use OrderedDict for LRU eviction
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
self._evictions = 0
|
||||
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
# Clean expired entries
|
||||
self._cleanup_expired()
|
||||
|
||||
if key not in self._cache:
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[key]
|
||||
|
||||
if entry.is_expired():
|
||||
# Remove expired entry
|
||||
del self._cache[key]
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
# Update access (move to end for LRU)
|
||||
entry.touch()
|
||||
self._cache.move_to_end(key)
|
||||
self._hits += 1
|
||||
return entry.data
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[float] = None,
|
||||
cache_type: Optional[str] = None
|
||||
):
|
||||
"""Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds (uses type-specific or default if None)
|
||||
cache_type: Type of cache ('ticker' or 'ohlcv') for type-specific TTL
|
||||
"""
|
||||
# Determine TTL
|
||||
if ttl is None:
|
||||
if cache_type == 'ticker':
|
||||
ttl = self.ticker_ttl
|
||||
elif cache_type == 'ohlcv':
|
||||
ttl = self.ohlcv_ttl
|
||||
else:
|
||||
ttl = self.default_ttl
|
||||
|
||||
# Check if we need to evict
|
||||
if len(self._cache) >= self.max_size and key not in self._cache:
|
||||
self._evict_lru()
|
||||
|
||||
# Create entry
|
||||
entry = CacheEntry(value, ttl)
|
||||
|
||||
# Add or update
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
self._cache[key] = entry
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete entry from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cache entries."""
|
||||
self._cache.clear()
|
||||
self.logger.info("Cache cleared")
|
||||
|
||||
def _cleanup_expired(self):
|
||||
"""Remove expired entries from cache."""
|
||||
expired_keys = [
|
||||
key for key, entry in self._cache.items()
|
||||
if entry.is_expired()
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
|
||||
def _evict_lru(self):
|
||||
"""Evict least recently used entry."""
|
||||
if self._cache:
|
||||
# Remove oldest (first) entry
|
||||
self._cache.popitem(last=False)
|
||||
self._evictions += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
total_requests = self._hits + self._misses
|
||||
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0.0
|
||||
|
||||
# Calculate average age
|
||||
if self._cache:
|
||||
avg_age = sum(entry.age() for entry in self._cache.values()) / len(self._cache)
|
||||
else:
|
||||
avg_age = 0.0
|
||||
|
||||
return {
|
||||
'size': len(self._cache),
|
||||
'max_size': self.max_size,
|
||||
'hits': self._hits,
|
||||
'misses': self._misses,
|
||||
'hit_rate': round(hit_rate, 2),
|
||||
'evictions': self._evictions,
|
||||
'avg_age_seconds': round(avg_age, 2),
|
||||
}
|
||||
|
||||
def invalidate_pattern(self, pattern: str):
|
||||
"""Invalidate entries matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: String pattern to match (simple substring match)
|
||||
"""
|
||||
keys_to_delete = [key for key in self._cache.keys() if pattern in key]
|
||||
for key in keys_to_delete:
|
||||
del self._cache[key]
|
||||
|
||||
if keys_to_delete:
|
||||
self.logger.info(f"Invalidated {len(keys_to_delete)} cache entries matching '{pattern}'")
|
||||
139
src/data/collector.py
Normal file
139
src/data/collector.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Real-time data collection system with WebSocket support."""
|
||||
|
||||
import asyncio
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Optional, Callable, List
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select
|
||||
from src.core.database import get_database, MarketData
|
||||
from src.core.logger import get_logger
|
||||
from .pricing_service import get_pricing_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataCollector:
|
||||
"""Collects real-time market data using the unified pricing service."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data collector."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
self._callbacks: Dict[str, List[Callable]] = {}
|
||||
self._running = False
|
||||
self._pricing_service = get_pricing_service()
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
exchange_id: Optional[int] = None,
|
||||
symbol: str = "",
|
||||
callback: Optional[Callable] = None
|
||||
):
|
||||
"""Subscribe to real-time data.
|
||||
|
||||
Args:
|
||||
exchange_id: Exchange ID (deprecated, kept for backward compatibility)
|
||||
symbol: Trading symbol
|
||||
callback: Callback function(data)
|
||||
"""
|
||||
if not symbol or not callback:
|
||||
logger.warning("subscribe called without symbol or callback")
|
||||
return
|
||||
|
||||
key = f"pricing:{symbol}"
|
||||
if key not in self._callbacks:
|
||||
self._callbacks[key] = []
|
||||
self._callbacks[key].append(callback)
|
||||
|
||||
# Subscribe via pricing service
|
||||
def wrapped_callback(data):
|
||||
"""Wrap callback to maintain backward compatibility."""
|
||||
for cb in self._callbacks.get(key, []):
|
||||
try:
|
||||
cb(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
self._pricing_service.subscribe_ticker(symbol, wrapped_callback)
|
||||
|
||||
async def store_ohlcv(
|
||||
self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
ohlcv_data: List[List]
|
||||
):
|
||||
"""Store OHLCV data in database.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name (can be provider name like 'CCXT-Kraken' or 'CoinGecko')
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
ohlcv_data: List of [timestamp, open, high, low, close, volume]
|
||||
"""
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
for candle in ohlcv_data:
|
||||
timestamp = datetime.fromtimestamp(candle[0] / 1000)
|
||||
|
||||
# Check if already exists
|
||||
stmt = select(MarketData).filter_by(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if not existing:
|
||||
market_data = MarketData(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
open=Decimal(str(candle[1])),
|
||||
high=Decimal(str(candle[2])),
|
||||
low=Decimal(str(candle[3])),
|
||||
close=Decimal(str(candle[4])),
|
||||
volume=Decimal(str(candle[5]))
|
||||
)
|
||||
session.add(market_data)
|
||||
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to store OHLCV data: {e}")
|
||||
|
||||
def get_ohlcv_from_pricing_service(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV data from pricing service.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
since: Start datetime
|
||||
limit: Number of candles
|
||||
|
||||
Returns:
|
||||
List of OHLCV candles
|
||||
"""
|
||||
return self._pricing_service.get_ohlcv(symbol, timeframe, since, limit)
|
||||
|
||||
|
||||
# Global data collector
|
||||
_data_collector: Optional[DataCollector] = None
|
||||
|
||||
|
||||
def get_data_collector() -> DataCollector:
|
||||
"""Get global data collector instance."""
|
||||
global _data_collector
|
||||
if _data_collector is None:
|
||||
_data_collector = DataCollector()
|
||||
return _data_collector
|
||||
|
||||
317
src/data/health_monitor.py
Normal file
317
src/data/health_monitor.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Health monitoring and failover management for pricing providers."""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HealthStatus(Enum):
|
||||
"""Provider health status."""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthMetrics:
|
||||
"""Health metrics for a provider."""
|
||||
status: HealthStatus = HealthStatus.UNKNOWN
|
||||
last_check: Optional[datetime] = None
|
||||
last_success: Optional[datetime] = None
|
||||
last_failure: Optional[datetime] = None
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
response_times: deque = field(default_factory=lambda: deque(maxlen=100))
|
||||
consecutive_failures: int = 0
|
||||
circuit_breaker_open: bool = False
|
||||
circuit_breaker_opened_at: Optional[datetime] = None
|
||||
|
||||
def record_success(self, response_time: float):
|
||||
"""Record a successful operation.
|
||||
|
||||
Args:
|
||||
response_time: Response time in seconds
|
||||
"""
|
||||
self.status = HealthStatus.HEALTHY
|
||||
self.last_check = datetime.utcnow()
|
||||
self.last_success = datetime.utcnow()
|
||||
self.success_count += 1
|
||||
self.response_times.append(response_time)
|
||||
self.consecutive_failures = 0
|
||||
self.circuit_breaker_open = False
|
||||
self.circuit_breaker_opened_at = None
|
||||
|
||||
def record_failure(self):
|
||||
"""Record a failed operation."""
|
||||
self.last_check = datetime.utcnow()
|
||||
self.last_failure = datetime.utcnow()
|
||||
self.failure_count += 1
|
||||
self.consecutive_failures += 1
|
||||
|
||||
# Open circuit breaker after 5 consecutive failures
|
||||
if self.consecutive_failures >= 5:
|
||||
if not self.circuit_breaker_open:
|
||||
self.circuit_breaker_open = True
|
||||
self.circuit_breaker_opened_at = datetime.utcnow()
|
||||
logger.warning(f"Circuit breaker opened after {self.consecutive_failures} failures")
|
||||
|
||||
# Update status based on failure rate
|
||||
total = self.success_count + self.failure_count
|
||||
if total > 0:
|
||||
failure_rate = self.failure_count / total
|
||||
if failure_rate > 0.5:
|
||||
self.status = HealthStatus.UNHEALTHY
|
||||
elif failure_rate > 0.2:
|
||||
self.status = HealthStatus.DEGRADED
|
||||
|
||||
def get_avg_response_time(self) -> float:
|
||||
"""Get average response time.
|
||||
|
||||
Returns:
|
||||
Average response time in seconds, or 0.0 if no data
|
||||
"""
|
||||
if not self.response_times:
|
||||
return 0.0
|
||||
return sum(self.response_times) / len(self.response_times)
|
||||
|
||||
def should_attempt(self, circuit_breaker_timeout: int = 60) -> bool:
|
||||
"""Check if we should attempt to use this provider.
|
||||
|
||||
Args:
|
||||
circuit_breaker_timeout: Seconds to wait before retrying after circuit breaker opens
|
||||
|
||||
Returns:
|
||||
True if we should attempt, False otherwise
|
||||
"""
|
||||
if not self.circuit_breaker_open:
|
||||
return True
|
||||
|
||||
# Check if timeout has passed
|
||||
if self.circuit_breaker_opened_at:
|
||||
elapsed = (datetime.utcnow() - self.circuit_breaker_opened_at).total_seconds()
|
||||
if elapsed >= circuit_breaker_timeout:
|
||||
# Half-open state: allow one attempt
|
||||
logger.info("Circuit breaker half-open, allowing attempt")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API responses."""
|
||||
return {
|
||||
'status': self.status.value,
|
||||
'last_check': self.last_check.isoformat() if self.last_check else None,
|
||||
'last_success': self.last_success.isoformat() if self.last_success else None,
|
||||
'last_failure': self.last_failure.isoformat() if self.last_failure else None,
|
||||
'success_count': self.success_count,
|
||||
'failure_count': self.failure_count,
|
||||
'avg_response_time': round(self.get_avg_response_time(), 3),
|
||||
'consecutive_failures': self.consecutive_failures,
|
||||
'circuit_breaker_open': self.circuit_breaker_open,
|
||||
'circuit_breaker_opened_at': (
|
||||
self.circuit_breaker_opened_at.isoformat()
|
||||
if self.circuit_breaker_opened_at else None
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""Monitors health of pricing providers and manages failover."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
circuit_breaker_timeout: int = 60,
|
||||
min_success_rate: float = 0.8,
|
||||
max_avg_response_time: float = 5.0
|
||||
):
|
||||
"""Initialize health monitor.
|
||||
|
||||
Args:
|
||||
circuit_breaker_timeout: Seconds to wait before retrying after circuit breaker opens
|
||||
min_success_rate: Minimum success rate to be considered healthy (0.0-1.0)
|
||||
max_avg_response_time: Maximum average response time in seconds to be considered healthy
|
||||
"""
|
||||
self.circuit_breaker_timeout = circuit_breaker_timeout
|
||||
self.min_success_rate = min_success_rate
|
||||
self.max_avg_response_time = max_avg_response_time
|
||||
|
||||
self._metrics: Dict[str, HealthMetrics] = {}
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def record_success(self, provider_name: str, response_time: float):
|
||||
"""Record a successful operation for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
response_time: Response time in seconds
|
||||
"""
|
||||
if provider_name not in self._metrics:
|
||||
self._metrics[provider_name] = HealthMetrics()
|
||||
|
||||
self._metrics[provider_name].record_success(response_time)
|
||||
self.logger.debug(f"Recorded success for {provider_name} ({response_time:.3f}s)")
|
||||
|
||||
def record_failure(self, provider_name: str):
|
||||
"""Record a failed operation for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
"""
|
||||
if provider_name not in self._metrics:
|
||||
self._metrics[provider_name] = HealthMetrics()
|
||||
|
||||
self._metrics[provider_name].record_failure()
|
||||
self.logger.warning(f"Recorded failure for {provider_name} "
|
||||
f"(consecutive: {self._metrics[provider_name].consecutive_failures})")
|
||||
|
||||
def is_healthy(self, provider_name: str) -> bool:
|
||||
"""Check if a provider is healthy.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
|
||||
Returns:
|
||||
True if provider is healthy
|
||||
"""
|
||||
if provider_name not in self._metrics:
|
||||
return True # Assume healthy if no metrics yet
|
||||
|
||||
metrics = self._metrics[provider_name]
|
||||
|
||||
# Check circuit breaker
|
||||
if not metrics.should_attempt(self.circuit_breaker_timeout):
|
||||
return False
|
||||
|
||||
# Check status
|
||||
if metrics.status == HealthStatus.UNHEALTHY:
|
||||
return False
|
||||
|
||||
# Check success rate
|
||||
total = metrics.success_count + metrics.failure_count
|
||||
if total > 10: # Need minimum data points
|
||||
success_rate = metrics.success_count / total
|
||||
if success_rate < self.min_success_rate:
|
||||
return False
|
||||
|
||||
# Check response time
|
||||
if metrics.response_times:
|
||||
avg_response_time = metrics.get_avg_response_time()
|
||||
if avg_response_time > self.max_avg_response_time:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_health_status(self, provider_name: str) -> HealthStatus:
|
||||
"""Get health status for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
|
||||
Returns:
|
||||
Health status
|
||||
"""
|
||||
if provider_name not in self._metrics:
|
||||
return HealthStatus.UNKNOWN
|
||||
|
||||
return self._metrics[provider_name].status
|
||||
|
||||
def get_metrics(self, provider_name: str) -> Optional[HealthMetrics]:
|
||||
"""Get health metrics for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
|
||||
Returns:
|
||||
Health metrics or None if not found
|
||||
"""
|
||||
return self._metrics.get(provider_name)
|
||||
|
||||
def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all provider metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping provider names to their metrics
|
||||
"""
|
||||
return {
|
||||
name: metrics.to_dict()
|
||||
for name, metrics in self._metrics.items()
|
||||
}
|
||||
|
||||
def select_healthiest(self, provider_names: List[str]) -> Optional[str]:
|
||||
"""Select the healthiest provider from a list.
|
||||
|
||||
Args:
|
||||
provider_names: List of provider names to choose from
|
||||
|
||||
Returns:
|
||||
Name of healthiest provider, or None if none are healthy
|
||||
"""
|
||||
healthy_providers = [
|
||||
name for name in provider_names
|
||||
if self.is_healthy(name)
|
||||
]
|
||||
|
||||
if not healthy_providers:
|
||||
return None
|
||||
|
||||
# Sort by health metrics (better providers first)
|
||||
def health_score(name: str) -> tuple:
|
||||
metrics = self._metrics.get(name)
|
||||
if not metrics:
|
||||
return (1, 0, 0) # Unknown providers get lowest priority
|
||||
|
||||
# Score: (status_weight, -avg_response_time, success_rate)
|
||||
status_weights = {
|
||||
HealthStatus.HEALTHY: 3,
|
||||
HealthStatus.DEGRADED: 2,
|
||||
HealthStatus.UNHEALTHY: 1,
|
||||
HealthStatus.UNKNOWN: 0,
|
||||
}
|
||||
|
||||
success_rate = (
|
||||
metrics.success_count / (metrics.success_count + metrics.failure_count)
|
||||
if (metrics.success_count + metrics.failure_count) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return (
|
||||
status_weights.get(metrics.status, 0),
|
||||
-metrics.get_avg_response_time(),
|
||||
success_rate
|
||||
)
|
||||
|
||||
sorted_providers = sorted(healthy_providers, key=health_score, reverse=True)
|
||||
return sorted_providers[0] if sorted_providers else None
|
||||
|
||||
def reset_circuit_breaker(self, provider_name: str):
|
||||
"""Manually reset circuit breaker for a provider.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
"""
|
||||
if provider_name in self._metrics:
|
||||
self._metrics[provider_name].circuit_breaker_open = False
|
||||
self._metrics[provider_name].circuit_breaker_opened_at = None
|
||||
self._metrics[provider_name].consecutive_failures = 0
|
||||
self.logger.info(f"Circuit breaker reset for {provider_name}")
|
||||
|
||||
def reset_metrics(self, provider_name: Optional[str] = None):
|
||||
"""Reset metrics for a provider or all providers.
|
||||
|
||||
Args:
|
||||
provider_name: Name of provider to reset, or None to reset all
|
||||
"""
|
||||
if provider_name:
|
||||
if provider_name in self._metrics:
|
||||
del self._metrics[provider_name]
|
||||
self.logger.info(f"Reset metrics for {provider_name}")
|
||||
else:
|
||||
self._metrics.clear()
|
||||
self.logger.info("Reset all provider metrics")
|
||||
569
src/data/indicators.py
Normal file
569
src/data/indicators.py
Normal file
@@ -0,0 +1,569 @@
|
||||
"""Comprehensive technical indicator library using pandas-ta and TA-Lib."""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
# Try to import pandas_ta, but handle if numba is missing
|
||||
try:
|
||||
import pandas_ta as ta
|
||||
PANDAS_TA_AVAILABLE = True
|
||||
except ImportError:
|
||||
PANDAS_TA_AVAILABLE = False
|
||||
ta = None
|
||||
import warnings
|
||||
warnings.warn("pandas-ta not available (numba issue), using basic implementations")
|
||||
|
||||
try:
|
||||
import talib
|
||||
TALIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
TALIB_AVAILABLE = False
|
||||
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TechnicalIndicators:
|
||||
"""Technical indicators library."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize indicators library."""
|
||||
self.talib_available = TALIB_AVAILABLE
|
||||
|
||||
# Trend Indicators
|
||||
|
||||
def sma(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Simple Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.sma(data, length=period)
|
||||
return data.rolling(window=period).mean()
|
||||
|
||||
def ema(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Exponential Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.ema(data, length=period)
|
||||
return data.ewm(span=period, adjust=False).mean()
|
||||
|
||||
def wma(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Weighted Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.wma(data, length=period)
|
||||
# Basic WMA implementation
|
||||
weights = np.arange(1, period + 1)
|
||||
return data.rolling(window=period).apply(lambda x: np.dot(x, weights) / weights.sum(), raw=True)
|
||||
|
||||
def dema(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Double Exponential Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.dema(data, length=period)
|
||||
ema1 = self.ema(data, period)
|
||||
return 2 * ema1 - self.ema(ema1, period)
|
||||
|
||||
def tema(self, data: pd.Series, period: int = 20) -> pd.Series:
|
||||
"""Triple Exponential Moving Average."""
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.tema(data, length=period)
|
||||
ema1 = self.ema(data, period)
|
||||
ema2 = self.ema(ema1, period)
|
||||
ema3 = self.ema(ema2, period)
|
||||
return 3 * ema1 - 3 * ema2 + ema3
|
||||
|
||||
# Momentum Indicators
|
||||
|
||||
def rsi(self, data: pd.Series, period: int = 14) -> pd.Series:
|
||||
"""Relative Strength Index."""
|
||||
if self.talib_available:
|
||||
return pd.Series(talib.RSI(data.values, timeperiod=period), index=data.index)
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.rsi(data, length=period)
|
||||
# Basic RSI implementation
|
||||
delta = data.diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
return 100 - (100 / (1 + rs))
|
||||
|
||||
def macd(
|
||||
self,
|
||||
data: pd.Series,
|
||||
fast: int = 12,
|
||||
slow: int = 26,
|
||||
signal: int = 9
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""MACD (Moving Average Convergence Divergence)."""
|
||||
if self.talib_available:
|
||||
macd, signal_line, histogram = talib.MACD(
|
||||
data.values, fastperiod=fast, slowperiod=slow, signalperiod=signal
|
||||
)
|
||||
return {
|
||||
'macd': pd.Series(macd, index=data.index),
|
||||
'signal': pd.Series(signal_line, index=data.index),
|
||||
'histogram': pd.Series(histogram, index=data.index),
|
||||
}
|
||||
if not PANDAS_TA_AVAILABLE or ta is None:
|
||||
# Basic MACD implementation fallback
|
||||
ema_fast = self.ema(data, fast)
|
||||
ema_slow = self.ema(data, slow)
|
||||
macd_line = ema_fast - ema_slow
|
||||
signal_line = self.ema(macd_line.dropna(), signal)
|
||||
histogram = macd_line - signal_line
|
||||
return {
|
||||
'macd': macd_line,
|
||||
'signal': signal_line,
|
||||
'histogram': histogram,
|
||||
}
|
||||
result = ta.macd(data, fast=fast, slow=slow, signal=signal)
|
||||
return {
|
||||
'macd': result[f'MACD_{fast}_{slow}_{signal}'],
|
||||
'signal': result[f'MACDs_{fast}_{slow}_{signal}'],
|
||||
'histogram': result[f'MACDh_{fast}_{slow}_{signal}'],
|
||||
}
|
||||
|
||||
def stochastic(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
k_period: int = 14,
|
||||
d_period: int = 3
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""Stochastic Oscillator."""
|
||||
if self.talib_available:
|
||||
slowk, slowd = talib.STOCH(
|
||||
high.values, low.values, close.values,
|
||||
fastk_period=k_period, slowk_period=d_period, slowd_period=d_period
|
||||
)
|
||||
return {
|
||||
'k': pd.Series(slowk, index=close.index),
|
||||
'd': pd.Series(slowd, index=close.index),
|
||||
}
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
result = ta.stoch(high, low, close, k=k_period, d=d_period)
|
||||
return {
|
||||
'k': result[f'STOCHk_{k_period}_{d_period}_{d_period}'],
|
||||
'd': result[f'STOCHd_{k_period}_{d_period}_{d_period}'],
|
||||
}
|
||||
# Basic Stochastic implementation
|
||||
lowest_low = low.rolling(window=k_period).min()
|
||||
highest_high = high.rolling(window=k_period).max()
|
||||
k = 100 * ((close - lowest_low) / (highest_high - lowest_low))
|
||||
d = k.rolling(window=d_period).mean()
|
||||
return {'k': k, 'd': d}
|
||||
|
||||
def cci(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 20
|
||||
) -> pd.Series:
|
||||
"""Commodity Channel Index."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.CCI(high.values, low.values, close.values, timeperiod=period),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.cci(high, low, close, length=period)
|
||||
# Basic CCI implementation
|
||||
tp = (high + low + close) / 3
|
||||
sma_tp = tp.rolling(window=period).mean()
|
||||
mad = tp.rolling(window=period).apply(lambda x: np.abs(x - x.mean()).mean(), raw=True)
|
||||
return (tp - sma_tp) / (0.015 * mad)
|
||||
|
||||
def williams_r(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 14
|
||||
) -> pd.Series:
|
||||
"""Williams %R."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.WILLR(high.values, low.values, close.values, timeperiod=period),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
return ta.willr(high, low, close, length=period)
|
||||
# Basic Williams %R implementation
|
||||
highest_high = high.rolling(window=period).max()
|
||||
lowest_low = low.rolling(window=period).min()
|
||||
return -100 * ((highest_high - close) / (highest_high - lowest_low))
|
||||
|
||||
# Volatility Indicators
|
||||
|
||||
def bollinger_bands(
|
||||
self,
|
||||
data: pd.Series,
|
||||
period: int = 20,
|
||||
std_dev: float = 2.0
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""Bollinger Bands."""
|
||||
if self.talib_available:
|
||||
upper, middle, lower = talib.BBANDS(
|
||||
data.values, timeperiod=period, nbdevup=std_dev, nbdevdn=std_dev
|
||||
)
|
||||
return {
|
||||
'upper': pd.Series(upper, index=data.index),
|
||||
'middle': pd.Series(middle, index=data.index),
|
||||
'lower': pd.Series(lower, index=data.index),
|
||||
}
|
||||
if PANDAS_TA_AVAILABLE:
|
||||
result = ta.bbands(data, length=period, std=std_dev)
|
||||
return {
|
||||
'upper': result[f'BBU_{period}_{std_dev}'],
|
||||
'middle': result[f'BBM_{period}_{std_dev}'],
|
||||
'lower': result[f'BBL_{period}_{std_dev}'],
|
||||
}
|
||||
# Basic Bollinger Bands implementation
|
||||
middle = self.sma(data, period)
|
||||
std = data.rolling(window=period).std()
|
||||
return {
|
||||
'upper': middle + (std * std_dev),
|
||||
'middle': middle,
|
||||
'lower': middle - (std * std_dev),
|
||||
}
|
||||
|
||||
def atr(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 14
|
||||
) -> pd.Series:
|
||||
"""Average True Range."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.ATR(high.values, low.values, close.values, timeperiod=period),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.atr(high, low, close, length=period)
|
||||
# Basic ATR implementation
|
||||
prev_close = close.shift(1)
|
||||
tr1 = high - low
|
||||
tr2 = abs(high - prev_close)
|
||||
tr3 = abs(low - prev_close)
|
||||
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
||||
return tr.rolling(window=period).mean()
|
||||
|
||||
def keltner_channels(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 20,
|
||||
multiplier: float = 2.0
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""Keltner Channels."""
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.kc(high, low, close, length=period, scalar=multiplier)
|
||||
# Basic Keltner Channels implementation
|
||||
middle = self.ema(close, period)
|
||||
atr_val = self.atr(high, low, close, period)
|
||||
return {
|
||||
'lower': middle - (multiplier * atr_val),
|
||||
'middle': middle,
|
||||
'upper': middle + (multiplier * atr_val),
|
||||
}
|
||||
|
||||
# Volume Indicators
|
||||
|
||||
def obv(self, close: pd.Series, volume: pd.Series) -> pd.Series:
|
||||
"""On-Balance Volume."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.OBV(close.values, volume.values),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.obv(close, volume)
|
||||
# Basic OBV implementation
|
||||
price_change = close.diff()
|
||||
obv = pd.Series(index=close.index, dtype=float)
|
||||
obv.iloc[0] = 0
|
||||
for i in range(1, len(close)):
|
||||
if price_change.iloc[i] > 0:
|
||||
obv.iloc[i] = obv.iloc[i-1] + volume.iloc[i]
|
||||
elif price_change.iloc[i] < 0:
|
||||
obv.iloc[i] = obv.iloc[i-1] - volume.iloc[i]
|
||||
else:
|
||||
obv.iloc[i] = obv.iloc[i-1]
|
||||
return obv
|
||||
|
||||
def vwap(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
volume: pd.Series
|
||||
) -> pd.Series:
|
||||
"""Volume Weighted Average Price."""
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.vwap(high, low, close, volume)
|
||||
# Basic VWAP implementation
|
||||
typical_price = (high + low + close) / 3
|
||||
cumulative_tp_vol = (typical_price * volume).cumsum()
|
||||
cumulative_vol = volume.cumsum()
|
||||
return cumulative_tp_vol / cumulative_vol
|
||||
|
||||
# Advanced Indicators
|
||||
|
||||
def ichimoku(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
tenkan: int = 9,
|
||||
kijun: int = 26,
|
||||
senkou: int = 52
|
||||
) -> Dict[str, pd.Series]:
|
||||
"""Ichimoku Cloud."""
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
result = ta.ichimoku(high, low, close, tenkan=tenkan, kijun=kijun, senkou=senkou)
|
||||
return {
|
||||
'tenkan': result['ITS_9'],
|
||||
'kijun': result['IKS_26'],
|
||||
'senkou_a': result['ISA_9'],
|
||||
'senkou_b': result['ISB_26'],
|
||||
'chikou': result['ICS_26'],
|
||||
}
|
||||
# Basic Ichimoku implementation
|
||||
tenkan_sen = (high.rolling(window=tenkan).max() + low.rolling(window=tenkan).min()) / 2
|
||||
kijun_sen = (high.rolling(window=kijun).max() + low.rolling(window=kijun).min()) / 2
|
||||
senkou_a = ((tenkan_sen + kijun_sen) / 2).shift(kijun)
|
||||
senkou_b = ((high.rolling(window=senkou).max() + low.rolling(window=senkou).min()) / 2).shift(kijun)
|
||||
chikou = close.shift(-kijun)
|
||||
return {
|
||||
'tenkan': tenkan_sen,
|
||||
'kijun': kijun_sen,
|
||||
'senkou_a': senkou_a,
|
||||
'senkou_b': senkou_b,
|
||||
'chikou': chikou,
|
||||
}
|
||||
|
||||
def adx(
|
||||
self,
|
||||
high: pd.Series,
|
||||
low: pd.Series,
|
||||
close: pd.Series,
|
||||
period: int = 14
|
||||
) -> pd.Series:
|
||||
"""Average Directional Index."""
|
||||
if self.talib_available:
|
||||
return pd.Series(
|
||||
talib.ADX(high.values, low.values, close.values, timeperiod=period),
|
||||
index=close.index
|
||||
)
|
||||
if PANDAS_TA_AVAILABLE and ta is not None:
|
||||
return ta.adx(high, low, close, length=period)
|
||||
# Basic ADX implementation
|
||||
plus_dm = high.diff()
|
||||
minus_dm = low.diff().abs()
|
||||
plus_dm[plus_dm < 0] = 0
|
||||
minus_dm[minus_dm < 0] = 0
|
||||
|
||||
atr_val = self.atr(high, low, close, period)
|
||||
plus_di = 100 * (plus_dm.rolling(window=period).mean() / atr_val)
|
||||
minus_di = 100 * (minus_dm.rolling(window=period).mean() / atr_val)
|
||||
|
||||
dx = 100 * abs(plus_di - minus_di) / (plus_di + minus_di)
|
||||
adx = dx.rolling(window=period).mean()
|
||||
return adx
|
||||
|
||||
def detect_divergence(
|
||||
self,
|
||||
prices: pd.Series,
|
||||
indicator: pd.Series,
|
||||
lookback: int = 20,
|
||||
min_swings: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
"""Detect divergence between price and indicator.
|
||||
|
||||
Divergence occurs when price makes new highs/lows but indicator doesn't,
|
||||
or vice versa. This is a powerful reversal signal.
|
||||
|
||||
Args:
|
||||
prices: Price series
|
||||
indicator: Indicator series (e.g., RSI, MACD)
|
||||
lookback: Lookback period for finding swings
|
||||
min_swings: Minimum number of swings to detect divergence
|
||||
|
||||
Returns:
|
||||
Dictionary with divergence information:
|
||||
{
|
||||
'type': 'bullish', 'bearish', or None
|
||||
'confidence': 0.0 to 1.0
|
||||
'price_swing_high': price at high swing
|
||||
'price_swing_low': price at low swing
|
||||
'indicator_swing_high': indicator at high swing
|
||||
'indicator_swing_low': indicator at low swing
|
||||
}
|
||||
"""
|
||||
if len(prices) < lookback * 2 or len(indicator) < lookback * 2:
|
||||
return {
|
||||
'type': None,
|
||||
'confidence': 0.0,
|
||||
'price_swing_high': None,
|
||||
'price_swing_low': None,
|
||||
'indicator_swing_high': None,
|
||||
'indicator_swing_low': None
|
||||
}
|
||||
|
||||
# Find local extrema (swings)
|
||||
def find_swings(series: pd.Series, lookback: int):
|
||||
"""Find local maxima and minima."""
|
||||
highs = []
|
||||
lows = []
|
||||
|
||||
for i in range(lookback, len(series) - lookback):
|
||||
window = series.iloc[i-lookback:i+lookback+1]
|
||||
center = series.iloc[i]
|
||||
|
||||
# Local maximum
|
||||
if center == window.max():
|
||||
highs.append((i, center))
|
||||
# Local minimum
|
||||
elif center == window.min():
|
||||
lows.append((i, center))
|
||||
|
||||
return highs, lows
|
||||
|
||||
price_highs, price_lows = find_swings(prices, lookback)
|
||||
indicator_highs, indicator_lows = find_swings(indicator, lookback)
|
||||
|
||||
# Need at least min_swings swings to detect divergence
|
||||
if len(price_highs) < min_swings or len(price_lows) < min_swings:
|
||||
return {
|
||||
'type': None,
|
||||
'confidence': 0.0,
|
||||
'price_swing_high': None,
|
||||
'price_swing_low': None,
|
||||
'indicator_swing_high': None,
|
||||
'indicator_swing_low': None
|
||||
}
|
||||
|
||||
# Check for bearish divergence (price makes higher high, indicator makes lower high)
|
||||
if len(price_highs) >= 2 and len(indicator_highs) >= 2:
|
||||
recent_price_high = price_highs[-1][1]
|
||||
prev_price_high = price_highs[-2][1]
|
||||
recent_indicator_high = indicator_highs[-1][1]
|
||||
prev_indicator_high = indicator_highs[-2][1]
|
||||
|
||||
# Price higher high but indicator lower high = bearish divergence
|
||||
if recent_price_high > prev_price_high and recent_indicator_high < prev_indicator_high:
|
||||
confidence = min(1.0, abs(recent_price_high - prev_price_high) / prev_price_high * 10)
|
||||
return {
|
||||
'type': 'bearish',
|
||||
'confidence': confidence,
|
||||
'price_swing_high': (price_highs[-2][0], price_highs[-1][0]),
|
||||
'price_swing_low': None,
|
||||
'indicator_swing_high': (indicator_highs[-2][0], indicator_highs[-1][0]),
|
||||
'indicator_swing_low': None
|
||||
}
|
||||
|
||||
# Check for bullish divergence (price makes lower low, indicator makes higher low)
|
||||
if len(price_lows) >= 2 and len(indicator_lows) >= 2:
|
||||
recent_price_low = price_lows[-1][1]
|
||||
prev_price_low = price_lows[-2][1]
|
||||
recent_indicator_low = indicator_lows[-1][1]
|
||||
prev_indicator_low = indicator_lows[-2][1]
|
||||
|
||||
# Price lower low but indicator higher low = bullish divergence
|
||||
if recent_price_low < prev_price_low and recent_indicator_low > prev_indicator_low:
|
||||
confidence = min(1.0, abs(prev_price_low - recent_price_low) / prev_price_low * 10)
|
||||
return {
|
||||
'type': 'bullish',
|
||||
'confidence': confidence,
|
||||
'price_swing_high': None,
|
||||
'price_swing_low': (price_lows[-2][0], price_lows[-1][0]),
|
||||
'indicator_swing_high': None,
|
||||
'indicator_swing_low': (indicator_lows[-2][0], indicator_lows[-1][0])
|
||||
}
|
||||
|
||||
return {
|
||||
'type': None,
|
||||
'confidence': 0.0,
|
||||
'price_swing_high': None,
|
||||
'price_swing_low': None,
|
||||
'indicator_swing_high': None,
|
||||
'indicator_swing_low': None
|
||||
}
|
||||
|
||||
def calculate_all(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
indicators: Optional[List[str]] = None
|
||||
) -> pd.DataFrame:
|
||||
"""Calculate multiple indicators at once.
|
||||
|
||||
Args:
|
||||
df: DataFrame with OHLCV data (columns: open, high, low, close, volume)
|
||||
indicators: List of indicator names to calculate (None = all)
|
||||
|
||||
Returns:
|
||||
DataFrame with added indicator columns
|
||||
"""
|
||||
result = df.copy()
|
||||
|
||||
if 'close' not in result.columns:
|
||||
raise ValueError("DataFrame must have 'close' column")
|
||||
|
||||
close = result['close']
|
||||
high = result.get('high', close)
|
||||
low = result.get('low', close)
|
||||
volume = result.get('volume', pd.Series(1, index=close.index))
|
||||
|
||||
# Default indicators if none specified
|
||||
if indicators is None:
|
||||
indicators = [
|
||||
'sma_20', 'ema_20', 'rsi', 'macd', 'bollinger_bands',
|
||||
'atr', 'obv', 'adx'
|
||||
]
|
||||
|
||||
for indicator in indicators:
|
||||
try:
|
||||
if indicator.startswith('sma_'):
|
||||
period = int(indicator.split('_')[1])
|
||||
result[f'SMA_{period}'] = self.sma(close, period)
|
||||
elif indicator.startswith('ema_'):
|
||||
period = int(indicator.split('_')[1])
|
||||
result[f'EMA_{period}'] = self.ema(close, period)
|
||||
elif indicator == 'rsi':
|
||||
result['RSI'] = self.rsi(close)
|
||||
elif indicator == 'macd':
|
||||
macd_data = self.macd(close)
|
||||
result['MACD'] = macd_data['macd']
|
||||
result['MACD_Signal'] = macd_data['signal']
|
||||
result['MACD_Histogram'] = macd_data['histogram']
|
||||
elif indicator == 'bollinger_bands':
|
||||
bb_data = self.bollinger_bands(close)
|
||||
result['BB_Upper'] = bb_data['upper']
|
||||
result['BB_Middle'] = bb_data['middle']
|
||||
result['BB_Lower'] = bb_data['lower']
|
||||
elif indicator == 'atr':
|
||||
result['ATR'] = self.atr(high, low, close)
|
||||
elif indicator == 'obv':
|
||||
result['OBV'] = self.obv(close, volume)
|
||||
elif indicator == 'adx':
|
||||
result['ADX'] = self.adx(high, low, close)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate indicator {indicator}: {e}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Global indicators instance
|
||||
_indicators: Optional[TechnicalIndicators] = None
|
||||
|
||||
|
||||
def get_indicators() -> TechnicalIndicators:
|
||||
"""Get global technical indicators instance."""
|
||||
global _indicators
|
||||
if _indicators is None:
|
||||
_indicators = TechnicalIndicators()
|
||||
return _indicators
|
||||
|
||||
447
src/data/news_collector.py
Normal file
447
src/data/news_collector.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""News collector for crypto sentiment analysis.
|
||||
|
||||
Collects headlines from multiple sources:
|
||||
- RSS feeds (CoinDesk, CoinTelegraph, Decrypt, etc.)
|
||||
- CryptoPanic API (optional, requires API key)
|
||||
|
||||
Headlines are cached and refreshed periodically to avoid rate limits.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class NewsSource(str, Enum):
|
||||
"""Supported news sources."""
|
||||
COINDESK = "coindesk"
|
||||
COINTELEGRAPH = "cointelegraph"
|
||||
DECRYPT = "decrypt"
|
||||
BITCOIN_MAGAZINE = "bitcoin_magazine"
|
||||
THE_BLOCK = "the_block"
|
||||
MESSARI = "messari"
|
||||
CRYPTOPANIC = "cryptopanic"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsItem:
|
||||
"""A single news item."""
|
||||
title: str
|
||||
source: NewsSource
|
||||
published: datetime
|
||||
url: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
symbols: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# RSS feed URLs for major crypto news sources
|
||||
RSS_FEEDS: Dict[NewsSource, str] = {
|
||||
NewsSource.COINDESK: "https://www.coindesk.com/arc/outboundfeeds/rss/",
|
||||
NewsSource.COINTELEGRAPH: "https://cointelegraph.com/rss",
|
||||
NewsSource.DECRYPT: "https://decrypt.co/feed",
|
||||
NewsSource.BITCOIN_MAGAZINE: "https://bitcoinmagazine.com/.rss/full/",
|
||||
NewsSource.THE_BLOCK: "https://www.theblock.co/rss.xml",
|
||||
NewsSource.MESSARI: "https://messari.io/rss",
|
||||
}
|
||||
|
||||
# Common crypto symbols to detect in headlines
|
||||
CRYPTO_SYMBOLS = {
|
||||
"BTC": ["bitcoin", "btc"],
|
||||
"ETH": ["ethereum", "eth", "ether"],
|
||||
"SOL": ["solana", "sol"],
|
||||
"XRP": ["ripple", "xrp"],
|
||||
"ADA": ["cardano", "ada"],
|
||||
"DOGE": ["dogecoin", "doge"],
|
||||
"DOT": ["polkadot", "dot"],
|
||||
"AVAX": ["avalanche", "avax"],
|
||||
"MATIC": ["polygon", "matic"],
|
||||
"LINK": ["chainlink", "link"],
|
||||
"UNI": ["uniswap", "uni"],
|
||||
"ATOM": ["cosmos", "atom"],
|
||||
"LTC": ["litecoin", "ltc"],
|
||||
}
|
||||
|
||||
|
||||
class NewsCollector:
|
||||
"""Collects crypto news headlines for sentiment analysis.
|
||||
|
||||
Features:
|
||||
- Aggregates news from multiple RSS feeds
|
||||
- Caches headlines to reduce network requests
|
||||
- Filters by crypto symbol
|
||||
- Optional CryptoPanic integration for more sources
|
||||
|
||||
Usage:
|
||||
collector = NewsCollector()
|
||||
headlines = await collector.fetch_headlines()
|
||||
# Or filter by symbol
|
||||
btc_headlines = await collector.fetch_headlines(symbols=["BTC"])
|
||||
"""
|
||||
|
||||
# Minimum time between fetches (in seconds)
|
||||
MIN_FETCH_INTERVAL = 300 # 5 minutes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sources: Optional[List[NewsSource]] = None,
|
||||
cryptopanic_api_key: Optional[str] = None,
|
||||
cache_duration: int = 600, # 10 minutes
|
||||
max_headlines: int = 50
|
||||
):
|
||||
"""Initialize NewsCollector.
|
||||
|
||||
Args:
|
||||
sources: List of news sources to use. Defaults to all RSS feeds.
|
||||
cryptopanic_api_key: Optional API key for CryptoPanic.
|
||||
cache_duration: How long to cache headlines (seconds).
|
||||
max_headlines: Maximum headlines to keep in cache.
|
||||
"""
|
||||
self.sources = sources or list(RSS_FEEDS.keys())
|
||||
self.cryptopanic_api_key = cryptopanic_api_key
|
||||
self.cache_duration = cache_duration
|
||||
self.max_headlines = max_headlines
|
||||
|
||||
self._cache: List[NewsItem] = []
|
||||
self._last_fetch: Optional[datetime] = None
|
||||
self._fetching = False
|
||||
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Check if feedparser is available
|
||||
try:
|
||||
import feedparser
|
||||
self._feedparser = feedparser
|
||||
self._feedparser_available = True
|
||||
except ImportError:
|
||||
self._feedparser_available = False
|
||||
self.logger.warning(
|
||||
"feedparser not installed. Install with: pip install feedparser"
|
||||
)
|
||||
|
||||
def _extract_symbols(self, text: str) -> List[str]:
|
||||
"""Extract crypto symbols mentioned in text.
|
||||
|
||||
Args:
|
||||
text: Text to search (headline, summary)
|
||||
|
||||
Returns:
|
||||
List of detected symbol codes (e.g., ["BTC", "ETH"])
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
detected = []
|
||||
|
||||
for symbol, keywords in CRYPTO_SYMBOLS.items():
|
||||
for keyword in keywords:
|
||||
if keyword in text_lower:
|
||||
detected.append(symbol)
|
||||
break
|
||||
|
||||
return detected
|
||||
|
||||
async def _fetch_rss_feed(self, source: NewsSource) -> List[NewsItem]:
|
||||
"""Fetch and parse a single RSS feed.
|
||||
|
||||
Args:
|
||||
source: News source to fetch
|
||||
|
||||
Returns:
|
||||
List of NewsItems from the feed
|
||||
"""
|
||||
if not self._feedparser_available:
|
||||
return []
|
||||
|
||||
url = RSS_FEEDS.get(source)
|
||||
if not url:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Run feedparser in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
feed = await loop.run_in_executor(
|
||||
None,
|
||||
self._feedparser.parse,
|
||||
url
|
||||
)
|
||||
|
||||
items = []
|
||||
for entry in feed.entries[:20]: # Limit entries per feed
|
||||
# Parse publication date
|
||||
published = datetime.now()
|
||||
if hasattr(entry, 'published_parsed') and entry.published_parsed:
|
||||
try:
|
||||
published = datetime(*entry.published_parsed[:6])
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
title = entry.get('title', '')
|
||||
summary = entry.get('summary', '')
|
||||
|
||||
# Clean HTML from summary
|
||||
summary = re.sub(r'<[^>]+>', '', summary)[:200]
|
||||
|
||||
item = NewsItem(
|
||||
title=title,
|
||||
source=source,
|
||||
published=published,
|
||||
url=entry.get('link'),
|
||||
summary=summary,
|
||||
symbols=self._extract_symbols(f"{title} {summary}")
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
self.logger.debug(f"Fetched {len(items)} items from {source.value}")
|
||||
return items
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to fetch {source.value} RSS: {e}")
|
||||
return []
|
||||
|
||||
async def _fetch_cryptopanic(self, symbols: Optional[List[str]] = None) -> List[NewsItem]:
|
||||
"""Fetch news from CryptoPanic API.
|
||||
|
||||
Args:
|
||||
symbols: Optional list of symbols to filter
|
||||
|
||||
Returns:
|
||||
List of NewsItems from CryptoPanic
|
||||
"""
|
||||
if not self.cryptopanic_api_key:
|
||||
return []
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
url = "https://cryptopanic.com/api/v1/posts/"
|
||||
params = {
|
||||
"auth_token": self.cryptopanic_api_key,
|
||||
"public": "true",
|
||||
}
|
||||
|
||||
if symbols:
|
||||
params["currencies"] = ",".join(symbols)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, timeout=10) as response:
|
||||
if response.status != 200:
|
||||
self.logger.warning(f"CryptoPanic API error: {response.status}")
|
||||
return []
|
||||
|
||||
data = await response.json()
|
||||
|
||||
items = []
|
||||
for post in data.get("results", [])[:20]:
|
||||
published = datetime.now()
|
||||
if post.get("published_at"):
|
||||
try:
|
||||
published = datetime.fromisoformat(
|
||||
post["published_at"].replace("Z", "+00:00")
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
item = NewsItem(
|
||||
title=post.get("title", ""),
|
||||
source=NewsSource.CRYPTOPANIC,
|
||||
published=published,
|
||||
url=post.get("url"),
|
||||
symbols=[c["code"] for c in post.get("currencies", [])]
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
self.logger.debug(f"Fetched {len(items)} items from CryptoPanic")
|
||||
return items
|
||||
|
||||
except ImportError:
|
||||
self.logger.warning("aiohttp not installed for CryptoPanic API")
|
||||
return []
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to fetch CryptoPanic: {e}")
|
||||
return []
|
||||
|
||||
async def fetch_news(
|
||||
self,
|
||||
symbols: Optional[List[str]] = None,
|
||||
force_refresh: bool = False
|
||||
) -> List[NewsItem]:
|
||||
"""Fetch news items from all sources.
|
||||
|
||||
Args:
|
||||
symbols: Optional list of symbols to filter (e.g., ["BTC", "ETH"])
|
||||
force_refresh: Force a refresh even if cache is valid
|
||||
|
||||
Returns:
|
||||
List of NewsItems sorted by publication date (newest first)
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
# Check cache validity
|
||||
cache_valid = (
|
||||
self._last_fetch is not None and
|
||||
(now - self._last_fetch).total_seconds() < self.cache_duration and
|
||||
len(self._cache) > 0
|
||||
)
|
||||
|
||||
if cache_valid and not force_refresh:
|
||||
self.logger.debug("Using cached news items")
|
||||
items = self._cache
|
||||
else:
|
||||
# Prevent concurrent fetches
|
||||
if self._fetching:
|
||||
self.logger.debug("Fetch already in progress, using cache")
|
||||
items = self._cache
|
||||
else:
|
||||
self._fetching = True
|
||||
try:
|
||||
items = await self._fetch_all_sources()
|
||||
self._cache = items
|
||||
self._last_fetch = now
|
||||
finally:
|
||||
self._fetching = False
|
||||
|
||||
# Filter by symbols if specified
|
||||
if symbols:
|
||||
symbols_upper = [s.upper() for s in symbols]
|
||||
items = [
|
||||
item for item in items
|
||||
if any(s in symbols_upper for s in item.symbols) or not item.symbols
|
||||
]
|
||||
|
||||
return items
|
||||
|
||||
async def _fetch_all_sources(self) -> List[NewsItem]:
|
||||
"""Fetch from all configured sources concurrently."""
|
||||
tasks = []
|
||||
|
||||
# RSS feeds
|
||||
for source in self.sources:
|
||||
if source in RSS_FEEDS:
|
||||
tasks.append(self._fetch_rss_feed(source))
|
||||
|
||||
# CryptoPanic
|
||||
if self.cryptopanic_api_key:
|
||||
tasks.append(self._fetch_cryptopanic())
|
||||
|
||||
if not tasks:
|
||||
self.logger.warning("No news sources configured")
|
||||
return []
|
||||
|
||||
# Fetch all concurrently
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Combine results
|
||||
all_items = []
|
||||
for result in results:
|
||||
if isinstance(result, list):
|
||||
all_items.extend(result)
|
||||
elif isinstance(result, Exception):
|
||||
self.logger.warning(f"Source fetch failed: {result}")
|
||||
|
||||
# Sort by publication date (newest first)
|
||||
all_items.sort(key=lambda x: x.published, reverse=True)
|
||||
|
||||
# Limit total items
|
||||
all_items = all_items[:self.max_headlines]
|
||||
|
||||
self.logger.info(f"Fetched {len(all_items)} total news items")
|
||||
return all_items
|
||||
|
||||
async def fetch_headlines(
|
||||
self,
|
||||
symbols: Optional[List[str]] = None,
|
||||
max_age_hours: int = 24,
|
||||
force_refresh: bool = False
|
||||
) -> List[str]:
|
||||
"""Fetch headlines as strings for sentiment analysis.
|
||||
|
||||
This is the main method to use with SentimentScanner.
|
||||
|
||||
Args:
|
||||
symbols: Optional list of symbols to filter
|
||||
max_age_hours: Only include headlines from the last N hours
|
||||
force_refresh: Force a refresh even if cache is valid
|
||||
|
||||
Returns:
|
||||
List of headline strings
|
||||
"""
|
||||
items = await self.fetch_news(symbols=symbols, force_refresh=force_refresh)
|
||||
|
||||
# Filter by age
|
||||
cutoff = datetime.now() - timedelta(hours=max_age_hours)
|
||||
recent_items = [item for item in items if item.published > cutoff]
|
||||
|
||||
# Extract just the titles
|
||||
headlines = [item.title for item in recent_items if item.title]
|
||||
|
||||
self.logger.debug(f"Returning {len(headlines)} headlines for analysis")
|
||||
return headlines
|
||||
|
||||
def get_cached_headlines(self, symbols: Optional[List[str]] = None) -> List[str]:
|
||||
"""Get cached headlines synchronously (no fetch).
|
||||
|
||||
Args:
|
||||
symbols: Optional list of symbols to filter
|
||||
|
||||
Returns:
|
||||
List of cached headline strings
|
||||
"""
|
||||
items = self._cache
|
||||
|
||||
if symbols:
|
||||
symbols_upper = [s.upper() for s in symbols]
|
||||
items = [
|
||||
item for item in items
|
||||
if any(s in symbols_upper for s in item.symbols) or not item.symbols
|
||||
]
|
||||
|
||||
return [item.title for item in items if item.title]
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get collector status information.
|
||||
|
||||
Returns:
|
||||
Dictionary with status info
|
||||
"""
|
||||
return {
|
||||
"sources": [s.value for s in self.sources],
|
||||
"cryptopanic_enabled": self.cryptopanic_api_key is not None,
|
||||
"feedparser_available": self._feedparser_available,
|
||||
"cached_items": len(self._cache),
|
||||
"last_fetch": self._last_fetch.isoformat() if self._last_fetch else None,
|
||||
"cache_age_seconds": (
|
||||
(datetime.now() - self._last_fetch).total_seconds()
|
||||
if self._last_fetch else None
|
||||
),
|
||||
}
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the headline cache."""
|
||||
self._cache = []
|
||||
self._last_fetch = None
|
||||
self.logger.info("News cache cleared")
|
||||
|
||||
|
||||
# Global instance
|
||||
_news_collector: Optional[NewsCollector] = None
|
||||
|
||||
|
||||
def get_news_collector(**kwargs) -> NewsCollector:
|
||||
"""Get or create the global NewsCollector instance.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments passed to NewsCollector constructor
|
||||
|
||||
Returns:
|
||||
NewsCollector instance
|
||||
"""
|
||||
global _news_collector
|
||||
if _news_collector is None:
|
||||
_news_collector = NewsCollector(**kwargs)
|
||||
logger.info("Created global NewsCollector instance")
|
||||
return _news_collector
|
||||
406
src/data/pricing_service.py
Normal file
406
src/data/pricing_service.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""Unified pricing data service with multi-provider support and automatic failover."""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from .providers.base_provider import BasePricingProvider
|
||||
from .providers.ccxt_provider import CCXTProvider
|
||||
from .providers.coingecko_provider import CoinGeckoProvider
|
||||
from .cache_manager import CacheManager
|
||||
from .health_monitor import HealthMonitor, HealthStatus
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PricingService:
|
||||
"""Unified pricing data service with multi-provider support.
|
||||
|
||||
Manages multiple pricing providers with automatic failover, caching,
|
||||
and health monitoring. Provides a single consistent API for accessing
|
||||
market data regardless of the underlying provider.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize pricing service."""
|
||||
self.config = get_config()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Initialize components
|
||||
self.cache = CacheManager(
|
||||
default_ttl=self.config.get("data_providers.caching.ohlcv_ttl", 60),
|
||||
max_size=self.config.get("data_providers.caching.max_cache_size", 1000),
|
||||
ticker_ttl=self.config.get("data_providers.caching.ticker_ttl", 2),
|
||||
ohlcv_ttl=self.config.get("data_providers.caching.ohlcv_ttl", 60),
|
||||
)
|
||||
|
||||
self.health_monitor = HealthMonitor()
|
||||
|
||||
# Provider instances
|
||||
self._providers: Dict[str, BasePricingProvider] = {}
|
||||
self._active_provider: Optional[str] = None
|
||||
self._provider_priority: List[str] = []
|
||||
|
||||
# Subscriptions
|
||||
self._subscriptions: Dict[str, List[Callable]] = {}
|
||||
|
||||
# Initialize providers
|
||||
self._initialize_providers()
|
||||
|
||||
def _initialize_providers(self):
|
||||
"""Initialize providers from configuration."""
|
||||
# Get primary providers from config
|
||||
primary_config = self.config.get("data_providers.primary", [])
|
||||
if not primary_config:
|
||||
# Default configuration
|
||||
primary_config = [
|
||||
{'name': 'kraken', 'enabled': True, 'priority': 1},
|
||||
{'name': 'coinbase', 'enabled': True, 'priority': 2},
|
||||
{'name': 'binance', 'enabled': True, 'priority': 3},
|
||||
]
|
||||
|
||||
# Sort by priority
|
||||
primary_config = sorted(
|
||||
[p for p in primary_config if p.get('enabled', True)],
|
||||
key=lambda x: x.get('priority', 999)
|
||||
)
|
||||
|
||||
# Create CCXT providers for each exchange
|
||||
for provider_config in primary_config:
|
||||
exchange_name = provider_config.get('name')
|
||||
try:
|
||||
provider = CCXTProvider(exchange_name=exchange_name)
|
||||
provider_name = provider.name
|
||||
|
||||
if provider.connect():
|
||||
self._providers[provider_name] = provider
|
||||
self._provider_priority.append(provider_name)
|
||||
self.logger.info(f"Initialized provider: {provider_name}")
|
||||
else:
|
||||
self.logger.warning(f"Failed to connect provider: {provider_name}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error initializing provider {exchange_name}: {e}")
|
||||
|
||||
# Add fallback provider (CoinGecko)
|
||||
fallback_config = self.config.get("data_providers.fallback", {})
|
||||
if fallback_config.get('enabled', True):
|
||||
try:
|
||||
coingecko = CoinGeckoProvider(api_key=fallback_config.get('api_key'))
|
||||
if coingecko.connect():
|
||||
self._providers[coingecko.name] = coingecko
|
||||
self._provider_priority.append(coingecko.name)
|
||||
self.logger.info(f"Initialized fallback provider: {coingecko.name}")
|
||||
else:
|
||||
self.logger.warning("Failed to connect CoinGecko fallback provider")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error initializing CoinGecko provider: {e}")
|
||||
|
||||
# Select initial active provider
|
||||
self._select_active_provider()
|
||||
|
||||
def _select_active_provider(self) -> Optional[str]:
|
||||
"""Select the best available provider.
|
||||
|
||||
Returns:
|
||||
Name of selected provider or None
|
||||
"""
|
||||
# Filter to healthy providers
|
||||
healthy_providers = [
|
||||
name for name in self._provider_priority
|
||||
if name in self._providers
|
||||
and self.health_monitor.is_healthy(name)
|
||||
]
|
||||
|
||||
if not healthy_providers:
|
||||
# Fall back to any available provider if none are healthy
|
||||
healthy_providers = list(self._providers.keys())
|
||||
|
||||
if not healthy_providers:
|
||||
self.logger.error("No providers available")
|
||||
self._active_provider = None
|
||||
return None
|
||||
|
||||
# Select first healthy provider (already sorted by priority)
|
||||
self._active_provider = healthy_providers[0]
|
||||
self.logger.info(f"Selected active provider: {self._active_provider}")
|
||||
return self._active_provider
|
||||
|
||||
def _get_provider(self, provider_name: Optional[str] = None) -> Optional[BasePricingProvider]:
|
||||
"""Get a provider instance.
|
||||
|
||||
Args:
|
||||
provider_name: Name of provider, or None to use active provider
|
||||
|
||||
Returns:
|
||||
Provider instance or None
|
||||
"""
|
||||
if provider_name:
|
||||
return self._providers.get(provider_name)
|
||||
|
||||
# Use active provider, or select one if none active
|
||||
if not self._active_provider:
|
||||
self._select_active_provider()
|
||||
|
||||
return self._providers.get(self._active_provider) if self._active_provider else None
|
||||
|
||||
def _execute_with_failover(
|
||||
self,
|
||||
operation: Callable[[BasePricingProvider], Any],
|
||||
operation_name: str
|
||||
) -> Any:
|
||||
"""Execute an operation with automatic failover.
|
||||
|
||||
Args:
|
||||
operation: Function that takes a provider and returns a result
|
||||
operation_name: Name of operation for logging
|
||||
|
||||
Returns:
|
||||
Operation result or None if all providers fail
|
||||
"""
|
||||
# Try active provider first
|
||||
providers_to_try = [self._active_provider] if self._active_provider else []
|
||||
|
||||
# Add other providers in priority order
|
||||
for provider_name in self._provider_priority:
|
||||
if provider_name != self._active_provider and provider_name in self._providers:
|
||||
providers_to_try.append(provider_name)
|
||||
|
||||
last_error = None
|
||||
|
||||
for provider_name in providers_to_try:
|
||||
provider = self._providers.get(provider_name)
|
||||
if not provider:
|
||||
continue
|
||||
|
||||
# Check health
|
||||
if not self.health_monitor.is_healthy(provider_name):
|
||||
self.logger.debug(f"Skipping unhealthy provider: {provider_name}")
|
||||
continue
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
result = operation(provider)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# Record success
|
||||
self.health_monitor.record_success(provider_name, response_time)
|
||||
|
||||
# Update active provider if we used a different one
|
||||
if provider_name != self._active_provider:
|
||||
self.logger.info(f"Switched to provider: {provider_name}")
|
||||
self._active_provider = provider_name
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.warning(f"{operation_name} failed on {provider_name}: {e}")
|
||||
self.health_monitor.record_failure(provider_name)
|
||||
|
||||
# Try next provider
|
||||
continue
|
||||
|
||||
# All providers failed
|
||||
self.logger.error(f"{operation_name} failed on all providers")
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return None
|
||||
|
||||
def get_ticker(self, symbol: str, use_cache: bool = True) -> Dict[str, Any]:
|
||||
"""Get current ticker data for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'BTC/USD')
|
||||
use_cache: Whether to use cache
|
||||
|
||||
Returns:
|
||||
Ticker data dictionary
|
||||
"""
|
||||
cache_key = f"ticker:{symbol}"
|
||||
|
||||
# Check cache
|
||||
if use_cache:
|
||||
cached = self.cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Fetch from provider
|
||||
def fetch_ticker(provider: BasePricingProvider):
|
||||
return provider.get_ticker(symbol)
|
||||
|
||||
ticker_data = self._execute_with_failover(fetch_ticker, f"get_ticker({symbol})")
|
||||
|
||||
if ticker_data:
|
||||
# Cache the result
|
||||
if use_cache:
|
||||
self.cache.set(cache_key, ticker_data, cache_type='ticker')
|
||||
|
||||
return ticker_data
|
||||
|
||||
return {}
|
||||
|
||||
def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100,
|
||||
use_cache: bool = True
|
||||
) -> List[List]:
|
||||
"""Get OHLCV candlestick data.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
|
||||
since: Start datetime
|
||||
limit: Number of candles
|
||||
use_cache: Whether to use cache
|
||||
|
||||
Returns:
|
||||
List of [timestamp_ms, open, high, low, close, volume]
|
||||
"""
|
||||
cache_key = f"ohlcv:{symbol}:{timeframe}:{limit}"
|
||||
|
||||
# Check cache (only if no 'since' parameter, as it changes the result)
|
||||
if use_cache and not since:
|
||||
cached = self.cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Fetch from provider
|
||||
def fetch_ohlcv(provider: BasePricingProvider):
|
||||
return provider.get_ohlcv(symbol, timeframe, since, limit)
|
||||
|
||||
ohlcv_data = self._execute_with_failover(
|
||||
fetch_ohlcv,
|
||||
f"get_ohlcv({symbol}, {timeframe})"
|
||||
)
|
||||
|
||||
if ohlcv_data:
|
||||
# Cache the result (only if no 'since' parameter)
|
||||
if use_cache and not since:
|
||||
self.cache.set(cache_key, ohlcv_data, cache_type='ohlcv')
|
||||
|
||||
return ohlcv_data
|
||||
|
||||
return []
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
|
||||
"""Subscribe to ticker updates.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Callback function(data) called on price updates
|
||||
|
||||
Returns:
|
||||
True if subscription successful
|
||||
"""
|
||||
key = f"ticker:{symbol}"
|
||||
|
||||
# Add callback
|
||||
if key not in self._subscriptions:
|
||||
self._subscriptions[key] = []
|
||||
if callback not in self._subscriptions[key]:
|
||||
self._subscriptions[key].append(callback)
|
||||
|
||||
# Wrap callback to handle failover
|
||||
def wrapped_callback(data):
|
||||
for cb in self._subscriptions.get(key, []):
|
||||
try:
|
||||
cb(data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
# Subscribe via active provider
|
||||
provider = self._get_provider()
|
||||
if provider:
|
||||
try:
|
||||
success = provider.subscribe_ticker(symbol, wrapped_callback)
|
||||
if success:
|
||||
self.logger.info(f"Subscribed to ticker updates for {symbol}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to subscribe to ticker for {symbol}: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
|
||||
"""Unsubscribe from ticker updates.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Specific callback to remove, or None to remove all
|
||||
"""
|
||||
key = f"ticker:{symbol}"
|
||||
|
||||
# Remove callback
|
||||
if key in self._subscriptions:
|
||||
if callback:
|
||||
if callback in self._subscriptions[key]:
|
||||
self._subscriptions[key].remove(callback)
|
||||
if not self._subscriptions[key]:
|
||||
del self._subscriptions[key]
|
||||
else:
|
||||
del self._subscriptions[key]
|
||||
|
||||
# Unsubscribe from all providers
|
||||
for provider in self._providers.values():
|
||||
try:
|
||||
provider.unsubscribe_ticker(symbol, callback)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.logger.info(f"Unsubscribed from ticker updates for {symbol}")
|
||||
|
||||
def get_active_provider(self) -> Optional[str]:
|
||||
"""Get name of active provider.
|
||||
|
||||
Returns:
|
||||
Provider name or None
|
||||
"""
|
||||
return self._active_provider
|
||||
|
||||
def get_provider_health(self, provider_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get health status for a provider or all providers.
|
||||
|
||||
Args:
|
||||
provider_name: Provider name, or None for all providers
|
||||
|
||||
Returns:
|
||||
Health status dictionary
|
||||
"""
|
||||
if provider_name:
|
||||
metrics = self.health_monitor.get_metrics(provider_name)
|
||||
if metrics:
|
||||
return metrics.to_dict()
|
||||
return {}
|
||||
|
||||
return self.health_monitor.get_all_metrics()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Cache statistics dictionary
|
||||
"""
|
||||
return self.cache.get_stats()
|
||||
|
||||
|
||||
# Global pricing service instance
|
||||
_pricing_service: Optional[PricingService] = None
|
||||
|
||||
|
||||
def get_pricing_service() -> PricingService:
|
||||
"""Get global pricing service instance.
|
||||
|
||||
Returns:
|
||||
PricingService instance
|
||||
"""
|
||||
global _pricing_service
|
||||
if _pricing_service is None:
|
||||
_pricing_service = PricingService()
|
||||
return _pricing_service
|
||||
7
src/data/providers/__init__.py
Normal file
7
src/data/providers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Pricing data providers package."""
|
||||
|
||||
from .base_provider import BasePricingProvider
|
||||
from .ccxt_provider import CCXTProvider
|
||||
from .coingecko_provider import CoinGeckoProvider
|
||||
|
||||
__all__ = ['BasePricingProvider', 'CCXTProvider', 'CoinGeckoProvider']
|
||||
150
src/data/providers/base_provider.py
Normal file
150
src/data/providers/base_provider.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Base pricing provider interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from datetime import datetime
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BasePricingProvider(ABC):
|
||||
"""Base class for pricing data providers.
|
||||
|
||||
Pricing providers are responsible for fetching market data (prices, OHLCV)
|
||||
without requiring API keys. They differ from exchange adapters which handle
|
||||
trading operations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize pricing provider."""
|
||||
self.logger = get_logger(f"provider.{self.__class__.__name__}")
|
||||
self._connected = False
|
||||
self._subscribers: Dict[str, List[Callable]] = {}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_websocket(self) -> bool:
|
||||
"""Whether this provider supports WebSocket connections."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def connect(self) -> bool:
|
||||
"""Connect to the provider.
|
||||
|
||||
Returns:
|
||||
True if connection successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def disconnect(self):
|
||||
"""Disconnect from the provider."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'BTC/USD')
|
||||
|
||||
Returns:
|
||||
Dictionary with ticker data:
|
||||
- 'symbol': str
|
||||
- 'bid': Decimal
|
||||
- 'ask': Decimal
|
||||
- 'last': Decimal (last price)
|
||||
- 'high': Decimal (24h high)
|
||||
- 'low': Decimal (24h low)
|
||||
- 'volume': Decimal (24h volume)
|
||||
- 'timestamp': int (Unix timestamp in milliseconds)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV (candlestick) data.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
|
||||
since: Start datetime
|
||||
limit: Number of candles
|
||||
|
||||
Returns:
|
||||
List of [timestamp_ms, open, high, low, close, volume]
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
|
||||
"""Subscribe to ticker updates.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Callback function(data) called on price updates
|
||||
|
||||
Returns:
|
||||
True if subscription successful
|
||||
"""
|
||||
pass
|
||||
|
||||
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
|
||||
"""Unsubscribe from ticker updates.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
callback: Specific callback to remove, or None to remove all
|
||||
"""
|
||||
key = f"ticker_{symbol}"
|
||||
if key in self._subscribers:
|
||||
if callback:
|
||||
if callback in self._subscribers[key]:
|
||||
self._subscribers[key].remove(callback)
|
||||
if not self._subscribers[key]:
|
||||
del self._subscribers[key]
|
||||
else:
|
||||
del self._subscribers[key]
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol format for this provider.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to normalize
|
||||
|
||||
Returns:
|
||||
Normalized symbol
|
||||
"""
|
||||
# Default: uppercase and replace dashes with slashes
|
||||
return symbol.upper().replace('-', '/')
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if provider is connected.
|
||||
|
||||
Returns:
|
||||
True if connected
|
||||
"""
|
||||
return self._connected
|
||||
|
||||
def get_supported_symbols(self) -> List[str]:
|
||||
"""Get list of supported trading symbols.
|
||||
|
||||
Returns:
|
||||
List of supported symbols, or empty list if not available
|
||||
"""
|
||||
# Default implementation - override in subclasses
|
||||
return []
|
||||
333
src/data/providers/ccxt_provider.py
Normal file
333
src/data/providers/ccxt_provider.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""CCXT-based pricing provider with multi-exchange support and WebSocket capabilities."""
|
||||
|
||||
import ccxt
|
||||
import threading
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from datetime import datetime
|
||||
from .base_provider import BasePricingProvider
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CCXTProvider(BasePricingProvider):
|
||||
"""CCXT-based pricing provider with multi-exchange fallback support.
|
||||
|
||||
This provider uses CCXT to connect to multiple exchanges (Kraken, Coinbase, Binance)
|
||||
as primary data sources. It supports WebSocket where available and falls back to
|
||||
polling if WebSocket is not available.
|
||||
"""
|
||||
|
||||
def __init__(self, exchange_name: Optional[str] = None):
|
||||
"""Initialize CCXT provider.
|
||||
|
||||
Args:
|
||||
exchange_name: Specific exchange to use ('kraken', 'coinbase', 'binance'),
|
||||
or None to try multiple exchanges
|
||||
"""
|
||||
super().__init__()
|
||||
self.exchange_name = exchange_name
|
||||
self.exchange = None
|
||||
self._selected_exchange_id = None
|
||||
self._polling_threads: Dict[str, threading.Thread] = {}
|
||||
self._stop_polling: Dict[str, bool] = {}
|
||||
|
||||
# Exchange priority order (try first to last)
|
||||
if exchange_name:
|
||||
self._exchange_options = [(exchange_name.lower(), exchange_name.capitalize())]
|
||||
else:
|
||||
self._exchange_options = [
|
||||
('kraken', 'Kraken'),
|
||||
('coinbase', 'Coinbase'),
|
||||
('binance', 'Binance'),
|
||||
]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Provider name."""
|
||||
if self._selected_exchange_id:
|
||||
return f"CCXT-{self._selected_exchange_id.capitalize()}"
|
||||
return "CCXT Provider"
|
||||
|
||||
@property
|
||||
def supports_websocket(self) -> bool:
|
||||
"""Check if current exchange supports WebSocket."""
|
||||
if not self.exchange:
|
||||
return False
|
||||
|
||||
# Check if exchange has WebSocket support
|
||||
# Most modern CCXT exchanges support WebSocket, but we check explicitly
|
||||
exchange_id = getattr(self.exchange, 'id', '').lower()
|
||||
|
||||
# Known WebSocket-capable exchanges
|
||||
ws_capable = ['kraken', 'coinbase', 'binance', 'binanceus', 'okx']
|
||||
|
||||
return exchange_id in ws_capable
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect to an exchange via CCXT.
|
||||
|
||||
Tries multiple exchanges in order until one succeeds.
|
||||
|
||||
Returns:
|
||||
True if connection successful
|
||||
"""
|
||||
for exchange_id, exchange_display_name in self._exchange_options:
|
||||
try:
|
||||
# Get exchange class from CCXT
|
||||
exchange_class = getattr(ccxt, exchange_id, None)
|
||||
if not exchange_class:
|
||||
logger.warning(f"Exchange {exchange_id} not found in CCXT")
|
||||
continue
|
||||
|
||||
# Create exchange instance
|
||||
self.exchange = exchange_class({
|
||||
'enableRateLimit': True,
|
||||
'options': {
|
||||
'defaultType': 'spot',
|
||||
}
|
||||
})
|
||||
|
||||
# Load markets to test connection
|
||||
if not hasattr(self.exchange, 'markets') or not self.exchange.markets:
|
||||
try:
|
||||
self.exchange.load_markets()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load markets for {exchange_id}: {e}")
|
||||
continue
|
||||
|
||||
# Test connection with a common symbol
|
||||
test_symbols = ['BTC/USDT', 'BTC/USD', 'BTC/EUR']
|
||||
ticker_result = None
|
||||
|
||||
for test_symbol in test_symbols:
|
||||
try:
|
||||
ticker_result = self.exchange.fetch_ticker(test_symbol)
|
||||
if ticker_result:
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not ticker_result:
|
||||
logger.warning(f"Could not fetch ticker from {exchange_id}")
|
||||
continue
|
||||
|
||||
# Success!
|
||||
self._selected_exchange_id = exchange_id
|
||||
self._connected = True
|
||||
logger.info(f"Connected to {exchange_display_name} via CCXT")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to connect to {exchange_id}: {e}")
|
||||
continue
|
||||
|
||||
# All exchanges failed
|
||||
logger.error("Failed to connect to any CCXT exchange")
|
||||
self._connected = False
|
||||
self.exchange = None
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from exchange."""
|
||||
# Stop all polling threads
|
||||
for symbol in list(self._stop_polling.keys()):
|
||||
self._stop_polling[symbol] = True
|
||||
|
||||
# Wait a bit for threads to stop
|
||||
for symbol, thread in self._polling_threads.items():
|
||||
if thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
self._polling_threads.clear()
|
||||
self._stop_polling.clear()
|
||||
self._subscribers.clear()
|
||||
self._connected = False
|
||||
self.exchange = None
|
||||
self._selected_exchange_id = None
|
||||
logger.info(f"Disconnected from {self.name}")
|
||||
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data."""
|
||||
if not self._connected or not self.exchange:
|
||||
logger.error("Provider not connected")
|
||||
return {}
|
||||
|
||||
try:
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
ticker = self.exchange.fetch_ticker(normalized_symbol)
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': Decimal(str(ticker.get('bid', 0) or 0)),
|
||||
'ask': Decimal(str(ticker.get('ask', 0) or 0)),
|
||||
'last': Decimal(str(ticker.get('last', 0) or ticker.get('close', 0) or 0)),
|
||||
'high': Decimal(str(ticker.get('high', 0) or 0)),
|
||||
'low': Decimal(str(ticker.get('low', 0) or 0)),
|
||||
'volume': Decimal(str(ticker.get('quoteVolume', ticker.get('volume', 0)) or 0)),
|
||||
'timestamp': ticker.get('timestamp', int(time.time() * 1000)),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ticker for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV candlestick data."""
|
||||
if not self._connected or not self.exchange:
|
||||
logger.error("Provider not connected")
|
||||
return []
|
||||
|
||||
try:
|
||||
since_timestamp = int(since.timestamp() * 1000) if since else None
|
||||
normalized_symbol = self.normalize_symbol(symbol)
|
||||
|
||||
# Most exchanges support up to 1000 candles per request
|
||||
max_limit = min(limit, 1000)
|
||||
|
||||
ohlcv = self.exchange.fetch_ohlcv(
|
||||
normalized_symbol,
|
||||
timeframe,
|
||||
since_timestamp,
|
||||
max_limit
|
||||
)
|
||||
|
||||
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} ({timeframe})")
|
||||
return ohlcv
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
|
||||
"""Subscribe to ticker updates.
|
||||
|
||||
Uses WebSocket if available, otherwise falls back to polling.
|
||||
"""
|
||||
if not self._connected or not self.exchange:
|
||||
logger.error("Provider not connected")
|
||||
return False
|
||||
|
||||
key = f"ticker_{symbol}"
|
||||
|
||||
# Add callback to subscribers
|
||||
if key not in self._subscribers:
|
||||
self._subscribers[key] = []
|
||||
if callback not in self._subscribers[key]:
|
||||
self._subscribers[key].append(callback)
|
||||
|
||||
# Try WebSocket first if supported
|
||||
if self.supports_websocket:
|
||||
try:
|
||||
# CCXT WebSocket support varies by exchange
|
||||
# For now, use polling as fallback since WebSocket implementation
|
||||
# in CCXT requires exchange-specific handling
|
||||
# TODO: Implement native WebSocket when CCXT adds better support
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket subscription failed, falling back to polling: {e}")
|
||||
|
||||
# Use polling as primary method (more reliable across exchanges)
|
||||
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
|
||||
self._stop_polling[key] = False
|
||||
|
||||
def poll_ticker():
|
||||
"""Poll ticker every 2 seconds."""
|
||||
while not self._stop_polling.get(key, False):
|
||||
try:
|
||||
ticker_data = self.get_ticker(symbol)
|
||||
if ticker_data and 'last' in ticker_data:
|
||||
# Call all callbacks for this symbol
|
||||
for cb in self._subscribers.get(key, []):
|
||||
try:
|
||||
cb({
|
||||
'symbol': symbol,
|
||||
'price': ticker_data['last'],
|
||||
'bid': ticker_data.get('bid', 0),
|
||||
'ask': ticker_data.get('ask', 0),
|
||||
'volume': ticker_data.get('volume', 0),
|
||||
'timestamp': ticker_data.get('timestamp'),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
time.sleep(2) # Poll every 2 seconds
|
||||
except Exception as e:
|
||||
logger.error(f"Ticker polling error for {symbol}: {e}")
|
||||
time.sleep(5) # Wait longer on error
|
||||
|
||||
thread = threading.Thread(target=poll_ticker, daemon=True)
|
||||
thread.start()
|
||||
self._polling_threads[key] = thread
|
||||
logger.info(f"Subscribed to ticker updates for {symbol} (polling mode)")
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
|
||||
"""Unsubscribe from ticker updates."""
|
||||
key = f"ticker_{symbol}"
|
||||
|
||||
# Stop polling thread
|
||||
if key in self._stop_polling:
|
||||
self._stop_polling[key] = True
|
||||
|
||||
# Remove from subscribers
|
||||
super().unsubscribe_ticker(symbol, callback)
|
||||
|
||||
# Clean up thread reference
|
||||
if key in self._polling_threads:
|
||||
del self._polling_threads[key]
|
||||
|
||||
logger.info(f"Unsubscribed from ticker updates for {symbol}")
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol for the selected exchange."""
|
||||
if not self.exchange:
|
||||
return symbol.replace('-', '/').upper()
|
||||
|
||||
# Basic normalization
|
||||
normalized = symbol.replace('-', '/').upper()
|
||||
|
||||
# Try to use exchange's markets to find correct symbol
|
||||
try:
|
||||
if hasattr(self.exchange, 'markets') and self.exchange.markets:
|
||||
# Check if symbol exists
|
||||
if normalized in self.exchange.markets:
|
||||
return normalized
|
||||
|
||||
# Try alternative formats
|
||||
if '/USD' in normalized:
|
||||
alt_symbol = normalized.replace('/USD', '/USDT')
|
||||
if alt_symbol in self.exchange.markets:
|
||||
return alt_symbol
|
||||
|
||||
# For Kraken: try XBT instead of BTC
|
||||
if normalized.startswith('BTC/'):
|
||||
alt_symbol = normalized.replace('BTC/', 'XBT/')
|
||||
if alt_symbol in self.exchange.markets:
|
||||
return alt_symbol
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: return normalized (let CCXT handle errors)
|
||||
return normalized
|
||||
|
||||
def get_supported_symbols(self) -> List[str]:
|
||||
"""Get list of supported trading symbols."""
|
||||
if not self._connected or not self.exchange:
|
||||
return []
|
||||
|
||||
try:
|
||||
markets = self.exchange.load_markets()
|
||||
return list(markets.keys())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get supported symbols: {e}")
|
||||
return []
|
||||
376
src/data/providers/coingecko_provider.py
Normal file
376
src/data/providers/coingecko_provider.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""CoinGecko pricing provider for fallback market data."""
|
||||
|
||||
import httpx
|
||||
import time
|
||||
import threading
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Callable, Tuple
|
||||
from datetime import datetime
|
||||
from .base_provider import BasePricingProvider
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CoinGeckoProvider(BasePricingProvider):
|
||||
"""CoinGecko API pricing provider.
|
||||
|
||||
This provider uses CoinGecko's free API tier as a fallback when CCXT
|
||||
providers are unavailable. It uses simple REST endpoints that don't
|
||||
require authentication.
|
||||
"""
|
||||
|
||||
BASE_URL = "https://api.coingecko.com/api/v3"
|
||||
|
||||
# CoinGecko coin ID mapping for common symbols
|
||||
COIN_ID_MAP = {
|
||||
'BTC': 'bitcoin',
|
||||
'ETH': 'ethereum',
|
||||
'BNB': 'binancecoin',
|
||||
'SOL': 'solana',
|
||||
'ADA': 'cardano',
|
||||
'XRP': 'ripple',
|
||||
'DOGE': 'dogecoin',
|
||||
'DOT': 'polkadot',
|
||||
'MATIC': 'matic-network',
|
||||
'AVAX': 'avalanche-2',
|
||||
'LINK': 'chainlink',
|
||||
'USDT': 'tether',
|
||||
'USDC': 'usd-coin',
|
||||
'DAI': 'dai',
|
||||
}
|
||||
|
||||
# Currency mapping
|
||||
CURRENCY_MAP = {
|
||||
'USD': 'usd',
|
||||
'EUR': 'eur',
|
||||
'GBP': 'gbp',
|
||||
'JPY': 'jpy',
|
||||
'USDT': 'usd', # CoinGecko uses USD for stablecoins
|
||||
}
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""Initialize CoinGecko provider.
|
||||
|
||||
Args:
|
||||
api_key: Optional API key for higher rate limits (free tier doesn't require it)
|
||||
"""
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self._client = None
|
||||
self._polling_threads: Dict[str, threading.Thread] = {}
|
||||
self._stop_polling: Dict[str, bool] = {}
|
||||
self._rate_limit_delay = 1.0 # Free tier: ~30-50 calls/minute
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Provider name."""
|
||||
return "CoinGecko"
|
||||
|
||||
@property
|
||||
def supports_websocket(self) -> bool:
|
||||
"""CoinGecko free tier doesn't support WebSocket."""
|
||||
return False
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect to CoinGecko API.
|
||||
|
||||
Returns:
|
||||
True if connection successful (just validates API access)
|
||||
"""
|
||||
try:
|
||||
self._client = httpx.Client(timeout=10.0)
|
||||
|
||||
# Test connection by fetching Bitcoin price
|
||||
response = self._client.get(
|
||||
f"{self.BASE_URL}/simple/price",
|
||||
params={
|
||||
'ids': 'bitcoin',
|
||||
'vs_currencies': 'usd',
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if 'bitcoin' in data:
|
||||
self._connected = True
|
||||
logger.info("Connected to CoinGecko API")
|
||||
return True
|
||||
|
||||
logger.warning("CoinGecko API test failed")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to CoinGecko: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from CoinGecko API."""
|
||||
# Stop all polling threads
|
||||
for symbol in list(self._stop_polling.keys()):
|
||||
self._stop_polling[symbol] = True
|
||||
|
||||
# Wait for threads to stop
|
||||
for symbol, thread in self._polling_threads.items():
|
||||
if thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
self._polling_threads.clear()
|
||||
self._stop_polling.clear()
|
||||
self._subscribers.clear()
|
||||
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info("Disconnected from CoinGecko API")
|
||||
|
||||
def _parse_symbol(self, symbol: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Parse symbol into coin_id and currency.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair like 'BTC/USD' or 'BTC/USDT'
|
||||
|
||||
Returns:
|
||||
Tuple of (coin_id, currency) or (None, None) if not found
|
||||
"""
|
||||
parts = symbol.upper().replace('-', '/').split('/')
|
||||
if len(parts) != 2:
|
||||
return None, None
|
||||
|
||||
base, quote = parts
|
||||
|
||||
# Get coin ID
|
||||
coin_id = self.COIN_ID_MAP.get(base)
|
||||
if not coin_id:
|
||||
# Try lowercase base as fallback
|
||||
coin_id = base.lower()
|
||||
|
||||
# Get currency
|
||||
currency = self.CURRENCY_MAP.get(quote, quote.lower())
|
||||
|
||||
return coin_id, currency
|
||||
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data from CoinGecko."""
|
||||
if not self._connected or not self._client:
|
||||
logger.error("Provider not connected")
|
||||
return {}
|
||||
|
||||
try:
|
||||
coin_id, currency = self._parse_symbol(symbol)
|
||||
if not coin_id:
|
||||
logger.error(f"Unknown symbol format: {symbol}")
|
||||
return {}
|
||||
|
||||
# Fetch current price
|
||||
response = self._client.get(
|
||||
f"{self.BASE_URL}/simple/price",
|
||||
params={
|
||||
'ids': coin_id,
|
||||
'vs_currencies': currency,
|
||||
'include_24hr_change': 'true',
|
||||
'include_24hr_vol': 'true',
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"CoinGecko API error: {response.status_code}")
|
||||
return {}
|
||||
|
||||
data = response.json()
|
||||
if coin_id not in data:
|
||||
logger.error(f"Coin {coin_id} not found in CoinGecko response")
|
||||
return {}
|
||||
|
||||
coin_data = data[coin_id]
|
||||
price_key = currency.lower()
|
||||
|
||||
if price_key not in coin_data:
|
||||
logger.error(f"Currency {currency} not found for {coin_id}")
|
||||
return {}
|
||||
|
||||
price = Decimal(str(coin_data[price_key]))
|
||||
|
||||
# Calculate high/low from 24h change if available
|
||||
change_key = f"{price_key}_24h_change"
|
||||
vol_key = f"{price_key}_24h_vol"
|
||||
|
||||
change_24h = coin_data.get(change_key, 0) or 0
|
||||
volume_24h = coin_data.get(vol_key, 0) or 0
|
||||
|
||||
# Estimate high/low from current price and 24h change
|
||||
# This is approximate since CoinGecko free tier doesn't provide exact high/low
|
||||
current_price = float(price)
|
||||
if change_24h:
|
||||
estimated_high = current_price * (1 + abs(change_24h / 100) / 2)
|
||||
estimated_low = current_price * (1 - abs(change_24h / 100) / 2)
|
||||
else:
|
||||
estimated_high = current_price
|
||||
estimated_low = current_price
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': price, # CoinGecko doesn't provide bid/ask, use last price
|
||||
'ask': price,
|
||||
'last': price,
|
||||
'high': Decimal(str(estimated_high)),
|
||||
'low': Decimal(str(estimated_low)),
|
||||
'volume': Decimal(str(volume_24h)),
|
||||
'timestamp': int(time.time() * 1000),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ticker for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_ohlcv(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = '1h',
|
||||
since: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[List]:
|
||||
"""Get OHLCV data from CoinGecko.
|
||||
|
||||
Note: CoinGecko free tier has limited historical data access.
|
||||
This method may return empty data for some timeframes.
|
||||
"""
|
||||
if not self._connected or not self._client:
|
||||
logger.error("Provider not connected")
|
||||
return []
|
||||
|
||||
try:
|
||||
coin_id, currency = self._parse_symbol(symbol)
|
||||
if not coin_id:
|
||||
logger.error(f"Unknown symbol format: {symbol}")
|
||||
return []
|
||||
|
||||
# CoinGecko uses days parameter instead of timeframe
|
||||
# Map timeframe to days
|
||||
timeframe_days_map = {
|
||||
'1m': 1, # Last 24 hours, minute data (not available in free tier)
|
||||
'5m': 1,
|
||||
'15m': 1,
|
||||
'30m': 1,
|
||||
'1h': 1, # Last 24 hours, hourly data
|
||||
'4h': 7, # Last 7 days
|
||||
'1d': 30, # Last 30 days
|
||||
'1w': 90, # Last 90 days
|
||||
}
|
||||
|
||||
days = timeframe_days_map.get(timeframe, 7)
|
||||
|
||||
# Fetch OHLC data
|
||||
response = self._client.get(
|
||||
f"{self.BASE_URL}/coins/{coin_id}/ohlc",
|
||||
params={
|
||||
'vs_currency': currency,
|
||||
'days': days,
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"CoinGecko OHLC API returned {response.status_code}")
|
||||
return []
|
||||
|
||||
data = response.json()
|
||||
|
||||
# CoinGecko returns: [timestamp_ms, open, high, low, close]
|
||||
# We need to convert to: [timestamp_ms, open, high, low, close, volume]
|
||||
# Note: CoinGecko OHLC endpoint doesn't include volume
|
||||
|
||||
# Filter by since if provided
|
||||
if since:
|
||||
since_timestamp = int(since.timestamp() * 1000)
|
||||
data = [candle for candle in data if candle[0] >= since_timestamp]
|
||||
|
||||
# Limit results
|
||||
data = data[-limit:] if limit else data
|
||||
|
||||
# Add volume as 0 (CoinGecko doesn't provide it in OHLC endpoint)
|
||||
ohlcv = [candle + [0] for candle in data]
|
||||
|
||||
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} from CoinGecko")
|
||||
return ohlcv
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
|
||||
"""Subscribe to ticker updates via polling."""
|
||||
if not self._connected or not self._client:
|
||||
logger.error("Provider not connected")
|
||||
return False
|
||||
|
||||
key = f"ticker_{symbol}"
|
||||
|
||||
# Add callback to subscribers
|
||||
if key not in self._subscribers:
|
||||
self._subscribers[key] = []
|
||||
if callback not in self._subscribers[key]:
|
||||
self._subscribers[key].append(callback)
|
||||
|
||||
# Start polling thread if not already running
|
||||
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
|
||||
self._stop_polling[key] = False
|
||||
|
||||
def poll_ticker():
|
||||
"""Poll ticker respecting rate limits."""
|
||||
while not self._stop_polling.get(key, False):
|
||||
try:
|
||||
ticker_data = self.get_ticker(symbol)
|
||||
if ticker_data and 'last' in ticker_data:
|
||||
# Call all callbacks
|
||||
for cb in self._subscribers.get(key, []):
|
||||
try:
|
||||
cb({
|
||||
'symbol': symbol,
|
||||
'price': ticker_data['last'],
|
||||
'bid': ticker_data.get('bid', 0),
|
||||
'ask': ticker_data.get('ask', 0),
|
||||
'volume': ticker_data.get('volume', 0),
|
||||
'timestamp': ticker_data.get('timestamp'),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
# Rate limit: wait between requests (free tier: ~30-50 calls/min)
|
||||
time.sleep(self._rate_limit_delay * 2) # Poll every 2 seconds
|
||||
except Exception as e:
|
||||
logger.error(f"Ticker polling error for {symbol}: {e}")
|
||||
time.sleep(10) # Wait longer on error
|
||||
|
||||
thread = threading.Thread(target=poll_ticker, daemon=True)
|
||||
thread.start()
|
||||
self._polling_threads[key] = thread
|
||||
logger.info(f"Subscribed to ticker updates for {symbol} (CoinGecko polling)")
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
|
||||
"""Unsubscribe from ticker updates."""
|
||||
key = f"ticker_{symbol}"
|
||||
|
||||
# Stop polling
|
||||
if key in self._stop_polling:
|
||||
self._stop_polling[key] = True
|
||||
|
||||
# Remove from subscribers
|
||||
super().unsubscribe_ticker(symbol, callback)
|
||||
|
||||
# Clean up thread
|
||||
if key in self._polling_threads:
|
||||
del self._polling_threads[key]
|
||||
|
||||
logger.info(f"Unsubscribed from ticker updates for {symbol}")
|
||||
|
||||
def normalize_symbol(self, symbol: str) -> str:
|
||||
"""Normalize symbol format."""
|
||||
# CoinGecko uses coin IDs, so we just normalize the format
|
||||
return symbol.replace('-', '/').upper()
|
||||
116
src/data/quality.py
Normal file
116
src/data/quality.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Data quality validation, gap filling, and retention policies."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, MarketData
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataQualityManager:
|
||||
"""Manages data quality and retention."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data quality manager."""
|
||||
self.db = get_database()
|
||||
self.config = get_config()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def validate_data_quality(
|
||||
self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate data quality.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Returns:
|
||||
Quality report
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
data = session.query(MarketData).filter(
|
||||
MarketData.exchange == exchange,
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe,
|
||||
MarketData.timestamp >= start_date,
|
||||
MarketData.timestamp <= end_date
|
||||
).order_by(MarketData.timestamp).all()
|
||||
|
||||
if len(data) == 0:
|
||||
return {"valid": False, "reason": "No data"}
|
||||
|
||||
# Check for gaps
|
||||
gaps = self._detect_gaps(data, timeframe)
|
||||
|
||||
# Check for anomalies
|
||||
anomalies = self._detect_anomalies(data)
|
||||
|
||||
return {
|
||||
"valid": len(gaps) == 0 and len(anomalies) == 0,
|
||||
"total_records": len(data),
|
||||
"gaps": len(gaps),
|
||||
"anomalies": len(anomalies),
|
||||
}
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def _detect_gaps(self, data: List[MarketData], timeframe: str) -> List[datetime]:
|
||||
"""Detect gaps in data.
|
||||
|
||||
Args:
|
||||
data: List of market data
|
||||
timeframe: Timeframe
|
||||
|
||||
Returns:
|
||||
List of gap timestamps
|
||||
"""
|
||||
gaps = []
|
||||
# Simplified gap detection
|
||||
return gaps
|
||||
|
||||
def _detect_anomalies(self, data: List[MarketData]) -> List[int]:
|
||||
"""Detect data anomalies.
|
||||
|
||||
Args:
|
||||
data: List of market data
|
||||
|
||||
Returns:
|
||||
List of anomaly indices
|
||||
"""
|
||||
anomalies = []
|
||||
# Simplified anomaly detection
|
||||
return anomalies
|
||||
|
||||
def cleanup_old_data(self, days_to_keep: int = 365):
|
||||
"""Clean up old data based on retention policy.
|
||||
|
||||
Args:
|
||||
days_to_keep: Days of data to keep
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_to_keep)
|
||||
deleted = session.query(MarketData).filter(
|
||||
MarketData.timestamp < cutoff_date
|
||||
).delete()
|
||||
session.commit()
|
||||
logger.info(f"Cleaned up {deleted} old data records")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to cleanup old data: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
225
src/data/redis_cache.py
Normal file
225
src/data/redis_cache.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Redis-based caching for market data and API responses."""
|
||||
|
||||
from typing import Any, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
from src.core.redis import get_redis_client
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisCache:
|
||||
"""Redis-based cache for market data and API responses."""
|
||||
|
||||
# Default TTL values (seconds)
|
||||
TTL_TICKER = 5 # Ticker prices are very volatile
|
||||
TTL_OHLCV = 60 # OHLCV can be cached longer
|
||||
TTL_ORDERBOOK = 2 # Order books change rapidly
|
||||
TTL_API_RESPONSE = 30 # General API response cache
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Redis cache."""
|
||||
self.redis = get_redis_client()
|
||||
|
||||
async def get_ticker(self, symbol: str) -> Optional[dict]:
|
||||
"""Get cached ticker data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USD')
|
||||
|
||||
Returns:
|
||||
Cached ticker data or None
|
||||
"""
|
||||
key = f"cache:ticker:{symbol.replace('/', '_')}"
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
data = await client.get(key)
|
||||
if data:
|
||||
logger.debug(f"Cache hit for ticker:{symbol}")
|
||||
return json.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache get failed: {e}")
|
||||
return None
|
||||
|
||||
async def set_ticker(self, symbol: str, data: dict, ttl: int = None) -> bool:
|
||||
"""Cache ticker data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
data: Ticker data
|
||||
ttl: Time-to-live in seconds (default: TTL_TICKER)
|
||||
|
||||
Returns:
|
||||
True if cached successfully
|
||||
"""
|
||||
key = f"cache:ticker:{symbol.replace('/', '_')}"
|
||||
ttl = ttl or self.TTL_TICKER
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
await client.setex(key, ttl, json.dumps(data))
|
||||
logger.debug(f"Cached ticker:{symbol} for {ttl}s")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_ohlcv(self, symbol: str, timeframe: str, limit: int = 100) -> Optional[list]:
|
||||
"""Get cached OHLCV data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
limit: Number of candles
|
||||
|
||||
Returns:
|
||||
Cached OHLCV data or None
|
||||
"""
|
||||
key = f"cache:ohlcv:{symbol.replace('/', '_')}:{timeframe}:{limit}"
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
data = await client.get(key)
|
||||
if data:
|
||||
logger.debug(f"Cache hit for ohlcv:{symbol}:{timeframe}")
|
||||
return json.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache get failed: {e}")
|
||||
return None
|
||||
|
||||
async def set_ohlcv(self, symbol: str, timeframe: str, data: list, limit: int = 100, ttl: int = None) -> bool:
|
||||
"""Cache OHLCV data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
data: OHLCV data
|
||||
limit: Number of candles
|
||||
ttl: Time-to-live in seconds
|
||||
|
||||
Returns:
|
||||
True if cached successfully
|
||||
"""
|
||||
key = f"cache:ohlcv:{symbol.replace('/', '_')}:{timeframe}:{limit}"
|
||||
ttl = ttl or self.TTL_OHLCV
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
await client.setex(key, ttl, json.dumps(data))
|
||||
logger.debug(f"Cached ohlcv:{symbol}:{timeframe} for {ttl}s")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_api_response(self, cache_key: str) -> Optional[dict]:
|
||||
"""Get cached API response.
|
||||
|
||||
Args:
|
||||
cache_key: Unique cache key
|
||||
|
||||
Returns:
|
||||
Cached response or None
|
||||
"""
|
||||
key = f"cache:api:{cache_key}"
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
data = await client.get(key)
|
||||
if data:
|
||||
logger.debug(f"Cache hit for api:{cache_key}")
|
||||
return json.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache get failed: {e}")
|
||||
return None
|
||||
|
||||
async def set_api_response(self, cache_key: str, data: dict, ttl: int = None) -> bool:
|
||||
"""Cache API response.
|
||||
|
||||
Args:
|
||||
cache_key: Unique cache key
|
||||
data: Response data
|
||||
ttl: Time-to-live in seconds
|
||||
|
||||
Returns:
|
||||
True if cached successfully
|
||||
"""
|
||||
key = f"cache:api:{cache_key}"
|
||||
ttl = ttl or self.TTL_API_RESPONSE
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
await client.setex(key, ttl, json.dumps(data))
|
||||
logger.debug(f"Cached api:{cache_key} for {ttl}s")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache set failed: {e}")
|
||||
return False
|
||||
|
||||
async def invalidate(self, pattern: str) -> int:
|
||||
"""Invalidate cache entries matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Redis key pattern (e.g., 'cache:ticker:*')
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
keys = []
|
||||
async for key in client.scan_iter(match=pattern):
|
||||
keys.append(key)
|
||||
|
||||
if keys:
|
||||
deleted = await client.delete(*keys)
|
||||
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
|
||||
return deleted
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache invalidation failed: {e}")
|
||||
return 0
|
||||
|
||||
async def get_stats(self) -> dict:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Cache statistics
|
||||
"""
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
info = await client.info('memory')
|
||||
|
||||
# Count cached items by type
|
||||
ticker_count = 0
|
||||
ohlcv_count = 0
|
||||
api_count = 0
|
||||
|
||||
async for key in client.scan_iter(match='cache:ticker:*'):
|
||||
ticker_count += 1
|
||||
async for key in client.scan_iter(match='cache:ohlcv:*'):
|
||||
ohlcv_count += 1
|
||||
async for key in client.scan_iter(match='cache:api:*'):
|
||||
api_count += 1
|
||||
|
||||
return {
|
||||
"memory_used": info.get('used_memory_human', 'N/A'),
|
||||
"ticker_entries": ticker_count,
|
||||
"ohlcv_entries": ohlcv_count,
|
||||
"api_entries": api_count,
|
||||
"total_entries": ticker_count + ohlcv_count + api_count
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cache stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_redis_cache: Optional[RedisCache] = None
|
||||
|
||||
|
||||
def get_redis_cache() -> RedisCache:
|
||||
"""Get global Redis cache instance."""
|
||||
global _redis_cache
|
||||
if _redis_cache is None:
|
||||
_redis_cache = RedisCache()
|
||||
return _redis_cache
|
||||
75
src/data/storage.py
Normal file
75
src/data/storage.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Data persistence."""
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from src.core.database import get_database, MarketData
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataStorage:
|
||||
"""Manages data storage and persistence."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize data storage."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def store_ohlcv(
|
||||
self,
|
||||
exchange: str,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
timestamp: datetime,
|
||||
open: Decimal,
|
||||
high: Decimal,
|
||||
low: Decimal,
|
||||
close: Decimal,
|
||||
volume: Decimal
|
||||
):
|
||||
"""Store OHLCV data.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
timestamp: Timestamp
|
||||
open: Open price
|
||||
high: High price
|
||||
low: Low price
|
||||
close: Close price
|
||||
volume: Volume
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
# Check if exists
|
||||
existing = session.query(MarketData).filter_by(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp
|
||||
).first()
|
||||
|
||||
if not existing:
|
||||
market_data = MarketData(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
open=open,
|
||||
high=high,
|
||||
low=low,
|
||||
close=close,
|
||||
volume=volume
|
||||
)
|
||||
session.add(market_data)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to store OHLCV data: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
Reference in New Issue
Block a user