Local changes: Updated model training, removed debug instrumentation, and configuration improvements

This commit is contained in:
kfox
2025-12-26 01:15:43 -05:00
commit cc60da49e7
388 changed files with 57127 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Test init file."""

View File

@@ -0,0 +1,139 @@
"""Unit tests for CCXT pricing provider."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from decimal import Decimal
from datetime import datetime
from src.data.providers.ccxt_provider import CCXTProvider
@pytest.fixture
def mock_ccxt_exchange():
"""Create a mock CCXT exchange."""
exchange = Mock()
exchange.markets = {
'BTC/USDT': {},
'ETH/USDT': {},
'BTC/USD': {},
}
exchange.id = 'kraken'
exchange.fetch_ticker = Mock(return_value={
'bid': 50000.0,
'ask': 50001.0,
'last': 50000.5,
'high': 51000.0,
'low': 49000.0,
'quoteVolume': 1000000.0,
'timestamp': 1609459200000,
})
exchange.fetch_ohlcv = Mock(return_value=[
[1609459200000, 50000, 51000, 49000, 50000, 1000],
])
exchange.load_markets = Mock(return_value=exchange.markets)
return exchange
@pytest.fixture
def provider():
"""Create a CCXT provider instance."""
return CCXTProvider(exchange_name='kraken')
class TestCCXTProvider:
"""Tests for CCXTProvider."""
def test_init(self, provider):
"""Test provider initialization."""
assert provider.name == "CCXT Provider"
assert not provider._connected
assert provider.exchange is None
def test_name_property(self, provider):
"""Test name property."""
provider._selected_exchange_id = 'kraken'
assert provider.name == "CCXT-Kraken"
@patch('src.data.providers.ccxt_provider.ccxt')
def test_connect_success(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test successful connection."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
result = provider.connect()
assert result is True
assert provider._connected is True
assert provider.exchange == mock_ccxt_exchange
assert provider._selected_exchange_id == 'kraken'
@patch('src.data.providers.ccxt_provider.ccxt')
def test_connect_failure(self, mock_ccxt, provider):
"""Test connection failure."""
mock_ccxt.kraken = Mock(side_effect=Exception("Connection failed"))
result = provider.connect()
assert result is False
assert not provider._connected
@patch('src.data.providers.ccxt_provider.ccxt')
def test_get_ticker(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test getting ticker data."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
provider.connect()
ticker = provider.get_ticker('BTC/USDT')
assert ticker['symbol'] == 'BTC/USDT'
assert isinstance(ticker['bid'], Decimal)
assert isinstance(ticker['last'], Decimal)
assert ticker['last'] > 0
@patch('src.data.providers.ccxt_provider.ccxt')
def test_get_ohlcv(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test getting OHLCV data."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
provider.connect()
ohlcv = provider.get_ohlcv('BTC/USDT', '1h', limit=10)
assert len(ohlcv) > 0
assert len(ohlcv[0]) == 6 # timestamp, open, high, low, close, volume
@patch('src.data.providers.ccxt_provider.ccxt')
def test_subscribe_ticker(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test subscribing to ticker updates."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
provider.connect()
callback = Mock()
result = provider.subscribe_ticker('BTC/USDT', callback)
assert result is True
assert 'ticker_BTC/USDT' in provider._subscribers
def test_normalize_symbol(self, provider):
"""Test symbol normalization."""
# Test with exchange
with patch.object(provider, 'exchange') as mock_exchange:
mock_exchange.markets = {'BTC/USDT': {}}
normalized = provider.normalize_symbol('btc-usdt')
assert normalized == 'BTC/USDT'
# Test without exchange
provider.exchange = None
normalized = provider.normalize_symbol('btc-usdt')
assert normalized == 'BTC/USDT'
@patch('src.data.providers.ccxt_provider.ccxt')
def test_disconnect(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test disconnection."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
provider.connect()
provider.subscribe_ticker('BTC/USDT', Mock())
provider.disconnect()
assert not provider._connected
assert provider.exchange is None
assert len(provider._subscribers) == 0

View File

@@ -0,0 +1,113 @@
"""Unit tests for CoinGecko pricing provider."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from decimal import Decimal
import httpx
from src.data.providers.coingecko_provider import CoinGeckoProvider
@pytest.fixture
def provider():
"""Create a CoinGecko provider instance."""
return CoinGeckoProvider()
@pytest.fixture
def mock_response():
"""Create a mock HTTP response."""
response = Mock()
response.status_code = 200
response.json = Mock(return_value={
'bitcoin': {
'usd': 50000.0,
'usd_24h_change': 2.5,
'usd_24h_vol': 1000000.0,
}
})
return response
class TestCoinGeckoProvider:
"""Tests for CoinGeckoProvider."""
def test_init(self, provider):
"""Test provider initialization."""
assert provider.name == "CoinGecko"
assert not provider.supports_websocket
assert not provider._connected
@patch('src.data.providers.coingecko_provider.httpx.Client')
def test_connect_success(self, mock_client_class, provider, mock_response):
"""Test successful connection."""
mock_client = Mock()
mock_client.get = Mock(return_value=mock_response)
mock_client_class.return_value = mock_client
result = provider.connect()
assert result is True
assert provider._connected is True
@patch('src.data.providers.coingecko_provider.httpx.Client')
def test_connect_failure(self, mock_client_class, provider):
"""Test connection failure."""
mock_client = Mock()
mock_client.get = Mock(side_effect=Exception("Connection failed"))
mock_client_class.return_value = mock_client
result = provider.connect()
assert result is False
assert not provider._connected
def test_parse_symbol(self, provider):
"""Test symbol parsing."""
coin_id, currency = provider._parse_symbol('BTC/USD')
assert coin_id == 'bitcoin'
assert currency == 'usd'
coin_id, currency = provider._parse_symbol('ETH/USDT')
assert coin_id == 'ethereum'
assert currency == 'usd' # USDT maps to USD
@patch('src.data.providers.coingecko_provider.httpx.Client')
def test_get_ticker(self, mock_client_class, provider, mock_response):
"""Test getting ticker data."""
mock_client = Mock()
mock_client.get = Mock(return_value=mock_response)
mock_client_class.return_value = mock_client
provider.connect()
ticker = provider.get_ticker('BTC/USD')
assert ticker['symbol'] == 'BTC/USD'
assert isinstance(ticker['last'], Decimal)
assert ticker['last'] > 0
assert 'timestamp' in ticker
@patch('src.data.providers.coingecko_provider.httpx.Client')
def test_get_ohlcv(self, mock_client_class, provider):
"""Test getting OHLCV data."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=[
[1609459200000, 50000, 51000, 49000, 50000],
])
mock_client = Mock()
mock_client.get = Mock(return_value=mock_response)
mock_client_class.return_value = mock_client
provider.connect()
ohlcv = provider.get_ohlcv('BTC/USD', '1h', limit=10)
assert len(ohlcv) > 0
# CoinGecko returns 5 elements, we add volume as 0
assert len(ohlcv[0]) == 6
def test_normalize_symbol(self, provider):
"""Test symbol normalization."""
normalized = provider.normalize_symbol('btc-usdt')
assert normalized == 'BTC/USDT'

View File

@@ -0,0 +1,120 @@
"""Unit tests for cache manager."""
import pytest
import time
from src.data.cache_manager import CacheManager, CacheEntry
class TestCacheEntry:
"""Tests for CacheEntry."""
def test_init(self):
"""Test cache entry initialization."""
entry = CacheEntry("test_data", 60.0)
assert entry.data == "test_data"
assert entry.expires_at > time.time()
assert entry.access_count == 0
def test_is_expired(self):
"""Test expiration checking."""
entry = CacheEntry("test_data", 0.01) # Very short TTL
assert not entry.is_expired()
time.sleep(0.02)
assert entry.is_expired()
def test_touch(self):
"""Test access tracking."""
entry = CacheEntry("test_data", 60.0)
initial_count = entry.access_count
entry.touch()
assert entry.access_count == initial_count + 1
class TestCacheManager:
"""Tests for CacheManager."""
@pytest.fixture
def cache(self):
"""Create a cache manager instance."""
return CacheManager(default_ttl=1.0, max_size=10)
def test_get_set(self, cache):
"""Test basic get and set operations."""
cache.set("key1", "value1")
assert cache.get("key1") == "value1"
def test_get_missing(self, cache):
"""Test getting non-existent key."""
assert cache.get("missing") is None
def test_expiration(self, cache):
"""Test cache entry expiration."""
cache.set("key1", "value1", ttl=0.1)
assert cache.get("key1") == "value1"
time.sleep(0.2)
assert cache.get("key1") is None
def test_lru_eviction(self, cache):
"""Test LRU eviction when max size reached."""
# Fill cache to max size
for i in range(10):
cache.set(f"key{i}", f"value{i}")
# Add one more - should evict oldest
cache.set("key10", "value10")
# Oldest key should be evicted
assert cache.get("key0") is None
assert cache.get("key10") == "value10"
def test_type_specific_ttl(self, cache):
"""Test type-specific TTL."""
cache.set("ticker1", {"price": 100}, cache_type='ticker')
cache.set("ohlcv1", [[1, 2, 3, 4, 5, 6]], cache_type='ohlcv')
# Both should be cached
assert cache.get("ticker1") is not None
assert cache.get("ohlcv1") is not None
def test_delete(self, cache):
"""Test cache entry deletion."""
cache.set("key1", "value1")
assert cache.get("key1") == "value1"
cache.delete("key1")
assert cache.get("key1") is None
def test_clear(self, cache):
"""Test cache clearing."""
cache.set("key1", "value1")
cache.set("key2", "value2")
cache.clear()
assert cache.get("key1") is None
assert cache.get("key2") is None
def test_stats(self, cache):
"""Test cache statistics."""
cache.set("key1", "value1")
cache.get("key1") # Hit
cache.get("missing") # Miss
stats = cache.get_stats()
assert stats['hits'] >= 1
assert stats['misses'] >= 1
assert stats['size'] == 1
assert 'hit_rate' in stats
def test_invalidate_pattern(self, cache):
"""Test pattern-based invalidation."""
cache.set("ticker:BTC/USD", "value1")
cache.set("ticker:ETH/USD", "value2")
cache.set("ohlcv:BTC/USD", "value3")
cache.invalidate_pattern("ticker:")
assert cache.get("ticker:BTC/USD") is None
assert cache.get("ticker:ETH/USD") is None
assert cache.get("ohlcv:BTC/USD") is not None

View File

@@ -0,0 +1,145 @@
"""Unit tests for health monitor."""
import pytest
from datetime import datetime, timedelta
from src.data.health_monitor import HealthMonitor, HealthMetrics, HealthStatus
class TestHealthMetrics:
"""Tests for HealthMetrics."""
def test_record_success(self):
"""Test recording successful operation."""
metrics = HealthMetrics()
metrics.record_success(0.5)
assert metrics.status == HealthStatus.HEALTHY
assert metrics.success_count == 1
assert metrics.consecutive_failures == 0
assert len(metrics.response_times) == 1
def test_record_failure(self):
"""Test recording failed operation."""
metrics = HealthMetrics()
metrics.record_failure()
assert metrics.failure_count == 1
assert metrics.consecutive_failures == 1
def test_circuit_breaker(self):
"""Test circuit breaker opening."""
metrics = HealthMetrics()
# Record 5 failures
for _ in range(5):
metrics.record_failure()
assert metrics.circuit_breaker_open is True
assert metrics.consecutive_failures == 5
def test_should_attempt(self):
"""Test should_attempt logic."""
metrics = HealthMetrics()
# Should attempt if circuit breaker not open
assert metrics.should_attempt() is True
# Open circuit breaker
for _ in range(5):
metrics.record_failure()
# Should not attempt immediately
assert metrics.should_attempt(circuit_breaker_timeout=60) is False
def test_get_avg_response_time(self):
"""Test average response time calculation."""
metrics = HealthMetrics()
metrics.response_times.extend([0.1, 0.2, 0.3])
avg = metrics.get_avg_response_time()
assert avg == 0.2
class TestHealthMonitor:
"""Tests for HealthMonitor."""
@pytest.fixture
def monitor(self):
"""Create a health monitor instance."""
return HealthMonitor()
def test_record_success(self, monitor):
"""Test recording success."""
monitor.record_success("provider1", 0.5)
metrics = monitor.get_metrics("provider1")
assert metrics is not None
assert metrics.status == HealthStatus.HEALTHY
assert metrics.success_count == 1
def test_record_failure(self, monitor):
"""Test recording failure."""
monitor.record_failure("provider1")
metrics = monitor.get_metrics("provider1")
assert metrics is not None
assert metrics.failure_count == 1
assert metrics.consecutive_failures == 1
def test_is_healthy(self, monitor):
"""Test health checking."""
# No metrics yet - assume healthy
assert monitor.is_healthy("provider1") is True
# Record success
monitor.record_success("provider1", 0.5)
assert monitor.is_healthy("provider1") is True
# Record many failures
for _ in range(10):
monitor.record_failure("provider1")
assert monitor.is_healthy("provider1") is False
def test_get_health_status(self, monitor):
"""Test getting health status."""
monitor.record_success("provider1", 0.5)
status = monitor.get_health_status("provider1")
assert status == HealthStatus.HEALTHY
def test_select_healthiest(self, monitor):
"""Test selecting healthiest provider."""
# Make provider1 healthy
monitor.record_success("provider1", 0.1)
monitor.record_success("provider1", 0.2)
# Make provider2 unhealthy
monitor.record_failure("provider2")
monitor.record_failure("provider2")
monitor.record_failure("provider2")
healthiest = monitor.select_healthiest(["provider1", "provider2"])
assert healthiest == "provider1"
def test_reset_circuit_breaker(self, monitor):
"""Test resetting circuit breaker."""
# Open circuit breaker
for _ in range(5):
monitor.record_failure("provider1")
assert monitor.get_metrics("provider1").circuit_breaker_open is True
monitor.reset_circuit_breaker("provider1")
metrics = monitor.get_metrics("provider1")
assert metrics.circuit_breaker_open is False
assert metrics.consecutive_failures == 0
def test_reset_metrics(self, monitor):
"""Test resetting metrics."""
monitor.record_success("provider1", 0.5)
assert monitor.get_metrics("provider1") is not None
monitor.reset_metrics("provider1")
assert monitor.get_metrics("provider1") is None

View File

@@ -0,0 +1,68 @@
"""Tests for technical indicators."""
import pytest
import pandas as pd
import numpy as np
from src.data.indicators import get_indicators, TechnicalIndicators
class TestTechnicalIndicators:
"""Tests for TechnicalIndicators."""
@pytest.fixture
def indicators(self):
"""Create indicators instance."""
return get_indicators()
@pytest.fixture
def sample_data(self):
"""Create sample price data."""
dates = pd.date_range(start='2025-01-01', periods=100, freq='1H')
return pd.DataFrame({
'close': [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)],
'high': [101 + i * 0.1 for i in range(100)],
'low': [99 + i * 0.1 for i in range(100)],
'open': [100 + i * 0.1 for i in range(100)],
'volume': [1000.0] * 100
})
def test_sma(self, indicators, sample_data):
"""Test Simple Moving Average."""
sma = indicators.sma(sample_data['close'], period=20)
assert len(sma) == len(sample_data)
assert not sma.isna().all() # Should have some valid values
def test_ema(self, indicators, sample_data):
"""Test Exponential Moving Average."""
ema = indicators.ema(sample_data['close'], period=20)
assert len(ema) == len(sample_data)
def test_rsi(self, indicators, sample_data):
"""Test Relative Strength Index."""
rsi = indicators.rsi(sample_data['close'], period=14)
assert len(rsi) == len(sample_data)
# RSI should be between 0 and 100
valid_rsi = rsi.dropna()
if len(valid_rsi) > 0:
assert (valid_rsi >= 0).all()
assert (valid_rsi <= 100).all()
def test_macd(self, indicators, sample_data):
"""Test MACD."""
macd_result = indicators.macd(sample_data['close'], fast=12, slow=26, signal=9)
assert 'macd' in macd_result
assert 'signal' in macd_result
assert 'histogram' in macd_result
def test_bollinger_bands(self, indicators, sample_data):
"""Test Bollinger Bands."""
bb = indicators.bollinger_bands(sample_data['close'], period=20, std_dev=2)
assert 'upper' in bb
assert 'middle' in bb
assert 'lower' in bb
# Upper should be above middle, middle above lower
valid_data = bb.dropna()
if len(valid_data) > 0:
assert (valid_data['upper'] >= valid_data['middle']).all()
assert (valid_data['middle'] >= valid_data['lower']).all()

View File

@@ -0,0 +1,80 @@
"""Tests for divergence detection in indicators."""
import pytest
import pandas as pd
import numpy as np
from src.data.indicators import get_indicators
class TestDivergenceDetection:
"""Tests for divergence detection."""
@pytest.fixture
def indicators(self):
"""Create indicators instance."""
return get_indicators()
@pytest.fixture
def sample_data(self):
"""Create sample price data with clear trend."""
dates = pd.date_range(start='2025-01-01', periods=100, freq='1H')
# Create price data with trend
prices = [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)]
return pd.Series(prices, index=dates)
def test_detect_divergence_insufficient_data(self, indicators):
"""Test divergence detection with insufficient data."""
prices = pd.Series([100, 101, 102])
indicator = pd.Series([50, 51, 52])
result = indicators.detect_divergence(prices, indicator, lookback=20)
assert result['type'] is None
assert result['confidence'] == 0.0
def test_detect_divergence_structure(self, indicators, sample_data):
"""Test divergence detection returns correct structure."""
# Create indicator data
indicator = pd.Series([50 + i * 0.1 for i in range(100)], index=sample_data.index)
result = indicators.detect_divergence(sample_data, indicator, lookback=20)
# Check structure
assert 'type' in result
assert 'confidence' in result
assert 'price_swing_high' in result
assert 'price_swing_low' in result
assert 'indicator_swing_high' in result
assert 'indicator_swing_low' in result
# Type should be None, 'bullish', or 'bearish'
assert result['type'] in [None, 'bullish', 'bearish']
# Confidence should be 0.0 to 1.0
assert 0.0 <= result['confidence'] <= 1.0
def test_detect_divergence_with_trend(self, indicators):
"""Test divergence detection with clear trend data."""
# Create price making lower lows
prices = pd.Series([100, 95, 90, 85, 80])
# Create indicator making higher lows (bullish divergence)
indicator = pd.Series([30, 32, 34, 36, 38])
# Need more data for lookback
prices_long = pd.concat([pd.Series([110] * 30), prices])
indicator_long = pd.concat([pd.Series([25] * 30), indicator])
result = indicators.detect_divergence(
prices_long,
indicator_long,
lookback=5,
min_swings=2
)
# Should detect bullish divergence (price down, indicator up)
# Note: This may not always detect due to swing detection logic
assert result is not None
assert 'type' in result
assert 'confidence' in result

View File

@@ -0,0 +1,135 @@
"""Unit tests for pricing service."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from decimal import Decimal
from src.data.pricing_service import PricingService, get_pricing_service
from src.data.providers.base_provider import BasePricingProvider
class MockProvider(BasePricingProvider):
"""Mock provider for testing."""
@property
def name(self) -> str:
return "MockProvider"
@property
def supports_websocket(self) -> bool:
return False
def connect(self) -> bool:
self._connected = True
return True
def disconnect(self):
self._connected = False
def get_ticker(self, symbol: str):
return {
'symbol': symbol,
'bid': Decimal('50000'),
'ask': Decimal('50001'),
'last': Decimal('50000.5'),
'high': Decimal('51000'),
'low': Decimal('49000'),
'volume': Decimal('1000000'),
'timestamp': 1609459200000,
}
def get_ohlcv(self, symbol: str, timeframe: str = '1h', since=None, limit: int = 100):
return [[1609459200000, 50000, 51000, 49000, 50000, 1000]]
def subscribe_ticker(self, symbol: str, callback) -> bool:
if symbol not in self._subscribers:
self._subscribers[symbol] = []
self._subscribers[symbol].append(callback)
return True
@pytest.fixture
def mock_config():
"""Create a mock configuration."""
config = Mock()
config.get = Mock(side_effect=lambda key, default=None: {
"data_providers.primary": [
{"name": "mock", "enabled": True, "priority": 1}
],
"data_providers.fallback": {"enabled": True, "api_key": ""},
"data_providers.caching.ticker_ttl": 2,
"data_providers.caching.ohlcv_ttl": 60,
"data_providers.caching.max_cache_size": 1000,
}.get(key, default))
return config
class TestPricingService:
"""Tests for PricingService."""
@patch('src.data.pricing_service.get_config')
@patch('src.data.providers.ccxt_provider.CCXTProvider')
@patch('src.data.providers.coingecko_provider.CoinGeckoProvider')
def test_init(self, mock_coingecko, mock_ccxt, mock_get_config, mock_config):
"""Test service initialization."""
mock_get_config.return_value = mock_config
mock_ccxt_instance = MockProvider()
mock_ccxt.return_value = mock_ccxt_instance
mock_coingecko_instance = MockProvider()
mock_coingecko.return_value = mock_coingecko_instance
service = PricingService()
assert service.cache is not None
assert service.health_monitor is not None
@patch('src.data.pricing_service.get_config')
@patch('src.data.providers.ccxt_provider.CCXTProvider')
def test_get_ticker(self, mock_ccxt, mock_get_config, mock_config):
"""Test getting ticker data."""
mock_get_config.return_value = mock_config
mock_provider = MockProvider()
mock_ccxt.return_value = mock_provider
service = PricingService()
service._providers["MockProvider"] = mock_provider
service._active_provider = "MockProvider"
ticker = service.get_ticker("BTC/USD")
assert ticker['symbol'] == "BTC/USD"
assert isinstance(ticker['last'], Decimal)
@patch('src.data.pricing_service.get_config')
@patch('src.data.providers.ccxt_provider.CCXTProvider')
def test_get_ohlcv(self, mock_ccxt, mock_get_config, mock_config):
"""Test getting OHLCV data."""
mock_get_config.return_value = mock_config
mock_provider = MockProvider()
mock_ccxt.return_value = mock_provider
service = PricingService()
service._providers["MockProvider"] = mock_provider
service._active_provider = "MockProvider"
ohlcv = service.get_ohlcv("BTC/USD", "1h", limit=10)
assert len(ohlcv) > 0
assert len(ohlcv[0]) == 6
@patch('src.data.pricing_service.get_config')
@patch('src.data.providers.ccxt_provider.CCXTProvider')
def test_subscribe_ticker(self, mock_ccxt, mock_get_config, mock_config):
"""Test subscribing to ticker updates."""
mock_get_config.return_value = mock_config
mock_provider = MockProvider()
mock_ccxt.return_value = mock_provider
service = PricingService()
service._providers["MockProvider"] = mock_provider
service._active_provider = "MockProvider"
callback = Mock()
result = service.subscribe_ticker("BTC/USD", callback)
assert result is True

View File

@@ -0,0 +1,118 @@
"""Tests for Redis cache."""
import pytest
from unittest.mock import Mock, patch, AsyncMock
class TestRedisCache:
"""Tests for RedisCache class."""
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_get_ticker_cache_hit(self, mock_get_client):
"""Test getting cached ticker data."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_client.get.return_value = '{"price": 45000.0, "symbol": "BTC/USD"}'
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
result = await cache.get_ticker("BTC/USD")
assert result is not None
assert result["price"] == 45000.0
mock_client.get.assert_called_once()
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_get_ticker_cache_miss(self, mock_get_client):
"""Test ticker cache miss."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_client.get.return_value = None
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
result = await cache.get_ticker("BTC/USD")
assert result is None
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_set_ticker(self, mock_get_client):
"""Test setting ticker cache."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
result = await cache.set_ticker("BTC/USD", {"price": 45000.0})
assert result is True
mock_client.setex.assert_called_once()
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_get_ohlcv(self, mock_get_client):
"""Test getting cached OHLCV data."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_client.get.return_value = '[[1700000000, 45000, 45500, 44500, 45200, 1000]]'
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
result = await cache.get_ohlcv("BTC/USD", "1h", 100)
assert result is not None
assert len(result) == 1
assert result[0][0] == 1700000000
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_set_ohlcv(self, mock_get_client):
"""Test setting OHLCV cache."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
ohlcv_data = [[1700000000, 45000, 45500, 44500, 45200, 1000]]
result = await cache.set_ohlcv("BTC/USD", "1h", ohlcv_data)
assert result is True
mock_client.setex.assert_called_once()
class TestGetRedisCache:
"""Tests for get_redis_cache singleton."""
@patch('src.data.redis_cache.get_redis_client')
def test_returns_singleton(self, mock_get_client):
"""Test that get_redis_cache returns same instance."""
mock_get_client.return_value = Mock()
# Reset the global
import src.data.redis_cache as cache_module
cache_module._redis_cache = None
from src.data.redis_cache import get_redis_cache
cache1 = get_redis_cache()
cache2 = get_redis_cache()
assert cache1 is cache2