Local changes: Updated model training, removed debug instrumentation, and configuration improvements

This commit is contained in:
kfox
2025-12-26 01:15:43 -05:00
commit cc60da49e7
388 changed files with 57127 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
"""Pricing data providers package."""
from .base_provider import BasePricingProvider
from .ccxt_provider import CCXTProvider
from .coingecko_provider import CoinGeckoProvider
__all__ = ['BasePricingProvider', 'CCXTProvider', 'CoinGeckoProvider']

View File

@@ -0,0 +1,150 @@
"""Base pricing provider interface."""
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from src.core.logger import get_logger
logger = get_logger(__name__)
class BasePricingProvider(ABC):
"""Base class for pricing data providers.
Pricing providers are responsible for fetching market data (prices, OHLCV)
without requiring API keys. They differ from exchange adapters which handle
trading operations.
"""
def __init__(self):
"""Initialize pricing provider."""
self.logger = get_logger(f"provider.{self.__class__.__name__}")
self._connected = False
self._subscribers: Dict[str, List[Callable]] = {}
@property
@abstractmethod
def name(self) -> str:
"""Provider name."""
pass
@property
@abstractmethod
def supports_websocket(self) -> bool:
"""Whether this provider supports WebSocket connections."""
pass
@abstractmethod
def connect(self) -> bool:
"""Connect to the provider.
Returns:
True if connection successful
"""
pass
@abstractmethod
def disconnect(self):
"""Disconnect from the provider."""
pass
@abstractmethod
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading pair symbol (e.g., 'BTC/USD')
Returns:
Dictionary with ticker data:
- 'symbol': str
- 'bid': Decimal
- 'ask': Decimal
- 'last': Decimal (last price)
- 'high': Decimal (24h high)
- 'low': Decimal (24h low)
- 'volume': Decimal (24h volume)
- 'timestamp': int (Unix timestamp in milliseconds)
"""
pass
@abstractmethod
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV (candlestick) data.
Args:
symbol: Trading pair symbol
timeframe: Timeframe (1m, 5m, 15m, 1h, 1d, etc.)
since: Start datetime
limit: Number of candles
Returns:
List of [timestamp_ms, open, high, low, close, volume]
"""
pass
@abstractmethod
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates.
Args:
symbol: Trading pair symbol
callback: Callback function(data) called on price updates
Returns:
True if subscription successful
"""
pass
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates.
Args:
symbol: Trading pair symbol
callback: Specific callback to remove, or None to remove all
"""
key = f"ticker_{symbol}"
if key in self._subscribers:
if callback:
if callback in self._subscribers[key]:
self._subscribers[key].remove(callback)
if not self._subscribers[key]:
del self._subscribers[key]
else:
del self._subscribers[key]
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format for this provider.
Args:
symbol: Symbol to normalize
Returns:
Normalized symbol
"""
# Default: uppercase and replace dashes with slashes
return symbol.upper().replace('-', '/')
def is_connected(self) -> bool:
"""Check if provider is connected.
Returns:
True if connected
"""
return self._connected
def get_supported_symbols(self) -> List[str]:
"""Get list of supported trading symbols.
Returns:
List of supported symbols, or empty list if not available
"""
# Default implementation - override in subclasses
return []

View File

@@ -0,0 +1,333 @@
"""CCXT-based pricing provider with multi-exchange support and WebSocket capabilities."""
import ccxt
import threading
import time
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable
from datetime import datetime
from .base_provider import BasePricingProvider
from src.core.logger import get_logger
logger = get_logger(__name__)
class CCXTProvider(BasePricingProvider):
"""CCXT-based pricing provider with multi-exchange fallback support.
This provider uses CCXT to connect to multiple exchanges (Kraken, Coinbase, Binance)
as primary data sources. It supports WebSocket where available and falls back to
polling if WebSocket is not available.
"""
def __init__(self, exchange_name: Optional[str] = None):
"""Initialize CCXT provider.
Args:
exchange_name: Specific exchange to use ('kraken', 'coinbase', 'binance'),
or None to try multiple exchanges
"""
super().__init__()
self.exchange_name = exchange_name
self.exchange = None
self._selected_exchange_id = None
self._polling_threads: Dict[str, threading.Thread] = {}
self._stop_polling: Dict[str, bool] = {}
# Exchange priority order (try first to last)
if exchange_name:
self._exchange_options = [(exchange_name.lower(), exchange_name.capitalize())]
else:
self._exchange_options = [
('kraken', 'Kraken'),
('coinbase', 'Coinbase'),
('binance', 'Binance'),
]
@property
def name(self) -> str:
"""Provider name."""
if self._selected_exchange_id:
return f"CCXT-{self._selected_exchange_id.capitalize()}"
return "CCXT Provider"
@property
def supports_websocket(self) -> bool:
"""Check if current exchange supports WebSocket."""
if not self.exchange:
return False
# Check if exchange has WebSocket support
# Most modern CCXT exchanges support WebSocket, but we check explicitly
exchange_id = getattr(self.exchange, 'id', '').lower()
# Known WebSocket-capable exchanges
ws_capable = ['kraken', 'coinbase', 'binance', 'binanceus', 'okx']
return exchange_id in ws_capable
def connect(self) -> bool:
"""Connect to an exchange via CCXT.
Tries multiple exchanges in order until one succeeds.
Returns:
True if connection successful
"""
for exchange_id, exchange_display_name in self._exchange_options:
try:
# Get exchange class from CCXT
exchange_class = getattr(ccxt, exchange_id, None)
if not exchange_class:
logger.warning(f"Exchange {exchange_id} not found in CCXT")
continue
# Create exchange instance
self.exchange = exchange_class({
'enableRateLimit': True,
'options': {
'defaultType': 'spot',
}
})
# Load markets to test connection
if not hasattr(self.exchange, 'markets') or not self.exchange.markets:
try:
self.exchange.load_markets()
except Exception as e:
logger.warning(f"Failed to load markets for {exchange_id}: {e}")
continue
# Test connection with a common symbol
test_symbols = ['BTC/USDT', 'BTC/USD', 'BTC/EUR']
ticker_result = None
for test_symbol in test_symbols:
try:
ticker_result = self.exchange.fetch_ticker(test_symbol)
if ticker_result:
break
except Exception:
continue
if not ticker_result:
logger.warning(f"Could not fetch ticker from {exchange_id}")
continue
# Success!
self._selected_exchange_id = exchange_id
self._connected = True
logger.info(f"Connected to {exchange_display_name} via CCXT")
return True
except Exception as e:
logger.warning(f"Failed to connect to {exchange_id}: {e}")
continue
# All exchanges failed
logger.error("Failed to connect to any CCXT exchange")
self._connected = False
self.exchange = None
return False
def disconnect(self):
"""Disconnect from exchange."""
# Stop all polling threads
for symbol in list(self._stop_polling.keys()):
self._stop_polling[symbol] = True
# Wait a bit for threads to stop
for symbol, thread in self._polling_threads.items():
if thread.is_alive():
thread.join(timeout=2.0)
self._polling_threads.clear()
self._stop_polling.clear()
self._subscribers.clear()
self._connected = False
self.exchange = None
self._selected_exchange_id = None
logger.info(f"Disconnected from {self.name}")
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data."""
if not self._connected or not self.exchange:
logger.error("Provider not connected")
return {}
try:
normalized_symbol = self.normalize_symbol(symbol)
ticker = self.exchange.fetch_ticker(normalized_symbol)
return {
'symbol': symbol,
'bid': Decimal(str(ticker.get('bid', 0) or 0)),
'ask': Decimal(str(ticker.get('ask', 0) or 0)),
'last': Decimal(str(ticker.get('last', 0) or ticker.get('close', 0) or 0)),
'high': Decimal(str(ticker.get('high', 0) or 0)),
'low': Decimal(str(ticker.get('low', 0) or 0)),
'volume': Decimal(str(ticker.get('quoteVolume', ticker.get('volume', 0)) or 0)),
'timestamp': ticker.get('timestamp', int(time.time() * 1000)),
}
except Exception as e:
logger.error(f"Failed to get ticker for {symbol}: {e}")
return {}
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV candlestick data."""
if not self._connected or not self.exchange:
logger.error("Provider not connected")
return []
try:
since_timestamp = int(since.timestamp() * 1000) if since else None
normalized_symbol = self.normalize_symbol(symbol)
# Most exchanges support up to 1000 candles per request
max_limit = min(limit, 1000)
ohlcv = self.exchange.fetch_ohlcv(
normalized_symbol,
timeframe,
since_timestamp,
max_limit
)
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} ({timeframe})")
return ohlcv
except Exception as e:
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
return []
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates.
Uses WebSocket if available, otherwise falls back to polling.
"""
if not self._connected or not self.exchange:
logger.error("Provider not connected")
return False
key = f"ticker_{symbol}"
# Add callback to subscribers
if key not in self._subscribers:
self._subscribers[key] = []
if callback not in self._subscribers[key]:
self._subscribers[key].append(callback)
# Try WebSocket first if supported
if self.supports_websocket:
try:
# CCXT WebSocket support varies by exchange
# For now, use polling as fallback since WebSocket implementation
# in CCXT requires exchange-specific handling
# TODO: Implement native WebSocket when CCXT adds better support
pass
except Exception as e:
logger.warning(f"WebSocket subscription failed, falling back to polling: {e}")
# Use polling as primary method (more reliable across exchanges)
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
self._stop_polling[key] = False
def poll_ticker():
"""Poll ticker every 2 seconds."""
while not self._stop_polling.get(key, False):
try:
ticker_data = self.get_ticker(symbol)
if ticker_data and 'last' in ticker_data:
# Call all callbacks for this symbol
for cb in self._subscribers.get(key, []):
try:
cb({
'symbol': symbol,
'price': ticker_data['last'],
'bid': ticker_data.get('bid', 0),
'ask': ticker_data.get('ask', 0),
'volume': ticker_data.get('volume', 0),
'timestamp': ticker_data.get('timestamp'),
})
except Exception as e:
logger.error(f"Callback error for {symbol}: {e}")
time.sleep(2) # Poll every 2 seconds
except Exception as e:
logger.error(f"Ticker polling error for {symbol}: {e}")
time.sleep(5) # Wait longer on error
thread = threading.Thread(target=poll_ticker, daemon=True)
thread.start()
self._polling_threads[key] = thread
logger.info(f"Subscribed to ticker updates for {symbol} (polling mode)")
return True
return True
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates."""
key = f"ticker_{symbol}"
# Stop polling thread
if key in self._stop_polling:
self._stop_polling[key] = True
# Remove from subscribers
super().unsubscribe_ticker(symbol, callback)
# Clean up thread reference
if key in self._polling_threads:
del self._polling_threads[key]
logger.info(f"Unsubscribed from ticker updates for {symbol}")
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol for the selected exchange."""
if not self.exchange:
return symbol.replace('-', '/').upper()
# Basic normalization
normalized = symbol.replace('-', '/').upper()
# Try to use exchange's markets to find correct symbol
try:
if hasattr(self.exchange, 'markets') and self.exchange.markets:
# Check if symbol exists
if normalized in self.exchange.markets:
return normalized
# Try alternative formats
if '/USD' in normalized:
alt_symbol = normalized.replace('/USD', '/USDT')
if alt_symbol in self.exchange.markets:
return alt_symbol
# For Kraken: try XBT instead of BTC
if normalized.startswith('BTC/'):
alt_symbol = normalized.replace('BTC/', 'XBT/')
if alt_symbol in self.exchange.markets:
return alt_symbol
except Exception:
pass
# Fallback: return normalized (let CCXT handle errors)
return normalized
def get_supported_symbols(self) -> List[str]:
"""Get list of supported trading symbols."""
if not self._connected or not self.exchange:
return []
try:
markets = self.exchange.load_markets()
return list(markets.keys())
except Exception as e:
logger.error(f"Failed to get supported symbols: {e}")
return []

View File

@@ -0,0 +1,376 @@
"""CoinGecko pricing provider for fallback market data."""
import httpx
import time
import threading
from decimal import Decimal
from typing import Dict, List, Optional, Any, Callable, Tuple
from datetime import datetime
from .base_provider import BasePricingProvider
from src.core.logger import get_logger
logger = get_logger(__name__)
class CoinGeckoProvider(BasePricingProvider):
"""CoinGecko API pricing provider.
This provider uses CoinGecko's free API tier as a fallback when CCXT
providers are unavailable. It uses simple REST endpoints that don't
require authentication.
"""
BASE_URL = "https://api.coingecko.com/api/v3"
# CoinGecko coin ID mapping for common symbols
COIN_ID_MAP = {
'BTC': 'bitcoin',
'ETH': 'ethereum',
'BNB': 'binancecoin',
'SOL': 'solana',
'ADA': 'cardano',
'XRP': 'ripple',
'DOGE': 'dogecoin',
'DOT': 'polkadot',
'MATIC': 'matic-network',
'AVAX': 'avalanche-2',
'LINK': 'chainlink',
'USDT': 'tether',
'USDC': 'usd-coin',
'DAI': 'dai',
}
# Currency mapping
CURRENCY_MAP = {
'USD': 'usd',
'EUR': 'eur',
'GBP': 'gbp',
'JPY': 'jpy',
'USDT': 'usd', # CoinGecko uses USD for stablecoins
}
def __init__(self, api_key: Optional[str] = None):
"""Initialize CoinGecko provider.
Args:
api_key: Optional API key for higher rate limits (free tier doesn't require it)
"""
super().__init__()
self.api_key = api_key
self._client = None
self._polling_threads: Dict[str, threading.Thread] = {}
self._stop_polling: Dict[str, bool] = {}
self._rate_limit_delay = 1.0 # Free tier: ~30-50 calls/minute
@property
def name(self) -> str:
"""Provider name."""
return "CoinGecko"
@property
def supports_websocket(self) -> bool:
"""CoinGecko free tier doesn't support WebSocket."""
return False
def connect(self) -> bool:
"""Connect to CoinGecko API.
Returns:
True if connection successful (just validates API access)
"""
try:
self._client = httpx.Client(timeout=10.0)
# Test connection by fetching Bitcoin price
response = self._client.get(
f"{self.BASE_URL}/simple/price",
params={
'ids': 'bitcoin',
'vs_currencies': 'usd',
}
)
if response.status_code == 200:
data = response.json()
if 'bitcoin' in data:
self._connected = True
logger.info("Connected to CoinGecko API")
return True
logger.warning("CoinGecko API test failed")
self._connected = False
return False
except Exception as e:
logger.error(f"Failed to connect to CoinGecko: {e}")
self._connected = False
return False
def disconnect(self):
"""Disconnect from CoinGecko API."""
# Stop all polling threads
for symbol in list(self._stop_polling.keys()):
self._stop_polling[symbol] = True
# Wait for threads to stop
for symbol, thread in self._polling_threads.items():
if thread.is_alive():
thread.join(timeout=2.0)
self._polling_threads.clear()
self._stop_polling.clear()
self._subscribers.clear()
if self._client:
self._client.close()
self._client = None
self._connected = False
logger.info("Disconnected from CoinGecko API")
def _parse_symbol(self, symbol: str) -> Tuple[Optional[str], Optional[str]]:
"""Parse symbol into coin_id and currency.
Args:
symbol: Trading pair like 'BTC/USD' or 'BTC/USDT'
Returns:
Tuple of (coin_id, currency) or (None, None) if not found
"""
parts = symbol.upper().replace('-', '/').split('/')
if len(parts) != 2:
return None, None
base, quote = parts
# Get coin ID
coin_id = self.COIN_ID_MAP.get(base)
if not coin_id:
# Try lowercase base as fallback
coin_id = base.lower()
# Get currency
currency = self.CURRENCY_MAP.get(quote, quote.lower())
return coin_id, currency
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data from CoinGecko."""
if not self._connected or not self._client:
logger.error("Provider not connected")
return {}
try:
coin_id, currency = self._parse_symbol(symbol)
if not coin_id:
logger.error(f"Unknown symbol format: {symbol}")
return {}
# Fetch current price
response = self._client.get(
f"{self.BASE_URL}/simple/price",
params={
'ids': coin_id,
'vs_currencies': currency,
'include_24hr_change': 'true',
'include_24hr_vol': 'true',
}
)
if response.status_code != 200:
logger.error(f"CoinGecko API error: {response.status_code}")
return {}
data = response.json()
if coin_id not in data:
logger.error(f"Coin {coin_id} not found in CoinGecko response")
return {}
coin_data = data[coin_id]
price_key = currency.lower()
if price_key not in coin_data:
logger.error(f"Currency {currency} not found for {coin_id}")
return {}
price = Decimal(str(coin_data[price_key]))
# Calculate high/low from 24h change if available
change_key = f"{price_key}_24h_change"
vol_key = f"{price_key}_24h_vol"
change_24h = coin_data.get(change_key, 0) or 0
volume_24h = coin_data.get(vol_key, 0) or 0
# Estimate high/low from current price and 24h change
# This is approximate since CoinGecko free tier doesn't provide exact high/low
current_price = float(price)
if change_24h:
estimated_high = current_price * (1 + abs(change_24h / 100) / 2)
estimated_low = current_price * (1 - abs(change_24h / 100) / 2)
else:
estimated_high = current_price
estimated_low = current_price
return {
'symbol': symbol,
'bid': price, # CoinGecko doesn't provide bid/ask, use last price
'ask': price,
'last': price,
'high': Decimal(str(estimated_high)),
'low': Decimal(str(estimated_low)),
'volume': Decimal(str(volume_24h)),
'timestamp': int(time.time() * 1000),
}
except Exception as e:
logger.error(f"Failed to get ticker for {symbol}: {e}")
return {}
def get_ohlcv(
self,
symbol: str,
timeframe: str = '1h',
since: Optional[datetime] = None,
limit: int = 100
) -> List[List]:
"""Get OHLCV data from CoinGecko.
Note: CoinGecko free tier has limited historical data access.
This method may return empty data for some timeframes.
"""
if not self._connected or not self._client:
logger.error("Provider not connected")
return []
try:
coin_id, currency = self._parse_symbol(symbol)
if not coin_id:
logger.error(f"Unknown symbol format: {symbol}")
return []
# CoinGecko uses days parameter instead of timeframe
# Map timeframe to days
timeframe_days_map = {
'1m': 1, # Last 24 hours, minute data (not available in free tier)
'5m': 1,
'15m': 1,
'30m': 1,
'1h': 1, # Last 24 hours, hourly data
'4h': 7, # Last 7 days
'1d': 30, # Last 30 days
'1w': 90, # Last 90 days
}
days = timeframe_days_map.get(timeframe, 7)
# Fetch OHLC data
response = self._client.get(
f"{self.BASE_URL}/coins/{coin_id}/ohlc",
params={
'vs_currency': currency,
'days': days,
}
)
if response.status_code != 200:
logger.warning(f"CoinGecko OHLC API returned {response.status_code}")
return []
data = response.json()
# CoinGecko returns: [timestamp_ms, open, high, low, close]
# We need to convert to: [timestamp_ms, open, high, low, close, volume]
# Note: CoinGecko OHLC endpoint doesn't include volume
# Filter by since if provided
if since:
since_timestamp = int(since.timestamp() * 1000)
data = [candle for candle in data if candle[0] >= since_timestamp]
# Limit results
data = data[-limit:] if limit else data
# Add volume as 0 (CoinGecko doesn't provide it in OHLC endpoint)
ohlcv = [candle + [0] for candle in data]
logger.debug(f"Fetched {len(ohlcv)} candles for {symbol} from CoinGecko")
return ohlcv
except Exception as e:
logger.error(f"Failed to get OHLCV for {symbol}: {e}")
return []
def subscribe_ticker(self, symbol: str, callback: Callable) -> bool:
"""Subscribe to ticker updates via polling."""
if not self._connected or not self._client:
logger.error("Provider not connected")
return False
key = f"ticker_{symbol}"
# Add callback to subscribers
if key not in self._subscribers:
self._subscribers[key] = []
if callback not in self._subscribers[key]:
self._subscribers[key].append(callback)
# Start polling thread if not already running
if key not in self._polling_threads or not self._polling_threads[key].is_alive():
self._stop_polling[key] = False
def poll_ticker():
"""Poll ticker respecting rate limits."""
while not self._stop_polling.get(key, False):
try:
ticker_data = self.get_ticker(symbol)
if ticker_data and 'last' in ticker_data:
# Call all callbacks
for cb in self._subscribers.get(key, []):
try:
cb({
'symbol': symbol,
'price': ticker_data['last'],
'bid': ticker_data.get('bid', 0),
'ask': ticker_data.get('ask', 0),
'volume': ticker_data.get('volume', 0),
'timestamp': ticker_data.get('timestamp'),
})
except Exception as e:
logger.error(f"Callback error for {symbol}: {e}")
# Rate limit: wait between requests (free tier: ~30-50 calls/min)
time.sleep(self._rate_limit_delay * 2) # Poll every 2 seconds
except Exception as e:
logger.error(f"Ticker polling error for {symbol}: {e}")
time.sleep(10) # Wait longer on error
thread = threading.Thread(target=poll_ticker, daemon=True)
thread.start()
self._polling_threads[key] = thread
logger.info(f"Subscribed to ticker updates for {symbol} (CoinGecko polling)")
return True
return True
def unsubscribe_ticker(self, symbol: str, callback: Optional[Callable] = None):
"""Unsubscribe from ticker updates."""
key = f"ticker_{symbol}"
# Stop polling
if key in self._stop_polling:
self._stop_polling[key] = True
# Remove from subscribers
super().unsubscribe_ticker(symbol, callback)
# Clean up thread
if key in self._polling_threads:
del self._polling_threads[key]
logger.info(f"Unsubscribed from ticker updates for {symbol}")
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format."""
# CoinGecko uses coin IDs, so we just normalize the format
return symbol.replace('-', '/').upper()