Local changes: Updated model training, removed debug instrumentation, and configuration improvements

This commit is contained in:
kfox
2025-12-26 01:15:43 -05:00
commit cc60da49e7
388 changed files with 57127 additions and 0 deletions

27
src/data/__init__.py Normal file
View File

@@ -0,0 +1,27 @@
"""Data collection and storage module.
Provides:
- DataCollector: Real-time market data collection
- NewsCollector: Crypto news headline aggregation for sentiment analysis
- TechnicalIndicators: Technical analysis indicators
- DataStorage: Data persistence utilities
- DataQualityManager: Data quality checks
"""
from .collector import DataCollector
from .news_collector import NewsCollector, NewsItem, NewsSource, get_news_collector
from .indicators import TechnicalIndicators, get_indicators
from .storage import DataStorage
from .quality import DataQualityManager
__all__ = [
"DataCollector",
"NewsCollector",
"NewsItem",
"NewsSource",
"get_news_collector",
"TechnicalIndicators",
"get_indicators",
"DataStorage",
"DataQualityManager",
]

221
src/data/cache_manager.py Normal file
View File

@@ -0,0 +1,221 @@
"""Intelligent caching system for pricing data."""
import time
from typing import Dict, Any, Optional, Tuple
from datetime import datetime, timedelta
from collections import OrderedDict
from src.core.logger import get_logger
logger = get_logger(__name__)
class CacheEntry:
"""Cache entry with TTL support."""
def __init__(self, data: Any, ttl: float):
"""Initialize cache entry.
Args:
data: Cached data
ttl: Time to live in seconds
"""
self.data = data
self.expires_at = time.time() + ttl
self.created_at = time.time()
self.access_count = 0
self.last_accessed = time.time()
def is_expired(self) -> bool:
"""Check if entry is expired.
Returns:
True if expired
"""
return time.time() > self.expires_at
def touch(self):
"""Update access statistics."""
self.access_count += 1
self.last_accessed = time.time()
def age(self) -> float:
"""Get age of entry in seconds.
Returns:
Age in seconds
"""
return time.time() - self.created_at
class CacheManager:
"""Intelligent cache manager with TTL and size limits.
Implements LRU (Least Recently Used) eviction when size limit is reached.
"""
def __init__(
self,
default_ttl: float = 60.0,
max_size: int = 1000,
ticker_ttl: float = 2.0,
ohlcv_ttl: float = 60.0
):
"""Initialize cache manager.
Args:
default_ttl: Default TTL in seconds
max_size: Maximum number of cache entries
ticker_ttl: TTL for ticker data in seconds
ohlcv_ttl: TTL for OHLCV data in seconds
"""
self.default_ttl = default_ttl
self.max_size = max_size
self.ticker_ttl = ticker_ttl
self.ohlcv_ttl = ohlcv_ttl
# Use OrderedDict for LRU eviction
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._hits = 0
self._misses = 0
self._evictions = 0
self.logger = get_logger(__name__)
def get(self, key: str) -> Optional[Any]:
"""Get value from cache.
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
# Clean expired entries
self._cleanup_expired()
if key not in self._cache:
self._misses += 1
return None
entry = self._cache[key]
if entry.is_expired():
# Remove expired entry
del self._cache[key]
self._misses += 1
return None
# Update access (move to end for LRU)
entry.touch()
self._cache.move_to_end(key)
self._hits += 1
return entry.data
def set(
self,
key: str,
value: Any,
ttl: Optional[float] = None,
cache_type: Optional[str] = None
):
"""Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds (uses type-specific or default if None)
cache_type: Type of cache ('ticker' or 'ohlcv') for type-specific TTL
"""
# Determine TTL
if ttl is None:
if cache_type == 'ticker':
ttl = self.ticker_ttl
elif cache_type == 'ohlcv':
ttl = self.ohlcv_ttl
else:
ttl = self.default_ttl
# Check if we need to evict
if len(self._cache) >= self.max_size and key not in self._cache:
self._evict_lru()
# Create entry
entry = CacheEntry(value, ttl)
# Add or update
if key in self._cache:
self._cache.move_to_end(key)
self._cache[key] = entry
def delete(self, key: str) -> bool:
"""Delete entry from cache.
Args:
key: Cache key
Returns:
True if entry was deleted, False if not found
"""
if key in self._cache:
del self._cache[key]
return True
return False
def clear(self):
"""Clear all cache entries."""
self._cache.clear()
self.logger.info("Cache cleared")
def _cleanup_expired(self):
"""Remove expired entries from cache."""
expired_keys = [
key for key, entry in self._cache.items()
if entry.is_expired()
]
for key in expired_keys:
del self._cache[key]
def _evict_lru(self):
"""Evict least recently used entry."""
if self._cache:
# Remove oldest (first) entry
self._cache.popitem(last=False)
self._evictions += 1
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics.
Returns:
Dictionary with cache statistics
"""
total_requests = self._hits + self._misses
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0.0
# Calculate average age
if self._cache:
avg_age = sum(entry.age() for entry in self._cache.values()) / len(self._cache)
else:
avg_age = 0.0
return {
'size': len(self._cache),
'max_size': self.max_size,
'hits': self._hits,
'misses': self._misses,
'hit_rate': round(hit_rate, 2),
'evictions': self._evictions,
'avg_age_seconds': round(avg_age, 2),
}
def invalidate_pattern(self, pattern: str):
"""Invalidate entries matching a pattern.
Args:
pattern: String pattern to match (simple substring match)
"""
keys_to_delete = [key for key in self._cache.keys() if pattern in key]
for key in keys_to_delete:
del self._cache[key]
if keys_to_delete:
self.logger.info(f"Invalidated {len(keys_to_delete)} cache entries matching '{pattern}'")

139
src/data/collector.py Normal file
View File

@@ -0,0 +1,139 @@
"""Real-time data collection system with WebSocket support."""
import asyncio
from decimal import Decimal
from typing import Dict, Optional, Callable, List
from datetime import datetime
from sqlalchemy import select
from src.core.database import get_database, MarketData
from src.core.logger import get_logger
from .pricing_service import get_pricing_service
logger = get_logger(__name__)
class DataCollector:
"""Collects real-time market data using the unified pricing service."""
def __init__(self):
"""Initialize data collector."""
self.db = get_database()
self.logger = get_logger(__name__)
self._callbacks: Dict[str, List[Callable]] = {}
self._running = False
self._pricing_service = get_pricing_service()
def subscribe(
self,
exchange_id: Optional[int] = None,
symbol: str = "",
callback: Optional[Callable] = None
):
"""Subscribe to real-time data.
Args:
exchange_id: Exchange ID (deprecated, kept for backward compatibility)
symbol: Trading symbol
callback: Callback function(data)
"""
if not symbol or not callback:
logger.warning("subscribe called without symbol or callback")
return
key = f"pricing:{symbol}"
if key not in self._callbacks:
self._callbacks[key] = []
self._callbacks[key].append(callback)
# Subscribe via pricing service
def wrapped_callback(data):
"""Wrap callback to maintain backward compatibility."""
for cb in self._callbacks.get(key, []):
try:
cb(data)
except Exception as e:
logger.error(f"Callback error for {symbol}: {e}")
self._pricing_service.subscribe_ticker(symbol, wrapped_callback)
async def store_ohlcv(
self,
exchange: str,
symbol: str,
timeframe: str,
ohlcv_data: List[List]
):
"""Store OHLCV data in database.
Args:
exchange: Exchange name (can be provider name like 'CCXT-Kraken' or 'CoinGecko')
symbol: Trading symbol
timeframe: Timeframe
ohlcv_data: List of [timestamp, open, high, low, close, volume]
"""
async with self.db.get_session() as session:
try:
for candle in ohlcv_data:
timestamp = datetime.fromtimestamp(candle[0] / 1000)
# Check if already exists
stmt = select(MarketData).filter_by(
exchange=exchange,
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if not existing:
market_data = MarketData(
exchange=exchange,
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp,
open=Decimal(str(candle[1])),
high=Decimal(str(candle[2])),
low=Decimal(str(candle[3])),
close=Decimal(str(candle[4])),
volume=Decimal(str(candle[5]))
)
session.add(market_data)
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Failed to store OHLCV data: {e}")
def get_ohlcv_from_pricing_service(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV data from pricing service.
Args:
symbol: Trading symbol
timeframe: Timeframe
since: Start datetime
limit: Number of candles
Returns:
List of OHLCV candles
"""
return self._pricing_service.get_ohlcv(symbol, timeframe, since, limit)
# Global data collector
_data_collector: Optional[DataCollector] = None
def get_data_collector() -> DataCollector:
"""Get global data collector instance."""
global _data_collector
if _data_collector is None:
_data_collector = DataCollector()
return _data_collector

317
src/data/health_monitor.py Normal file
View File

@@ -0,0 +1,317 @@
"""Health monitoring and failover management for pricing providers."""
import time
from typing import Dict, List, Optional, Any
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from collections import deque
from src.core.logger import get_logger
logger = get_logger(__name__)
class HealthStatus(Enum):
"""Provider health status."""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
UNKNOWN = "unknown"
@dataclass
class HealthMetrics:
"""Health metrics for a provider."""
status: HealthStatus = HealthStatus.UNKNOWN
last_check: Optional[datetime] = None
last_success: Optional[datetime] = None
last_failure: Optional[datetime] = None
success_count: int = 0
failure_count: int = 0
response_times: deque = field(default_factory=lambda: deque(maxlen=100))
consecutive_failures: int = 0
circuit_breaker_open: bool = False
circuit_breaker_opened_at: Optional[datetime] = None
def record_success(self, response_time: float):
"""Record a successful operation.
Args:
response_time: Response time in seconds
"""
self.status = HealthStatus.HEALTHY
self.last_check = datetime.utcnow()
self.last_success = datetime.utcnow()
self.success_count += 1
self.response_times.append(response_time)
self.consecutive_failures = 0
self.circuit_breaker_open = False
self.circuit_breaker_opened_at = None
def record_failure(self):
"""Record a failed operation."""
self.last_check = datetime.utcnow()
self.last_failure = datetime.utcnow()
self.failure_count += 1
self.consecutive_failures += 1
# Open circuit breaker after 5 consecutive failures
if self.consecutive_failures >= 5:
if not self.circuit_breaker_open:
self.circuit_breaker_open = True
self.circuit_breaker_opened_at = datetime.utcnow()
logger.warning(f"Circuit breaker opened after {self.consecutive_failures} failures")
# Update status based on failure rate
total = self.success_count + self.failure_count
if total > 0:
failure_rate = self.failure_count / total
if failure_rate > 0.5:
self.status = HealthStatus.UNHEALTHY
elif failure_rate > 0.2:
self.status = HealthStatus.DEGRADED
def get_avg_response_time(self) -> float:
"""Get average response time.
Returns:
Average response time in seconds, or 0.0 if no data
"""
if not self.response_times:
return 0.0
return sum(self.response_times) / len(self.response_times)
def should_attempt(self, circuit_breaker_timeout: int = 60) -> bool:
"""Check if we should attempt to use this provider.
Args:
circuit_breaker_timeout: Seconds to wait before retrying after circuit breaker opens
Returns:
True if we should attempt, False otherwise
"""
if not self.circuit_breaker_open:
return True
# Check if timeout has passed
if self.circuit_breaker_opened_at:
elapsed = (datetime.utcnow() - self.circuit_breaker_opened_at).total_seconds()
if elapsed >= circuit_breaker_timeout:
# Half-open state: allow one attempt
logger.info("Circuit breaker half-open, allowing attempt")
return True
return False
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API responses."""
return {
'status': self.status.value,
'last_check': self.last_check.isoformat() if self.last_check else None,
'last_success': self.last_success.isoformat() if self.last_success else None,
'last_failure': self.last_failure.isoformat() if self.last_failure else None,
'success_count': self.success_count,
'failure_count': self.failure_count,
'avg_response_time': round(self.get_avg_response_time(), 3),
'consecutive_failures': self.consecutive_failures,
'circuit_breaker_open': self.circuit_breaker_open,
'circuit_breaker_opened_at': (
self.circuit_breaker_opened_at.isoformat()
if self.circuit_breaker_opened_at else None
),
}
class HealthMonitor:
"""Monitors health of pricing providers and manages failover."""
def __init__(
self,
circuit_breaker_timeout: int = 60,
min_success_rate: float = 0.8,
max_avg_response_time: float = 5.0
):
"""Initialize health monitor.
Args:
circuit_breaker_timeout: Seconds to wait before retrying after circuit breaker opens
min_success_rate: Minimum success rate to be considered healthy (0.0-1.0)
max_avg_response_time: Maximum average response time in seconds to be considered healthy
"""
self.circuit_breaker_timeout = circuit_breaker_timeout
self.min_success_rate = min_success_rate
self.max_avg_response_time = max_avg_response_time
self._metrics: Dict[str, HealthMetrics] = {}
self.logger = get_logger(__name__)
def record_success(self, provider_name: str, response_time: float):
"""Record a successful operation for a provider.
Args:
provider_name: Name of the provider
response_time: Response time in seconds
"""
if provider_name not in self._metrics:
self._metrics[provider_name] = HealthMetrics()
self._metrics[provider_name].record_success(response_time)
self.logger.debug(f"Recorded success for {provider_name} ({response_time:.3f}s)")
def record_failure(self, provider_name: str):
"""Record a failed operation for a provider.
Args:
provider_name: Name of the provider
"""
if provider_name not in self._metrics:
self._metrics[provider_name] = HealthMetrics()
self._metrics[provider_name].record_failure()
self.logger.warning(f"Recorded failure for {provider_name} "
f"(consecutive: {self._metrics[provider_name].consecutive_failures})")
def is_healthy(self, provider_name: str) -> bool:
"""Check if a provider is healthy.
Args:
provider_name: Name of the provider
Returns:
True if provider is healthy
"""
if provider_name not in self._metrics:
return True # Assume healthy if no metrics yet
metrics = self._metrics[provider_name]
# Check circuit breaker
if not metrics.should_attempt(self.circuit_breaker_timeout):
return False
# Check status
if metrics.status == HealthStatus.UNHEALTHY:
return False
# Check success rate
total = metrics.success_count + metrics.failure_count
if total > 10: # Need minimum data points
success_rate = metrics.success_count / total
if success_rate < self.min_success_rate:
return False
# Check response time
if metrics.response_times:
avg_response_time = metrics.get_avg_response_time()
if avg_response_time > self.max_avg_response_time:
return False
return True
def get_health_status(self, provider_name: str) -> HealthStatus:
"""Get health status for a provider.
Args:
provider_name: Name of the provider
Returns:
Health status
"""
if provider_name not in self._metrics:
return HealthStatus.UNKNOWN
return self._metrics[provider_name].status
def get_metrics(self, provider_name: str) -> Optional[HealthMetrics]:
"""Get health metrics for a provider.
Args:
provider_name: Name of the provider
Returns:
Health metrics or None if not found
"""
return self._metrics.get(provider_name)
def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
"""Get all provider metrics.
Returns:
Dictionary mapping provider names to their metrics
"""
return {
name: metrics.to_dict()
for name, metrics in self._metrics.items()
}
def select_healthiest(self, provider_names: List[str]) -> Optional[str]:
"""Select the healthiest provider from a list.
Args:
provider_names: List of provider names to choose from
Returns:
Name of healthiest provider, or None if none are healthy
"""
healthy_providers = [
name for name in provider_names
if self.is_healthy(name)
]
if not healthy_providers:
return None
# Sort by health metrics (better providers first)
def health_score(name: str) -> tuple:
metrics = self._metrics.get(name)
if not metrics:
return (1, 0, 0) # Unknown providers get lowest priority
# Score: (status_weight, -avg_response_time, success_rate)
status_weights = {
HealthStatus.HEALTHY: 3,
HealthStatus.DEGRADED: 2,
HealthStatus.UNHEALTHY: 1,
HealthStatus.UNKNOWN: 0,
}
success_rate = (
metrics.success_count / (metrics.success_count + metrics.failure_count)
if (metrics.success_count + metrics.failure_count) > 0
else 0.0
)
return (
status_weights.get(metrics.status, 0),
-metrics.get_avg_response_time(),
success_rate
)
sorted_providers = sorted(healthy_providers, key=health_score, reverse=True)
return sorted_providers[0] if sorted_providers else None
def reset_circuit_breaker(self, provider_name: str):
"""Manually reset circuit breaker for a provider.
Args:
provider_name: Name of the provider
"""
if provider_name in self._metrics:
self._metrics[provider_name].circuit_breaker_open = False
self._metrics[provider_name].circuit_breaker_opened_at = None
self._metrics[provider_name].consecutive_failures = 0
self.logger.info(f"Circuit breaker reset for {provider_name}")
def reset_metrics(self, provider_name: Optional[str] = None):
"""Reset metrics for a provider or all providers.
Args:
provider_name: Name of provider to reset, or None to reset all
"""
if provider_name:
if provider_name in self._metrics:
del self._metrics[provider_name]
self.logger.info(f"Reset metrics for {provider_name}")
else:
self._metrics.clear()
self.logger.info("Reset all provider metrics")

569
src/data/indicators.py Normal file
View File

@@ -0,0 +1,569 @@
"""Comprehensive technical indicator library using pandas-ta and TA-Lib."""
import pandas as pd
import numpy as np
from typing import Optional, Dict, Any, List
# Try to import pandas_ta, but handle if numba is missing
try:
import pandas_ta as ta
PANDAS_TA_AVAILABLE = True
except ImportError:
PANDAS_TA_AVAILABLE = False
ta = None
import warnings
warnings.warn("pandas-ta not available (numba issue), using basic implementations")
try:
import talib
TALIB_AVAILABLE = True
except ImportError:
TALIB_AVAILABLE = False
from src.core.logger import get_logger
logger = get_logger(__name__)
class TechnicalIndicators:
"""Technical indicators library."""
def __init__(self):
"""Initialize indicators library."""
self.talib_available = TALIB_AVAILABLE
# Trend Indicators
def sma(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Simple Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.sma(data, length=period)
return data.rolling(window=period).mean()
def ema(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Exponential Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.ema(data, length=period)
return data.ewm(span=period, adjust=False).mean()
def wma(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Weighted Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.wma(data, length=period)
# Basic WMA implementation
weights = np.arange(1, period + 1)
return data.rolling(window=period).apply(lambda x: np.dot(x, weights) / weights.sum(), raw=True)
def dema(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Double Exponential Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.dema(data, length=period)
ema1 = self.ema(data, period)
return 2 * ema1 - self.ema(ema1, period)
def tema(self, data: pd.Series, period: int = 20) -> pd.Series:
"""Triple Exponential Moving Average."""
if PANDAS_TA_AVAILABLE:
return ta.tema(data, length=period)
ema1 = self.ema(data, period)
ema2 = self.ema(ema1, period)
ema3 = self.ema(ema2, period)
return 3 * ema1 - 3 * ema2 + ema3
# Momentum Indicators
def rsi(self, data: pd.Series, period: int = 14) -> pd.Series:
"""Relative Strength Index."""
if self.talib_available:
return pd.Series(talib.RSI(data.values, timeperiod=period), index=data.index)
if PANDAS_TA_AVAILABLE:
return ta.rsi(data, length=period)
# Basic RSI implementation
delta = data.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
return 100 - (100 / (1 + rs))
def macd(
self,
data: pd.Series,
fast: int = 12,
slow: int = 26,
signal: int = 9
) -> Dict[str, pd.Series]:
"""MACD (Moving Average Convergence Divergence)."""
if self.talib_available:
macd, signal_line, histogram = talib.MACD(
data.values, fastperiod=fast, slowperiod=slow, signalperiod=signal
)
return {
'macd': pd.Series(macd, index=data.index),
'signal': pd.Series(signal_line, index=data.index),
'histogram': pd.Series(histogram, index=data.index),
}
if not PANDAS_TA_AVAILABLE or ta is None:
# Basic MACD implementation fallback
ema_fast = self.ema(data, fast)
ema_slow = self.ema(data, slow)
macd_line = ema_fast - ema_slow
signal_line = self.ema(macd_line.dropna(), signal)
histogram = macd_line - signal_line
return {
'macd': macd_line,
'signal': signal_line,
'histogram': histogram,
}
result = ta.macd(data, fast=fast, slow=slow, signal=signal)
return {
'macd': result[f'MACD_{fast}_{slow}_{signal}'],
'signal': result[f'MACDs_{fast}_{slow}_{signal}'],
'histogram': result[f'MACDh_{fast}_{slow}_{signal}'],
}
def stochastic(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
k_period: int = 14,
d_period: int = 3
) -> Dict[str, pd.Series]:
"""Stochastic Oscillator."""
if self.talib_available:
slowk, slowd = talib.STOCH(
high.values, low.values, close.values,
fastk_period=k_period, slowk_period=d_period, slowd_period=d_period
)
return {
'k': pd.Series(slowk, index=close.index),
'd': pd.Series(slowd, index=close.index),
}
if PANDAS_TA_AVAILABLE and ta is not None:
result = ta.stoch(high, low, close, k=k_period, d=d_period)
return {
'k': result[f'STOCHk_{k_period}_{d_period}_{d_period}'],
'd': result[f'STOCHd_{k_period}_{d_period}_{d_period}'],
}
# Basic Stochastic implementation
lowest_low = low.rolling(window=k_period).min()
highest_high = high.rolling(window=k_period).max()
k = 100 * ((close - lowest_low) / (highest_high - lowest_low))
d = k.rolling(window=d_period).mean()
return {'k': k, 'd': d}
def cci(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 20
) -> pd.Series:
"""Commodity Channel Index."""
if self.talib_available:
return pd.Series(
talib.CCI(high.values, low.values, close.values, timeperiod=period),
index=close.index
)
if PANDAS_TA_AVAILABLE:
return ta.cci(high, low, close, length=period)
# Basic CCI implementation
tp = (high + low + close) / 3
sma_tp = tp.rolling(window=period).mean()
mad = tp.rolling(window=period).apply(lambda x: np.abs(x - x.mean()).mean(), raw=True)
return (tp - sma_tp) / (0.015 * mad)
def williams_r(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 14
) -> pd.Series:
"""Williams %R."""
if self.talib_available:
return pd.Series(
talib.WILLR(high.values, low.values, close.values, timeperiod=period),
index=close.index
)
if PANDAS_TA_AVAILABLE:
return ta.willr(high, low, close, length=period)
# Basic Williams %R implementation
highest_high = high.rolling(window=period).max()
lowest_low = low.rolling(window=period).min()
return -100 * ((highest_high - close) / (highest_high - lowest_low))
# Volatility Indicators
def bollinger_bands(
self,
data: pd.Series,
period: int = 20,
std_dev: float = 2.0
) -> Dict[str, pd.Series]:
"""Bollinger Bands."""
if self.talib_available:
upper, middle, lower = talib.BBANDS(
data.values, timeperiod=period, nbdevup=std_dev, nbdevdn=std_dev
)
return {
'upper': pd.Series(upper, index=data.index),
'middle': pd.Series(middle, index=data.index),
'lower': pd.Series(lower, index=data.index),
}
if PANDAS_TA_AVAILABLE:
result = ta.bbands(data, length=period, std=std_dev)
return {
'upper': result[f'BBU_{period}_{std_dev}'],
'middle': result[f'BBM_{period}_{std_dev}'],
'lower': result[f'BBL_{period}_{std_dev}'],
}
# Basic Bollinger Bands implementation
middle = self.sma(data, period)
std = data.rolling(window=period).std()
return {
'upper': middle + (std * std_dev),
'middle': middle,
'lower': middle - (std * std_dev),
}
def atr(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 14
) -> pd.Series:
"""Average True Range."""
if self.talib_available:
return pd.Series(
talib.ATR(high.values, low.values, close.values, timeperiod=period),
index=close.index
)
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.atr(high, low, close, length=period)
# Basic ATR implementation
prev_close = close.shift(1)
tr1 = high - low
tr2 = abs(high - prev_close)
tr3 = abs(low - prev_close)
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
return tr.rolling(window=period).mean()
def keltner_channels(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 20,
multiplier: float = 2.0
) -> Dict[str, pd.Series]:
"""Keltner Channels."""
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.kc(high, low, close, length=period, scalar=multiplier)
# Basic Keltner Channels implementation
middle = self.ema(close, period)
atr_val = self.atr(high, low, close, period)
return {
'lower': middle - (multiplier * atr_val),
'middle': middle,
'upper': middle + (multiplier * atr_val),
}
# Volume Indicators
def obv(self, close: pd.Series, volume: pd.Series) -> pd.Series:
"""On-Balance Volume."""
if self.talib_available:
return pd.Series(
talib.OBV(close.values, volume.values),
index=close.index
)
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.obv(close, volume)
# Basic OBV implementation
price_change = close.diff()
obv = pd.Series(index=close.index, dtype=float)
obv.iloc[0] = 0
for i in range(1, len(close)):
if price_change.iloc[i] > 0:
obv.iloc[i] = obv.iloc[i-1] + volume.iloc[i]
elif price_change.iloc[i] < 0:
obv.iloc[i] = obv.iloc[i-1] - volume.iloc[i]
else:
obv.iloc[i] = obv.iloc[i-1]
return obv
def vwap(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
volume: pd.Series
) -> pd.Series:
"""Volume Weighted Average Price."""
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.vwap(high, low, close, volume)
# Basic VWAP implementation
typical_price = (high + low + close) / 3
cumulative_tp_vol = (typical_price * volume).cumsum()
cumulative_vol = volume.cumsum()
return cumulative_tp_vol / cumulative_vol
# Advanced Indicators
def ichimoku(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
tenkan: int = 9,
kijun: int = 26,
senkou: int = 52
) -> Dict[str, pd.Series]:
"""Ichimoku Cloud."""
if PANDAS_TA_AVAILABLE and ta is not None:
result = ta.ichimoku(high, low, close, tenkan=tenkan, kijun=kijun, senkou=senkou)
return {
'tenkan': result['ITS_9'],
'kijun': result['IKS_26'],
'senkou_a': result['ISA_9'],
'senkou_b': result['ISB_26'],
'chikou': result['ICS_26'],
}
# Basic Ichimoku implementation
tenkan_sen = (high.rolling(window=tenkan).max() + low.rolling(window=tenkan).min()) / 2
kijun_sen = (high.rolling(window=kijun).max() + low.rolling(window=kijun).min()) / 2
senkou_a = ((tenkan_sen + kijun_sen) / 2).shift(kijun)
senkou_b = ((high.rolling(window=senkou).max() + low.rolling(window=senkou).min()) / 2).shift(kijun)
chikou = close.shift(-kijun)
return {
'tenkan': tenkan_sen,
'kijun': kijun_sen,
'senkou_a': senkou_a,
'senkou_b': senkou_b,
'chikou': chikou,
}
def adx(
self,
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 14
) -> pd.Series:
"""Average Directional Index."""
if self.talib_available:
return pd.Series(
talib.ADX(high.values, low.values, close.values, timeperiod=period),
index=close.index
)
if PANDAS_TA_AVAILABLE and ta is not None:
return ta.adx(high, low, close, length=period)
# Basic ADX implementation
plus_dm = high.diff()
minus_dm = low.diff().abs()
plus_dm[plus_dm < 0] = 0
minus_dm[minus_dm < 0] = 0
atr_val = self.atr(high, low, close, period)
plus_di = 100 * (plus_dm.rolling(window=period).mean() / atr_val)
minus_di = 100 * (minus_dm.rolling(window=period).mean() / atr_val)
dx = 100 * abs(plus_di - minus_di) / (plus_di + minus_di)
adx = dx.rolling(window=period).mean()
return adx
def detect_divergence(
self,
prices: pd.Series,
indicator: pd.Series,
lookback: int = 20,
min_swings: int = 2
) -> Dict[str, Any]:
"""Detect divergence between price and indicator.
Divergence occurs when price makes new highs/lows but indicator doesn't,
or vice versa. This is a powerful reversal signal.
Args:
prices: Price series
indicator: Indicator series (e.g., RSI, MACD)
lookback: Lookback period for finding swings
min_swings: Minimum number of swings to detect divergence
Returns:
Dictionary with divergence information:
{
'type': 'bullish', 'bearish', or None
'confidence': 0.0 to 1.0
'price_swing_high': price at high swing
'price_swing_low': price at low swing
'indicator_swing_high': indicator at high swing
'indicator_swing_low': indicator at low swing
}
"""
if len(prices) < lookback * 2 or len(indicator) < lookback * 2:
return {
'type': None,
'confidence': 0.0,
'price_swing_high': None,
'price_swing_low': None,
'indicator_swing_high': None,
'indicator_swing_low': None
}
# Find local extrema (swings)
def find_swings(series: pd.Series, lookback: int):
"""Find local maxima and minima."""
highs = []
lows = []
for i in range(lookback, len(series) - lookback):
window = series.iloc[i-lookback:i+lookback+1]
center = series.iloc[i]
# Local maximum
if center == window.max():
highs.append((i, center))
# Local minimum
elif center == window.min():
lows.append((i, center))
return highs, lows
price_highs, price_lows = find_swings(prices, lookback)
indicator_highs, indicator_lows = find_swings(indicator, lookback)
# Need at least min_swings swings to detect divergence
if len(price_highs) < min_swings or len(price_lows) < min_swings:
return {
'type': None,
'confidence': 0.0,
'price_swing_high': None,
'price_swing_low': None,
'indicator_swing_high': None,
'indicator_swing_low': None
}
# Check for bearish divergence (price makes higher high, indicator makes lower high)
if len(price_highs) >= 2 and len(indicator_highs) >= 2:
recent_price_high = price_highs[-1][1]
prev_price_high = price_highs[-2][1]
recent_indicator_high = indicator_highs[-1][1]
prev_indicator_high = indicator_highs[-2][1]
# Price higher high but indicator lower high = bearish divergence
if recent_price_high > prev_price_high and recent_indicator_high < prev_indicator_high:
confidence = min(1.0, abs(recent_price_high - prev_price_high) / prev_price_high * 10)
return {
'type': 'bearish',
'confidence': confidence,
'price_swing_high': (price_highs[-2][0], price_highs[-1][0]),
'price_swing_low': None,
'indicator_swing_high': (indicator_highs[-2][0], indicator_highs[-1][0]),
'indicator_swing_low': None
}
# Check for bullish divergence (price makes lower low, indicator makes higher low)
if len(price_lows) >= 2 and len(indicator_lows) >= 2:
recent_price_low = price_lows[-1][1]
prev_price_low = price_lows[-2][1]
recent_indicator_low = indicator_lows[-1][1]
prev_indicator_low = indicator_lows[-2][1]
# Price lower low but indicator higher low = bullish divergence
if recent_price_low < prev_price_low and recent_indicator_low > prev_indicator_low:
confidence = min(1.0, abs(prev_price_low - recent_price_low) / prev_price_low * 10)
return {
'type': 'bullish',
'confidence': confidence,
'price_swing_high': None,
'price_swing_low': (price_lows[-2][0], price_lows[-1][0]),
'indicator_swing_high': None,
'indicator_swing_low': (indicator_lows[-2][0], indicator_lows[-1][0])
}
return {
'type': None,
'confidence': 0.0,
'price_swing_high': None,
'price_swing_low': None,
'indicator_swing_high': None,
'indicator_swing_low': None
}
def calculate_all(
self,
df: pd.DataFrame,
indicators: Optional[List[str]] = None
) -> pd.DataFrame:
"""Calculate multiple indicators at once.
Args:
df: DataFrame with OHLCV data (columns: open, high, low, close, volume)
indicators: List of indicator names to calculate (None = all)
Returns:
DataFrame with added indicator columns
"""
result = df.copy()
if 'close' not in result.columns:
raise ValueError("DataFrame must have 'close' column")
close = result['close']
high = result.get('high', close)
low = result.get('low', close)
volume = result.get('volume', pd.Series(1, index=close.index))
# Default indicators if none specified
if indicators is None:
indicators = [
'sma_20', 'ema_20', 'rsi', 'macd', 'bollinger_bands',
'atr', 'obv', 'adx'
]
for indicator in indicators:
try:
if indicator.startswith('sma_'):
period = int(indicator.split('_')[1])
result[f'SMA_{period}'] = self.sma(close, period)
elif indicator.startswith('ema_'):
period = int(indicator.split('_')[1])
result[f'EMA_{period}'] = self.ema(close, period)
elif indicator == 'rsi':
result['RSI'] = self.rsi(close)
elif indicator == 'macd':
macd_data = self.macd(close)
result['MACD'] = macd_data['macd']
result['MACD_Signal'] = macd_data['signal']
result['MACD_Histogram'] = macd_data['histogram']
elif indicator == 'bollinger_bands':
bb_data = self.bollinger_bands(close)
result['BB_Upper'] = bb_data['upper']
result['BB_Middle'] = bb_data['middle']
result['BB_Lower'] = bb_data['lower']
elif indicator == 'atr':
result['ATR'] = self.atr(high, low, close)
elif indicator == 'obv':
result['OBV'] = self.obv(close, volume)
elif indicator == 'adx':
result['ADX'] = self.adx(high, low, close)
except Exception as e:
logger.warning(f"Failed to calculate indicator {indicator}: {e}")
return result
# Global indicators instance
_indicators: Optional[TechnicalIndicators] = None
def get_indicators() -> TechnicalIndicators:
"""Get global technical indicators instance."""
global _indicators
if _indicators is None:
_indicators = TechnicalIndicators()
return _indicators

447
src/data/news_collector.py Normal file
View File

@@ -0,0 +1,447 @@
"""News collector for crypto sentiment analysis.
Collects headlines from multiple sources:
- RSS feeds (CoinDesk, CoinTelegraph, Decrypt, etc.)
- CryptoPanic API (optional, requires API key)
Headlines are cached and refreshed periodically to avoid rate limits.
"""
import asyncio
import re
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
from src.core.logger import get_logger
logger = get_logger(__name__)
class NewsSource(str, Enum):
"""Supported news sources."""
COINDESK = "coindesk"
COINTELEGRAPH = "cointelegraph"
DECRYPT = "decrypt"
BITCOIN_MAGAZINE = "bitcoin_magazine"
THE_BLOCK = "the_block"
MESSARI = "messari"
CRYPTOPANIC = "cryptopanic"
@dataclass
class NewsItem:
"""A single news item."""
title: str
source: NewsSource
published: datetime
url: Optional[str] = None
summary: Optional[str] = None
symbols: List[str] = field(default_factory=list)
# RSS feed URLs for major crypto news sources
RSS_FEEDS: Dict[NewsSource, str] = {
NewsSource.COINDESK: "https://www.coindesk.com/arc/outboundfeeds/rss/",
NewsSource.COINTELEGRAPH: "https://cointelegraph.com/rss",
NewsSource.DECRYPT: "https://decrypt.co/feed",
NewsSource.BITCOIN_MAGAZINE: "https://bitcoinmagazine.com/.rss/full/",
NewsSource.THE_BLOCK: "https://www.theblock.co/rss.xml",
NewsSource.MESSARI: "https://messari.io/rss",
}
# Common crypto symbols to detect in headlines
CRYPTO_SYMBOLS = {
"BTC": ["bitcoin", "btc"],
"ETH": ["ethereum", "eth", "ether"],
"SOL": ["solana", "sol"],
"XRP": ["ripple", "xrp"],
"ADA": ["cardano", "ada"],
"DOGE": ["dogecoin", "doge"],
"DOT": ["polkadot", "dot"],
"AVAX": ["avalanche", "avax"],
"MATIC": ["polygon", "matic"],
"LINK": ["chainlink", "link"],
"UNI": ["uniswap", "uni"],
"ATOM": ["cosmos", "atom"],
"LTC": ["litecoin", "ltc"],
}
class NewsCollector:
"""Collects crypto news headlines for sentiment analysis.
Features:
- Aggregates news from multiple RSS feeds
- Caches headlines to reduce network requests
- Filters by crypto symbol
- Optional CryptoPanic integration for more sources
Usage:
collector = NewsCollector()
headlines = await collector.fetch_headlines()
# Or filter by symbol
btc_headlines = await collector.fetch_headlines(symbols=["BTC"])
"""
# Minimum time between fetches (in seconds)
MIN_FETCH_INTERVAL = 300 # 5 minutes
def __init__(
self,
sources: Optional[List[NewsSource]] = None,
cryptopanic_api_key: Optional[str] = None,
cache_duration: int = 600, # 10 minutes
max_headlines: int = 50
):
"""Initialize NewsCollector.
Args:
sources: List of news sources to use. Defaults to all RSS feeds.
cryptopanic_api_key: Optional API key for CryptoPanic.
cache_duration: How long to cache headlines (seconds).
max_headlines: Maximum headlines to keep in cache.
"""
self.sources = sources or list(RSS_FEEDS.keys())
self.cryptopanic_api_key = cryptopanic_api_key
self.cache_duration = cache_duration
self.max_headlines = max_headlines
self._cache: List[NewsItem] = []
self._last_fetch: Optional[datetime] = None
self._fetching = False
self.logger = get_logger(__name__)
# Check if feedparser is available
try:
import feedparser
self._feedparser = feedparser
self._feedparser_available = True
except ImportError:
self._feedparser_available = False
self.logger.warning(
"feedparser not installed. Install with: pip install feedparser"
)
def _extract_symbols(self, text: str) -> List[str]:
"""Extract crypto symbols mentioned in text.
Args:
text: Text to search (headline, summary)
Returns:
List of detected symbol codes (e.g., ["BTC", "ETH"])
"""
text_lower = text.lower()
detected = []
for symbol, keywords in CRYPTO_SYMBOLS.items():
for keyword in keywords:
if keyword in text_lower:
detected.append(symbol)
break
return detected
async def _fetch_rss_feed(self, source: NewsSource) -> List[NewsItem]:
"""Fetch and parse a single RSS feed.
Args:
source: News source to fetch
Returns:
List of NewsItems from the feed
"""
if not self._feedparser_available:
return []
url = RSS_FEEDS.get(source)
if not url:
return []
try:
# Run feedparser in thread pool to avoid blocking
loop = asyncio.get_event_loop()
feed = await loop.run_in_executor(
None,
self._feedparser.parse,
url
)
items = []
for entry in feed.entries[:20]: # Limit entries per feed
# Parse publication date
published = datetime.now()
if hasattr(entry, 'published_parsed') and entry.published_parsed:
try:
published = datetime(*entry.published_parsed[:6])
except (TypeError, ValueError):
pass
title = entry.get('title', '')
summary = entry.get('summary', '')
# Clean HTML from summary
summary = re.sub(r'<[^>]+>', '', summary)[:200]
item = NewsItem(
title=title,
source=source,
published=published,
url=entry.get('link'),
summary=summary,
symbols=self._extract_symbols(f"{title} {summary}")
)
items.append(item)
self.logger.debug(f"Fetched {len(items)} items from {source.value}")
return items
except Exception as e:
self.logger.warning(f"Failed to fetch {source.value} RSS: {e}")
return []
async def _fetch_cryptopanic(self, symbols: Optional[List[str]] = None) -> List[NewsItem]:
"""Fetch news from CryptoPanic API.
Args:
symbols: Optional list of symbols to filter
Returns:
List of NewsItems from CryptoPanic
"""
if not self.cryptopanic_api_key:
return []
try:
import aiohttp
url = "https://cryptopanic.com/api/v1/posts/"
params = {
"auth_token": self.cryptopanic_api_key,
"public": "true",
}
if symbols:
params["currencies"] = ",".join(symbols)
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, timeout=10) as response:
if response.status != 200:
self.logger.warning(f"CryptoPanic API error: {response.status}")
return []
data = await response.json()
items = []
for post in data.get("results", [])[:20]:
published = datetime.now()
if post.get("published_at"):
try:
published = datetime.fromisoformat(
post["published_at"].replace("Z", "+00:00")
)
except (ValueError, TypeError):
pass
item = NewsItem(
title=post.get("title", ""),
source=NewsSource.CRYPTOPANIC,
published=published,
url=post.get("url"),
symbols=[c["code"] for c in post.get("currencies", [])]
)
items.append(item)
self.logger.debug(f"Fetched {len(items)} items from CryptoPanic")
return items
except ImportError:
self.logger.warning("aiohttp not installed for CryptoPanic API")
return []
except Exception as e:
self.logger.warning(f"Failed to fetch CryptoPanic: {e}")
return []
async def fetch_news(
self,
symbols: Optional[List[str]] = None,
force_refresh: bool = False
) -> List[NewsItem]:
"""Fetch news items from all sources.
Args:
symbols: Optional list of symbols to filter (e.g., ["BTC", "ETH"])
force_refresh: Force a refresh even if cache is valid
Returns:
List of NewsItems sorted by publication date (newest first)
"""
now = datetime.now()
# Check cache validity
cache_valid = (
self._last_fetch is not None and
(now - self._last_fetch).total_seconds() < self.cache_duration and
len(self._cache) > 0
)
if cache_valid and not force_refresh:
self.logger.debug("Using cached news items")
items = self._cache
else:
# Prevent concurrent fetches
if self._fetching:
self.logger.debug("Fetch already in progress, using cache")
items = self._cache
else:
self._fetching = True
try:
items = await self._fetch_all_sources()
self._cache = items
self._last_fetch = now
finally:
self._fetching = False
# Filter by symbols if specified
if symbols:
symbols_upper = [s.upper() for s in symbols]
items = [
item for item in items
if any(s in symbols_upper for s in item.symbols) or not item.symbols
]
return items
async def _fetch_all_sources(self) -> List[NewsItem]:
"""Fetch from all configured sources concurrently."""
tasks = []
# RSS feeds
for source in self.sources:
if source in RSS_FEEDS:
tasks.append(self._fetch_rss_feed(source))
# CryptoPanic
if self.cryptopanic_api_key:
tasks.append(self._fetch_cryptopanic())
if not tasks:
self.logger.warning("No news sources configured")
return []
# Fetch all concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Combine results
all_items = []
for result in results:
if isinstance(result, list):
all_items.extend(result)
elif isinstance(result, Exception):
self.logger.warning(f"Source fetch failed: {result}")
# Sort by publication date (newest first)
all_items.sort(key=lambda x: x.published, reverse=True)
# Limit total items
all_items = all_items[:self.max_headlines]
self.logger.info(f"Fetched {len(all_items)} total news items")
return all_items
async def fetch_headlines(
self,
symbols: Optional[List[str]] = None,
max_age_hours: int = 24,
force_refresh: bool = False
) -> List[str]:
"""Fetch headlines as strings for sentiment analysis.
This is the main method to use with SentimentScanner.
Args:
symbols: Optional list of symbols to filter
max_age_hours: Only include headlines from the last N hours
force_refresh: Force a refresh even if cache is valid
Returns:
List of headline strings
"""
items = await self.fetch_news(symbols=symbols, force_refresh=force_refresh)
# Filter by age
cutoff = datetime.now() - timedelta(hours=max_age_hours)
recent_items = [item for item in items if item.published > cutoff]
# Extract just the titles
headlines = [item.title for item in recent_items if item.title]
self.logger.debug(f"Returning {len(headlines)} headlines for analysis")
return headlines
def get_cached_headlines(self, symbols: Optional[List[str]] = None) -> List[str]:
"""Get cached headlines synchronously (no fetch).
Args:
symbols: Optional list of symbols to filter
Returns:
List of cached headline strings
"""
items = self._cache
if symbols:
symbols_upper = [s.upper() for s in symbols]
items = [
item for item in items
if any(s in symbols_upper for s in item.symbols) or not item.symbols
]
return [item.title for item in items if item.title]
def get_status(self) -> Dict[str, Any]:
"""Get collector status information.
Returns:
Dictionary with status info
"""
return {
"sources": [s.value for s in self.sources],
"cryptopanic_enabled": self.cryptopanic_api_key is not None,
"feedparser_available": self._feedparser_available,
"cached_items": len(self._cache),
"last_fetch": self._last_fetch.isoformat() if self._last_fetch else None,
"cache_age_seconds": (
(datetime.now() - self._last_fetch).total_seconds()
if self._last_fetch else None
),
}
def clear_cache(self):
"""Clear the headline cache."""
self._cache = []
self._last_fetch = None
self.logger.info("News cache cleared")
# Global instance
_news_collector: Optional[NewsCollector] = None
def get_news_collector(**kwargs) -> NewsCollector:
"""Get or create the global NewsCollector instance.
Args:
**kwargs: Arguments passed to NewsCollector constructor
Returns:
NewsCollector instance
"""
global _news_collector
if _news_collector is None:
_news_collector = NewsCollector(**kwargs)
logger.info("Created global NewsCollector instance")
return _news_collector

406
src/data/pricing_service.py Normal file
View File

@@ -0,0 +1,406 @@
"""Unified pricing data service with multi-provider support and automatic failover."""
import time
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from decimal import Decimal
from .providers.base_provider import BasePricingProvider
from .providers.ccxt_provider import CCXTProvider
from .providers.coingecko_provider import CoinGeckoProvider
from .cache_manager import CacheManager
from .health_monitor import HealthMonitor, HealthStatus
from src.core.config import get_config
from src.core.logger import get_logger
logger = get_logger(__name__)
class PricingService:
"""Unified pricing data service with multi-provider support.
Manages multiple pricing providers with automatic failover, caching,
and health monitoring. Provides a single consistent API for accessing
market data regardless of the underlying provider.
"""
def __init__(self):
"""Initialize pricing service."""
self.config = get_config()
self.logger = get_logger(__name__)
# Initialize components
self.cache = CacheManager(
default_ttl=self.config.get("data_providers.caching.ohlcv_ttl", 60),
max_size=self.config.get("data_providers.caching.max_cache_size", 1000),
ticker_ttl=self.config.get("data_providers.caching.ticker_ttl", 2),
ohlcv_ttl=self.config.get("data_providers.caching.ohlcv_ttl", 60),
)
self.health_monitor = HealthMonitor()
# Provider instances
self._providers: Dict[str, BasePricingProvider] = {}
self._active_provider: Optional[str] = None
self._provider_priority: List[str] = []
# Subscriptions
self._subscriptions: Dict[str, List[Callable]] = {}
# Initialize providers
self._initialize_providers()
def _initialize_providers(self):
"""Initialize providers from configuration."""
# Get primary providers from config
primary_config = self.config.get("data_providers.primary", [])
if not primary_config:
# Default configuration
primary_config = [
{'name': 'kraken', 'enabled': True, 'priority': 1},
{'name': 'coinbase', 'enabled': True, 'priority': 2},
{'name': 'binance', 'enabled': True, 'priority': 3},
]
# Sort by priority
primary_config = sorted(
[p for p in primary_config if p.get('enabled', True)],
key=lambda x: x.get('priority', 999)
)
# Create CCXT providers for each exchange
for provider_config in primary_config:
exchange_name = provider_config.get('name')
try:
provider = CCXTProvider(exchange_name=exchange_name)
provider_name = provider.name
if provider.connect():
self._providers[provider_name] = provider
self._provider_priority.append(provider_name)
self.logger.info(f"Initialized provider: {provider_name}")
else:
self.logger.warning(f"Failed to connect provider: {provider_name}")
except Exception as e:
self.logger.error(f"Error initializing provider {exchange_name}: {e}")
# Add fallback provider (CoinGecko)
fallback_config = self.config.get("data_providers.fallback", {})
if fallback_config.get('enabled', True):
try:
coingecko = CoinGeckoProvider(api_key=fallback_config.get('api_key'))
if coingecko.connect():
self._providers[coingecko.name] = coingecko
self._provider_priority.append(coingecko.name)
self.logger.info(f"Initialized fallback provider: {coingecko.name}")
else:
self.logger.warning("Failed to connect CoinGecko fallback provider")
except Exception as e:
self.logger.error(f"Error initializing CoinGecko provider: {e}")
# Select initial active provider
self._select_active_provider()
def _select_active_provider(self) -> Optional[str]:
"""Select the best available provider.
Returns:
Name of selected provider or None
"""
# Filter to healthy providers
healthy_providers = [
name for name in self._provider_priority
if name in self._providers
and self.health_monitor.is_healthy(name)
]
if not healthy_providers:
# Fall back to any available provider if none are healthy
healthy_providers = list(self._providers.keys())
if not healthy_providers:
self.logger.error("No providers available")
self._active_provider = None
return None
# Select first healthy provider (already sorted by priority)
self._active_provider = healthy_providers[0]
self.logger.info(f"Selected active provider: {self._active_provider}")
return self._active_provider
def _get_provider(self, provider_name: Optional[str] = None) -> Optional[BasePricingProvider]:
"""Get a provider instance.
Args:
provider_name: Name of provider, or None to use active provider
Returns:
Provider instance or None
"""
if provider_name:
return self._providers.get(provider_name)
# Use active provider, or select one if none active
if not self._active_provider:
self._select_active_provider()
return self._providers.get(self._active_provider) if self._active_provider else None
def _execute_with_failover(
self,
operation: Callable[[BasePricingProvider], Any],
operation_name: str
) -> Any:
"""Execute an operation with automatic failover.
Args:
operation: Function that takes a provider and returns a result
operation_name: Name of operation for logging
Returns:
Operation result or None if all providers fail
"""
# Try active provider first
providers_to_try = [self._active_provider] if self._active_provider else []
# Add other providers in priority order
for provider_name in self._provider_priority:
if provider_name != self._active_provider and provider_name in self._providers:
providers_to_try.append(provider_name)
last_error = None
for provider_name in providers_to_try:
provider = self._providers.get(provider_name)
if not provider:
continue
# Check health
if not self.health_monitor.is_healthy(provider_name):
self.logger.debug(f"Skipping unhealthy provider: {provider_name}")
continue
try:
start_time = time.time()
result = operation(provider)
response_time = time.time() - start_time
# Record success
self.health_monitor.record_success(provider_name, response_time)
# Update active provider if we used a different one
if provider_name != self._active_provider:
self.logger.info(f"Switched to provider: {provider_name}")
self._active_provider = provider_name
return result
except Exception as e:
last_error = e
self.logger.warning(f"{operation_name} failed on {provider_name}: {e}")
self.health_monitor.record_failure(provider_name)
# Try next provider
continue
# All providers failed
self.logger.error(f"{operation_name} failed on all providers")
if last_error:
raise last_error
return None
def get_ticker(self, symbol: str, use_cache: bool = True) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading pair symbol (e.g., 'BTC/USD')
use_cache: Whether to use cache
Returns:
Ticker data dictionary
"""
cache_key = f"ticker:{symbol}"
# Check cache
if use_cache:
cached = self.cache.get(cache_key)
if cached:
return cached
# Fetch from provider
def fetch_ticker(provider: BasePricingProvider):
return provider.get_ticker(symbol)
ticker_data = self._execute_with_failover(fetch_ticker, f"get_ticker({symbol})")
if ticker_data:
# Cache the result
if use_cache:
self.cache.set(cache_key, ticker_data, cache_type='ticker')
return ticker_data
return {}
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100,
use_cache: bool = True
) -> List[List]:
"""Get OHLCV candlestick data.
Args:
symbol: Trading pair symbol
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
since: Start datetime
limit: Number of candles
use_cache: Whether to use cache
Returns:
List of [timestamp_ms, open, high, low, close, volume]
"""
cache_key = f"ohlcv:{symbol}:{timeframe}:{limit}"
# Check cache (only if no 'since' parameter, as it changes the result)
if use_cache and not since:
cached = self.cache.get(cache_key)
if cached:
return cached
# Fetch from provider
def fetch_ohlcv(provider: BasePricingProvider):
return provider.get_ohlcv(symbol, timeframe, since, limit)
ohlcv_data = self._execute_with_failover(
fetch_ohlcv,
f"get_ohlcv({symbol}, {timeframe})"
)
if ohlcv_data:
# Cache the result (only if no 'since' parameter)
if use_cache and not since:
self.cache.set(cache_key, ohlcv_data, cache_type='ohlcv')
return ohlcv_data
return []
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates.
Args:
symbol: Trading pair symbol
callback: Callback function(data) called on price updates
Returns:
True if subscription successful
"""
key = f"ticker:{symbol}"
# Add callback
if key not in self._subscriptions:
self._subscriptions[key] = []
if callback not in self._subscriptions[key]:
self._subscriptions[key].append(callback)
# Wrap callback to handle failover
def wrapped_callback(data):
for cb in self._subscriptions.get(key, []):
try:
cb(data)
except Exception as e:
self.logger.error(f"Callback error for {symbol}: {e}")
# Subscribe via active provider
provider = self._get_provider()
if provider:
try:
success = provider.subscribe_ticker(symbol, wrapped_callback)
if success:
self.logger.info(f"Subscribed to ticker updates for {symbol}")
return True
except Exception as e:
self.logger.error(f"Failed to subscribe to ticker for {symbol}: {e}")
return False
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates.
Args:
symbol: Trading pair symbol
callback: Specific callback to remove, or None to remove all
"""
key = f"ticker:{symbol}"
# Remove callback
if key in self._subscriptions:
if callback:
if callback in self._subscriptions[key]:
self._subscriptions[key].remove(callback)
if not self._subscriptions[key]:
del self._subscriptions[key]
else:
del self._subscriptions[key]
# Unsubscribe from all providers
for provider in self._providers.values():
try:
provider.unsubscribe_ticker(symbol, callback)
except Exception:
pass
self.logger.info(f"Unsubscribed from ticker updates for {symbol}")
def get_active_provider(self) -> Optional[str]:
"""Get name of active provider.
Returns:
Provider name or None
"""
return self._active_provider
def get_provider_health(self, provider_name: Optional[str] = None) -> Dict[str, Any]:
"""Get health status for a provider or all providers.
Args:
provider_name: Provider name, or None for all providers
Returns:
Health status dictionary
"""
if provider_name:
metrics = self.health_monitor.get_metrics(provider_name)
if metrics:
return metrics.to_dict()
return {}
return self.health_monitor.get_all_metrics()
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics.
Returns:
Cache statistics dictionary
"""
return self.cache.get_stats()
# Global pricing service instance
_pricing_service: Optional[PricingService] = None
def get_pricing_service() -> PricingService:
"""Get global pricing service instance.
Returns:
PricingService instance
"""
global _pricing_service
if _pricing_service is None:
_pricing_service = PricingService()
return _pricing_service

View File

@@ -0,0 +1,7 @@
"""Pricing data providers package."""
from .base_provider import BasePricingProvider
from .ccxt_provider import CCXTProvider
from .coingecko_provider import CoinGeckoProvider
__all__ = ['BasePricingProvider', 'CCXTProvider', 'CoinGeckoProvider']

View File

@@ -0,0 +1,150 @@
"""Base pricing provider interface."""
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from src.core.logger import get_logger
logger = get_logger(__name__)
class BasePricingProvider(ABC):
"""Base class for pricing data providers.
Pricing providers are responsible for fetching market data (prices, OHLCV)
without requiring API keys. They differ from exchange adapters which handle
trading operations.
"""
def __init__(self):
"""Initialize pricing provider."""
self.logger = get_logger(f"provider.{self.__class__.__name__}")
self._connected = False
self._subscribers: Dict[str, List[Callable]] = {}
@property
@abstractmethod
def name(self) -> str:
"""Provider name."""
pass
@property
@abstractmethod
def supports_websocket(self) -> bool:
"""Whether this provider supports WebSocket connections."""
pass
@abstractmethod
def connect(self) -> bool:
"""Connect to the provider.
Returns:
True if connection successful
"""
pass
@abstractmethod
def disconnect(self):
"""Disconnect from the provider."""
pass
@abstractmethod
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading pair symbol (e.g., 'BTC/USD')
Returns:
Dictionary with ticker data:
- 'symbol': str
- 'bid': Decimal
- 'ask': Decimal
- 'last': Decimal (last price)
- 'high': Decimal (24h high)
- 'low': Decimal (24h low)
- 'volume': Decimal (24h volume)
- 'timestamp': int (Unix timestamp in milliseconds)
"""
pass
@abstractmethod
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV (candlestick) data.
Args:
symbol: Trading pair symbol
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
since: Start datetime
limit: Number of candles
Returns:
List of [timestamp_ms, open, high, low, close, volume]
"""
pass
@abstractmethod
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates.
Args:
symbol: Trading pair symbol
callback: Callback function(data) called on price updates
Returns:
True if subscription successful
"""
pass
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates.
Args:
symbol: Trading pair symbol
callback: Specific callback to remove, or None to remove all
"""
key = f"ticker_{symbol}"
if key in self._subscribers:
if callback:
if callback in self._subscribers[key]:
self._subscribers[key].remove(callback)
if not self._subscribers[key]:
del self._subscribers[key]
else:
del self._subscribers[key]
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format for this provider.
Args:
symbol: Symbol to normalize
Returns:
Normalized symbol
"""
# Default: uppercase and replace dashes with slashes
return symbol.upper().replace('-', '/')
def is_connected(self) -> bool:
"""Check if provider is connected.
Returns:
True if connected
"""
return self._connected
def get_supported_symbols(self) -> List[str]:
"""Get list of supported trading symbols.
Returns:
List of supported symbols, or empty list if not available
"""
# Default implementation - override in subclasses
return []

View File

@@ -0,0 +1,333 @@
"""CCXT-based pricing provider with multi-exchange support and WebSocket capabilities."""
import ccxt
import threading
import time
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from .base_provider import BasePricingProvider
from src.core.logger import get_logger
logger = get_logger(__name__)
class CCXTProvider(BasePricingProvider):
"""CCXT-based pricing provider with multi-exchange fallback support.
This provider uses CCXT to connect to multiple exchanges (Kraken, Coinbase, Binance)
as primary data sources. It supports WebSocket where available and falls back to
polling if WebSocket is not available.
"""
def __init__(self, exchange_name: Optional[str] = None):
"""Initialize CCXT provider.
Args:
exchange_name: Specific exchange to use ('kraken', 'coinbase', 'binance'),
or None to try multiple exchanges
"""
super().__init__()
self.exchange_name = exchange_name
self.exchange = None
self._selected_exchange_id = None
self._polling_threads: Dict[str, threading.Thread] = {}
self._stop_polling: Dict[str, bool] = {}
# Exchange priority order (try first to last)
if exchange_name:
self._exchange_options = [(exchange_name.lower(), exchange_name.capitalize())]
else:
self._exchange_options = [
('kraken', 'Kraken'),
('coinbase', 'Coinbase'),
('binance', 'Binance'),
]
@property
def name(self) -> str:
"""Provider name."""
if self._selected_exchange_id:
return f"CCXT-{self._selected_exchange_id.capitalize()}"
return "CCXT Provider"
@property
def supports_websocket(self) -> bool:
"""Check if current exchange supports WebSocket."""
if not self.exchange:
return False
# Check if exchange has WebSocket support
# Most modern CCXT exchanges support WebSocket, but we check explicitly
exchange_id = getattr(self.exchange, 'id', '').lower()
# Known WebSocket-capable exchanges
ws_capable = ['kraken', 'coinbase', 'binance', 'binanceus', 'okx']
return exchange_id in ws_capable
def connect(self) -> bool:
"""Connect to an exchange via CCXT.
Tries multiple exchanges in order until one succeeds.
Returns:
True if connection successful
"""
for exchange_id, exchange_display_name in self._exchange_options:
try:
# Get exchange class from CCXT
exchange_class = getattr(ccxt, exchange_id, None)
if not exchange_class:
logger.warning(f"Exchange {exchange_id} not found in CCXT")
continue
# Create exchange instance
self.exchange = exchange_class({
'enableRateLimit': True,
'options': {
'defaultType': 'spot',
}
})
# Load markets to test connection
if not hasattr(self.exchange, 'markets') or not self.exchange.markets:
try:
self.exchange.load_markets()
except Exception as e:
logger.warning(f"Failed to load markets for {exchange_id}: {e}")
continue
# Test connection with a common symbol
test_symbols = ['BTC/USDT', 'BTC/USD', 'BTC/EUR']
ticker_result = None
for test_symbol in test_symbols:
try:
ticker_result = self.exchange.fetch_ticker(test_symbol)
if ticker_result:
break
except Exception:
continue
if not ticker_result:
logger.warning(f"Could not fetch ticker from {exchange_id}")
continue
# Success!
self._selected_exchange_id = exchange_id
self._connected = True
logger.info(f"Connected to {exchange_display_name} via CCXT")
return True
except Exception as e:
logger.warning(f"Failed to connect to {exchange_id}: {e}")
continue
# All exchanges failed
logger.error("Failed to connect to any CCXT exchange")
self._connected = False
self.exchange = None
return False
def disconnect(self):
"""Disconnect from exchange."""
# Stop all polling threads
for symbol in list(self._stop_polling.keys()):
self._stop_polling[symbol] = True
# Wait a bit for threads to stop
for symbol, thread in self._polling_threads.items():
if thread.is_alive():
thread.join(timeout=2.0)
self._polling_threads.clear()
self._stop_polling.clear()
self._subscribers.clear()
self._connected = False
self.exchange = None
self._selected_exchange_id = None
logger.info(f"Disconnected from {self.name}")
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data."""
if not self._connected or not self.exchange:
logger.error("Provider not connected")
return {}
try:
normalized_symbol = self.normalize_symbol(symbol)
ticker = self.exchange.fetch_ticker(normalized_symbol)
return {
'symbol': symbol,
'bid': Decimal(str(ticker.get('bid', 0) or 0)),
'ask': Decimal(str(ticker.get('ask', 0) or 0)),
'last': Decimal(str(ticker.get('last', 0) or ticker.get('close', 0) or 0)),
'high': Decimal(str(ticker.get('high', 0) or 0)),
'low': Decimal(str(ticker.get('low', 0) or 0)),
'volume': Decimal(str(ticker.get('quoteVolume', ticker.get('volume', 0)) or 0)),
'timestamp': ticker.get('timestamp', int(time.time() * 1000)),
}
except Exception as e:
logger.error(f"Failed to get ticker for {symbol}: {e}")
return {}
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV candlestick data."""
if not self._connected or not self.exchange:
logger.error("Provider not connected")
return []
try:
since_timestamp = int(since.timestamp() * 1000) if since else None
normalized_symbol = self.normalize_symbol(symbol)
# Most exchanges support up to 1000 candles per request
max_limit = min(limit, 1000)
ohlcv = self.exchange.fetch_ohlcv(
normalized_symbol,
timeframe,
since_timestamp,
max_limit
)
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} ({timeframe})")
return ohlcv
except Exception as e:
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
return []
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates.
Uses WebSocket if available, otherwise falls back to polling.
"""
if not self._connected or not self.exchange:
logger.error("Provider not connected")
return False
key = f"ticker_{symbol}"
# Add callback to subscribers
if key not in self._subscribers:
self._subscribers[key] = []
if callback not in self._subscribers[key]:
self._subscribers[key].append(callback)
# Try WebSocket first if supported
if self.supports_websocket:
try:
# CCXT WebSocket support varies by exchange
# For now, use polling as fallback since WebSocket implementation
# in CCXT requires exchange-specific handling
# TODO: Implement native WebSocket when CCXT adds better support
pass
except Exception as e:
logger.warning(f"WebSocket subscription failed, falling back to polling: {e}")
# Use polling as primary method (more reliable across exchanges)
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
self._stop_polling[key] = False
def poll_ticker():
"""Poll ticker every 2 seconds."""
while not self._stop_polling.get(key, False):
try:
ticker_data = self.get_ticker(symbol)
if ticker_data and 'last' in ticker_data:
# Call all callbacks for this symbol
for cb in self._subscribers.get(key, []):
try:
cb({
'symbol': symbol,
'price': ticker_data['last'],
'bid': ticker_data.get('bid', 0),
'ask': ticker_data.get('ask', 0),
'volume': ticker_data.get('volume', 0),
'timestamp': ticker_data.get('timestamp'),
})
except Exception as e:
logger.error(f"Callback error for {symbol}: {e}")
time.sleep(2) # Poll every 2 seconds
except Exception as e:
logger.error(f"Ticker polling error for {symbol}: {e}")
time.sleep(5) # Wait longer on error
thread = threading.Thread(target=poll_ticker, daemon=True)
thread.start()
self._polling_threads[key] = thread
logger.info(f"Subscribed to ticker updates for {symbol} (polling mode)")
return True
return True
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates."""
key = f"ticker_{symbol}"
# Stop polling thread
if key in self._stop_polling:
self._stop_polling[key] = True
# Remove from subscribers
super().unsubscribe_ticker(symbol, callback)
# Clean up thread reference
if key in self._polling_threads:
del self._polling_threads[key]
logger.info(f"Unsubscribed from ticker updates for {symbol}")
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol for the selected exchange."""
if not self.exchange:
return symbol.replace('-', '/').upper()
# Basic normalization
normalized = symbol.replace('-', '/').upper()
# Try to use exchange's markets to find correct symbol
try:
if hasattr(self.exchange, 'markets') and self.exchange.markets:
# Check if symbol exists
if normalized in self.exchange.markets:
return normalized
# Try alternative formats
if '/USD' in normalized:
alt_symbol = normalized.replace('/USD', '/USDT')
if alt_symbol in self.exchange.markets:
return alt_symbol
# For Kraken: try XBT instead of BTC
if normalized.startswith('BTC/'):
alt_symbol = normalized.replace('BTC/', 'XBT/')
if alt_symbol in self.exchange.markets:
return alt_symbol
except Exception:
pass
# Fallback: return normalized (let CCXT handle errors)
return normalized
def get_supported_symbols(self) -> List[str]:
"""Get list of supported trading symbols."""
if not self._connected or not self.exchange:
return []
try:
markets = self.exchange.load_markets()
return list(markets.keys())
except Exception as e:
logger.error(f"Failed to get supported symbols: {e}")
return []

View File

@@ -0,0 +1,376 @@
"""CoinGecko pricing provider for fallback market data."""
import httpx
import time
import threading
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable, Tuple
from datetime import datetime
from .base_provider import BasePricingProvider
from src.core.logger import get_logger
logger = get_logger(__name__)
class CoinGeckoProvider(BasePricingProvider):
"""CoinGecko API pricing provider.
This provider uses CoinGecko's free API tier as a fallback when CCXT
providers are unavailable. It uses simple REST endpoints that don't
require authentication.
"""
BASE_URL = "https://api.coingecko.com/api/v3"
# CoinGecko coin ID mapping for common symbols
COIN_ID_MAP = {
'BTC': 'bitcoin',
'ETH': 'ethereum',
'BNB': 'binancecoin',
'SOL': 'solana',
'ADA': 'cardano',
'XRP': 'ripple',
'DOGE': 'dogecoin',
'DOT': 'polkadot',
'MATIC': 'matic-network',
'AVAX': 'avalanche-2',
'LINK': 'chainlink',
'USDT': 'tether',
'USDC': 'usd-coin',
'DAI': 'dai',
}
# Currency mapping
CURRENCY_MAP = {
'USD': 'usd',
'EUR': 'eur',
'GBP': 'gbp',
'JPY': 'jpy',
'USDT': 'usd', # CoinGecko uses USD for stablecoins
}
def __init__(self, api_key: Optional[str] = None):
"""Initialize CoinGecko provider.
Args:
api_key: Optional API key for higher rate limits (free tier doesn't require it)
"""
super().__init__()
self.api_key = api_key
self._client = None
self._polling_threads: Dict[str, threading.Thread] = {}
self._stop_polling: Dict[str, bool] = {}
self._rate_limit_delay = 1.0 # Free tier: ~30-50 calls/minute
@property
def name(self) -> str:
"""Provider name."""
return "CoinGecko"
@property
def supports_websocket(self) -> bool:
"""CoinGecko free tier doesn't support WebSocket."""
return False
def connect(self) -> bool:
"""Connect to CoinGecko API.
Returns:
True if connection successful (just validates API access)
"""
try:
self._client = httpx.Client(timeout=10.0)
# Test connection by fetching Bitcoin price
response = self._client.get(
f"{self.BASE_URL}/simple/price",
params={
'ids': 'bitcoin',
'vs_currencies': 'usd',
}
)
if response.status_code == 200:
data = response.json()
if 'bitcoin' in data:
self._connected = True
logger.info("Connected to CoinGecko API")
return True
logger.warning("CoinGecko API test failed")
self._connected = False
return False
except Exception as e:
logger.error(f"Failed to connect to CoinGecko: {e}")
self._connected = False
return False
def disconnect(self):
"""Disconnect from CoinGecko API."""
# Stop all polling threads
for symbol in list(self._stop_polling.keys()):
self._stop_polling[symbol] = True
# Wait for threads to stop
for symbol, thread in self._polling_threads.items():
if thread.is_alive():
thread.join(timeout=2.0)
self._polling_threads.clear()
self._stop_polling.clear()
self._subscribers.clear()
if self._client:
self._client.close()
self._client = None
self._connected = False
logger.info("Disconnected from CoinGecko API")
def _parse_symbol(self, symbol: str) -> Tuple[Optional[str], Optional[str]]:
"""Parse symbol into coin_id and currency.
Args:
symbol: Trading pair like 'BTC/USD' or 'BTC/USDT'
Returns:
Tuple of (coin_id, currency) or (None, None) if not found
"""
parts = symbol.upper().replace('-', '/').split('/')
if len(parts) != 2:
return None, None
base, quote = parts
# Get coin ID
coin_id = self.COIN_ID_MAP.get(base)
if not coin_id:
# Try lowercase base as fallback
coin_id = base.lower()
# Get currency
currency = self.CURRENCY_MAP.get(quote, quote.lower())
return coin_id, currency
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data from CoinGecko."""
if not self._connected or not self._client:
logger.error("Provider not connected")
return {}
try:
coin_id, currency = self._parse_symbol(symbol)
if not coin_id:
logger.error(f"Unknown symbol format: {symbol}")
return {}
# Fetch current price
response = self._client.get(
f"{self.BASE_URL}/simple/price",
params={
'ids': coin_id,
'vs_currencies': currency,
'include_24hr_change': 'true',
'include_24hr_vol': 'true',
}
)
if response.status_code != 200:
logger.error(f"CoinGecko API error: {response.status_code}")
return {}
data = response.json()
if coin_id not in data:
logger.error(f"Coin {coin_id} not found in CoinGecko response")
return {}
coin_data = data[coin_id]
price_key = currency.lower()
if price_key not in coin_data:
logger.error(f"Currency {currency} not found for {coin_id}")
return {}
price = Decimal(str(coin_data[price_key]))
# Calculate high/low from 24h change if available
change_key = f"{price_key}_24h_change"
vol_key = f"{price_key}_24h_vol"
change_24h = coin_data.get(change_key, 0) or 0
volume_24h = coin_data.get(vol_key, 0) or 0
# Estimate high/low from current price and 24h change
# This is approximate since CoinGecko free tier doesn't provide exact high/low
current_price = float(price)
if change_24h:
estimated_high = current_price * (1 + abs(change_24h / 100) / 2)
estimated_low = current_price * (1 - abs(change_24h / 100) / 2)
else:
estimated_high = current_price
estimated_low = current_price
return {
'symbol': symbol,
'bid': price, # CoinGecko doesn't provide bid/ask, use last price
'ask': price,
'last': price,
'high': Decimal(str(estimated_high)),
'low': Decimal(str(estimated_low)),
'volume': Decimal(str(volume_24h)),
'timestamp': int(time.time() * 1000),
}
except Exception as e:
logger.error(f"Failed to get ticker for {symbol}: {e}")
return {}
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV data from CoinGecko.
Note: CoinGecko free tier has limited historical data access.
This method may return empty data for some timeframes.
"""
if not self._connected or not self._client:
logger.error("Provider not connected")
return []
try:
coin_id, currency = self._parse_symbol(symbol)
if not coin_id:
logger.error(f"Unknown symbol format: {symbol}")
return []
# CoinGecko uses days parameter instead of timeframe
# Map timeframe to days
timeframe_days_map = {
'1m': 1, # Last 24 hours, minute data (not available in free tier)
'5m': 1,
'15m': 1,
'30m': 1,
'1h': 1, # Last 24 hours, hourly data
'4h': 7, # Last 7 days
'1d': 30, # Last 30 days
'1w': 90, # Last 90 days
}
days = timeframe_days_map.get(timeframe, 7)
# Fetch OHLC data
response = self._client.get(
f"{self.BASE_URL}/coins/{coin_id}/ohlc",
params={
'vs_currency': currency,
'days': days,
}
)
if response.status_code != 200:
logger.warning(f"CoinGecko OHLC API returned {response.status_code}")
return []
data = response.json()
# CoinGecko returns: [timestamp_ms, open, high, low, close]
# We need to convert to: [timestamp_ms, open, high, low, close, volume]
# Note: CoinGecko OHLC endpoint doesn't include volume
# Filter by since if provided
if since:
since_timestamp = int(since.timestamp() * 1000)
data = [candle for candle in data if candle[0] >= since_timestamp]
# Limit results
data = data[-limit:] if limit else data
# Add volume as 0 (CoinGecko doesn't provide it in OHLC endpoint)
ohlcv = [candle + [0] for candle in data]
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} from CoinGecko")
return ohlcv
except Exception as e:
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
return []
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates via polling."""
if not self._connected or not self._client:
logger.error("Provider not connected")
return False
key = f"ticker_{symbol}"
# Add callback to subscribers
if key not in self._subscribers:
self._subscribers[key] = []
if callback not in self._subscribers[key]:
self._subscribers[key].append(callback)
# Start polling thread if not already running
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
self._stop_polling[key] = False
def poll_ticker():
"""Poll ticker respecting rate limits."""
while not self._stop_polling.get(key, False):
try:
ticker_data = self.get_ticker(symbol)
if ticker_data and 'last' in ticker_data:
# Call all callbacks
for cb in self._subscribers.get(key, []):
try:
cb({
'symbol': symbol,
'price': ticker_data['last'],
'bid': ticker_data.get('bid', 0),
'ask': ticker_data.get('ask', 0),
'volume': ticker_data.get('volume', 0),
'timestamp': ticker_data.get('timestamp'),
})
except Exception as e:
logger.error(f"Callback error for {symbol}: {e}")
# Rate limit: wait between requests (free tier: ~30-50 calls/min)
time.sleep(self._rate_limit_delay * 2) # Poll every 2 seconds
except Exception as e:
logger.error(f"Ticker polling error for {symbol}: {e}")
time.sleep(10) # Wait longer on error
thread = threading.Thread(target=poll_ticker, daemon=True)
thread.start()
self._polling_threads[key] = thread
logger.info(f"Subscribed to ticker updates for {symbol} (CoinGecko polling)")
return True
return True
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates."""
key = f"ticker_{symbol}"
# Stop polling
if key in self._stop_polling:
self._stop_polling[key] = True
# Remove from subscribers
super().unsubscribe_ticker(symbol, callback)
# Clean up thread
if key in self._polling_threads:
del self._polling_threads[key]
logger.info(f"Unsubscribed from ticker updates for {symbol}")
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format."""
# CoinGecko uses coin IDs, so we just normalize the format
return symbol.replace('-', '/').upper()

116
src/data/quality.py Normal file
View File

@@ -0,0 +1,116 @@
"""Data quality validation, gap filling, and retention policies."""
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from src.core.database import get_database, MarketData
from src.core.config import get_config
from src.core.logger import get_logger
logger = get_logger(__name__)
class DataQualityManager:
"""Manages data quality and retention."""
def __init__(self):
"""Initialize data quality manager."""
self.db = get_database()
self.config = get_config()
self.logger = get_logger(__name__)
def validate_data_quality(
self,
exchange: str,
symbol: str,
timeframe: str,
start_date: datetime,
end_date: datetime
) -> Dict[str, Any]:
"""Validate data quality.
Args:
exchange: Exchange name
symbol: Trading symbol
timeframe: Timeframe
start_date: Start date
end_date: End date
Returns:
Quality report
"""
session = self.db.get_session()
try:
data = session.query(MarketData).filter(
MarketData.exchange == exchange,
MarketData.symbol == symbol,
MarketData.timeframe == timeframe,
MarketData.timestamp >= start_date,
MarketData.timestamp <= end_date
).order_by(MarketData.timestamp).all()
if len(data) == 0:
return {"valid": False, "reason": "No data"}
# Check for gaps
gaps = self._detect_gaps(data, timeframe)
# Check for anomalies
anomalies = self._detect_anomalies(data)
return {
"valid": len(gaps) == 0 and len(anomalies) == 0,
"total_records": len(data),
"gaps": len(gaps),
"anomalies": len(anomalies),
}
finally:
session.close()
def _detect_gaps(self, data: List[MarketData], timeframe: str) -> List[datetime]:
"""Detect gaps in data.
Args:
data: List of market data
timeframe: Timeframe
Returns:
List of gap timestamps
"""
gaps = []
# Simplified gap detection
return gaps
def _detect_anomalies(self, data: List[MarketData]) -> List[int]:
"""Detect data anomalies.
Args:
data: List of market data
Returns:
List of anomaly indices
"""
anomalies = []
# Simplified anomaly detection
return anomalies
def cleanup_old_data(self, days_to_keep: int = 365):
"""Clean up old data based on retention policy.
Args:
days_to_keep: Days of data to keep
"""
session = self.db.get_session()
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_to_keep)
deleted = session.query(MarketData).filter(
MarketData.timestamp < cutoff_date
).delete()
session.commit()
logger.info(f"Cleaned up {deleted} old data records")
except Exception as e:
session.rollback()
logger.error(f"Failed to cleanup old data: {e}")
finally:
session.close()

225
src/data/redis_cache.py Normal file
View File

@@ -0,0 +1,225 @@
"""Redis-based caching for market data and API responses."""
from typing import Any, Optional
import json
from datetime import datetime
from src.core.redis import get_redis_client
from src.core.logger import get_logger
logger = get_logger(__name__)
class RedisCache:
"""Redis-based cache for market data and API responses."""
# Default TTL values (seconds)
TTL_TICKER = 5 # Ticker prices are very volatile
TTL_OHLCV = 60 # OHLCV can be cached longer
TTL_ORDERBOOK = 2 # Order books change rapidly
TTL_API_RESPONSE = 30 # General API response cache
def __init__(self):
"""Initialize Redis cache."""
self.redis = get_redis_client()
async def get_ticker(self, symbol: str) -> Optional[dict]:
"""Get cached ticker data.
Args:
symbol: Trading symbol (e.g., 'BTC/USD')
Returns:
Cached ticker data or None
"""
key = f"cache:ticker:{symbol.replace('/', '_')}"
try:
client = self.redis.get_client()
data = await client.get(key)
if data:
logger.debug(f"Cache hit for ticker:{symbol}")
return json.loads(data)
return None
except Exception as e:
logger.warning(f"Redis cache get failed: {e}")
return None
async def set_ticker(self, symbol: str, data: dict, ttl: int = None) -> bool:
"""Cache ticker data.
Args:
symbol: Trading symbol
data: Ticker data
ttl: Time-to-live in seconds (default: TTL_TICKER)
Returns:
True if cached successfully
"""
key = f"cache:ticker:{symbol.replace('/', '_')}"
ttl = ttl or self.TTL_TICKER
try:
client = self.redis.get_client()
await client.setex(key, ttl, json.dumps(data))
logger.debug(f"Cached ticker:{symbol} for {ttl}s")
return True
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
return False
async def get_ohlcv(self, symbol: str, timeframe: str, limit: int = 100) -> Optional[list]:
"""Get cached OHLCV data.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
limit: Number of candles
Returns:
Cached OHLCV data or None
"""
key = f"cache:ohlcv:{symbol.replace('/', '_')}:{timeframe}:{limit}"
try:
client = self.redis.get_client()
data = await client.get(key)
if data:
logger.debug(f"Cache hit for ohlcv:{symbol}:{timeframe}")
return json.loads(data)
return None
except Exception as e:
logger.warning(f"Redis cache get failed: {e}")
return None
async def set_ohlcv(self, symbol: str, timeframe: str, data: list, limit: int = 100, ttl: int = None) -> bool:
"""Cache OHLCV data.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
data: OHLCV data
limit: Number of candles
ttl: Time-to-live in seconds
Returns:
True if cached successfully
"""
key = f"cache:ohlcv:{symbol.replace('/', '_')}:{timeframe}:{limit}"
ttl = ttl or self.TTL_OHLCV
try:
client = self.redis.get_client()
await client.setex(key, ttl, json.dumps(data))
logger.debug(f"Cached ohlcv:{symbol}:{timeframe} for {ttl}s")
return True
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
return False
async def get_api_response(self, cache_key: str) -> Optional[dict]:
"""Get cached API response.
Args:
cache_key: Unique cache key
Returns:
Cached response or None
"""
key = f"cache:api:{cache_key}"
try:
client = self.redis.get_client()
data = await client.get(key)
if data:
logger.debug(f"Cache hit for api:{cache_key}")
return json.loads(data)
return None
except Exception as e:
logger.warning(f"Redis cache get failed: {e}")
return None
async def set_api_response(self, cache_key: str, data: dict, ttl: int = None) -> bool:
"""Cache API response.
Args:
cache_key: Unique cache key
data: Response data
ttl: Time-to-live in seconds
Returns:
True if cached successfully
"""
key = f"cache:api:{cache_key}"
ttl = ttl or self.TTL_API_RESPONSE
try:
client = self.redis.get_client()
await client.setex(key, ttl, json.dumps(data))
logger.debug(f"Cached api:{cache_key} for {ttl}s")
return True
except Exception as e:
logger.warning(f"Redis cache set failed: {e}")
return False
async def invalidate(self, pattern: str) -> int:
"""Invalidate cache entries matching pattern.
Args:
pattern: Redis key pattern (e.g., 'cache:ticker:*')
Returns:
Number of keys deleted
"""
try:
client = self.redis.get_client()
keys = []
async for key in client.scan_iter(match=pattern):
keys.append(key)
if keys:
deleted = await client.delete(*keys)
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
return deleted
return 0
except Exception as e:
logger.warning(f"Redis cache invalidation failed: {e}")
return 0
async def get_stats(self) -> dict:
"""Get cache statistics.
Returns:
Cache statistics
"""
try:
client = self.redis.get_client()
info = await client.info('memory')
# Count cached items by type
ticker_count = 0
ohlcv_count = 0
api_count = 0
async for key in client.scan_iter(match='cache:ticker:*'):
ticker_count += 1
async for key in client.scan_iter(match='cache:ohlcv:*'):
ohlcv_count += 1
async for key in client.scan_iter(match='cache:api:*'):
api_count += 1
return {
"memory_used": info.get('used_memory_human', 'N/A'),
"ticker_entries": ticker_count,
"ohlcv_entries": ohlcv_count,
"api_entries": api_count,
"total_entries": ticker_count + ohlcv_count + api_count
}
except Exception as e:
logger.warning(f"Failed to get cache stats: {e}")
return {"error": str(e)}
# Global cache instance
_redis_cache: Optional[RedisCache] = None
def get_redis_cache() -> RedisCache:
"""Get global Redis cache instance."""
global _redis_cache
if _redis_cache is None:
_redis_cache = RedisCache()
return _redis_cache

75
src/data/storage.py Normal file
View File

@@ -0,0 +1,75 @@
"""Data persistence."""
from decimal import Decimal
from datetime import datetime
from typing import List, Optional
from sqlalchemy.orm import Session
from src.core.database import get_database, MarketData
from src.core.logger import get_logger
logger = get_logger(__name__)
class DataStorage:
"""Manages data storage and persistence."""
def __init__(self):
"""Initialize data storage."""
self.db = get_database()
self.logger = get_logger(__name__)
def store_ohlcv(
self,
exchange: str,
symbol: str,
timeframe: str,
timestamp: datetime,
open: Decimal,
high: Decimal,
low: Decimal,
close: Decimal,
volume: Decimal
):
"""Store OHLCV data.
Args:
exchange: Exchange name
symbol: Trading symbol
timeframe: Timeframe
timestamp: Timestamp
open: Open price
high: High price
low: Low price
close: Close price
volume: Volume
"""
session = self.db.get_session()
try:
# Check if exists
existing = session.query(MarketData).filter_by(
exchange=exchange,
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp
).first()
if not existing:
market_data = MarketData(
exchange=exchange,
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp,
open=open,
high=high,
low=low,
close=close,
volume=volume
)
session.add(market_data)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to store OHLCV data: {e}")
finally:
session.close()