Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
1
tests/unit/data/__init__.py
Normal file
1
tests/unit/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test init file."""
|
||||
139
tests/unit/data/providers/test_ccxt_provider.py
Normal file
139
tests/unit/data/providers/test_ccxt_provider.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Unit tests for CCXT pricing provider."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
|
||||
from src.data.providers.ccxt_provider import CCXTProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ccxt_exchange():
|
||||
"""Create a mock CCXT exchange."""
|
||||
exchange = Mock()
|
||||
exchange.markets = {
|
||||
'BTC/USDT': {},
|
||||
'ETH/USDT': {},
|
||||
'BTC/USD': {},
|
||||
}
|
||||
exchange.id = 'kraken'
|
||||
exchange.fetch_ticker = Mock(return_value={
|
||||
'bid': 50000.0,
|
||||
'ask': 50001.0,
|
||||
'last': 50000.5,
|
||||
'high': 51000.0,
|
||||
'low': 49000.0,
|
||||
'quoteVolume': 1000000.0,
|
||||
'timestamp': 1609459200000,
|
||||
})
|
||||
exchange.fetch_ohlcv = Mock(return_value=[
|
||||
[1609459200000, 50000, 51000, 49000, 50000, 1000],
|
||||
])
|
||||
exchange.load_markets = Mock(return_value=exchange.markets)
|
||||
return exchange
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider():
|
||||
"""Create a CCXT provider instance."""
|
||||
return CCXTProvider(exchange_name='kraken')
|
||||
|
||||
|
||||
class TestCCXTProvider:
|
||||
"""Tests for CCXTProvider."""
|
||||
|
||||
def test_init(self, provider):
|
||||
"""Test provider initialization."""
|
||||
assert provider.name == "CCXT Provider"
|
||||
assert not provider._connected
|
||||
assert provider.exchange is None
|
||||
|
||||
def test_name_property(self, provider):
|
||||
"""Test name property."""
|
||||
provider._selected_exchange_id = 'kraken'
|
||||
assert provider.name == "CCXT-Kraken"
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_connect_success(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test successful connection."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
|
||||
result = provider.connect()
|
||||
|
||||
assert result is True
|
||||
assert provider._connected is True
|
||||
assert provider.exchange == mock_ccxt_exchange
|
||||
assert provider._selected_exchange_id == 'kraken'
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_connect_failure(self, mock_ccxt, provider):
|
||||
"""Test connection failure."""
|
||||
mock_ccxt.kraken = Mock(side_effect=Exception("Connection failed"))
|
||||
|
||||
result = provider.connect()
|
||||
|
||||
assert result is False
|
||||
assert not provider._connected
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_get_ticker(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test getting ticker data."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
provider.connect()
|
||||
|
||||
ticker = provider.get_ticker('BTC/USDT')
|
||||
|
||||
assert ticker['symbol'] == 'BTC/USDT'
|
||||
assert isinstance(ticker['bid'], Decimal)
|
||||
assert isinstance(ticker['last'], Decimal)
|
||||
assert ticker['last'] > 0
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_get_ohlcv(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test getting OHLCV data."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
provider.connect()
|
||||
|
||||
ohlcv = provider.get_ohlcv('BTC/USDT', '1h', limit=10)
|
||||
|
||||
assert len(ohlcv) > 0
|
||||
assert len(ohlcv[0]) == 6 # timestamp, open, high, low, close, volume
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_subscribe_ticker(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test subscribing to ticker updates."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
provider.connect()
|
||||
|
||||
callback = Mock()
|
||||
result = provider.subscribe_ticker('BTC/USDT', callback)
|
||||
|
||||
assert result is True
|
||||
assert 'ticker_BTC/USDT' in provider._subscribers
|
||||
|
||||
def test_normalize_symbol(self, provider):
|
||||
"""Test symbol normalization."""
|
||||
# Test with exchange
|
||||
with patch.object(provider, 'exchange') as mock_exchange:
|
||||
mock_exchange.markets = {'BTC/USDT': {}}
|
||||
normalized = provider.normalize_symbol('btc-usdt')
|
||||
assert normalized == 'BTC/USDT'
|
||||
|
||||
# Test without exchange
|
||||
provider.exchange = None
|
||||
normalized = provider.normalize_symbol('btc-usdt')
|
||||
assert normalized == 'BTC/USDT'
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_disconnect(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test disconnection."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
provider.connect()
|
||||
provider.subscribe_ticker('BTC/USDT', Mock())
|
||||
|
||||
provider.disconnect()
|
||||
|
||||
assert not provider._connected
|
||||
assert provider.exchange is None
|
||||
assert len(provider._subscribers) == 0
|
||||
113
tests/unit/data/providers/test_coingecko_provider.py
Normal file
113
tests/unit/data/providers/test_coingecko_provider.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Unit tests for CoinGecko pricing provider."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from decimal import Decimal
|
||||
import httpx
|
||||
|
||||
from src.data.providers.coingecko_provider import CoinGeckoProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider():
|
||||
"""Create a CoinGecko provider instance."""
|
||||
return CoinGeckoProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
"""Create a mock HTTP response."""
|
||||
response = Mock()
|
||||
response.status_code = 200
|
||||
response.json = Mock(return_value={
|
||||
'bitcoin': {
|
||||
'usd': 50000.0,
|
||||
'usd_24h_change': 2.5,
|
||||
'usd_24h_vol': 1000000.0,
|
||||
}
|
||||
})
|
||||
return response
|
||||
|
||||
|
||||
class TestCoinGeckoProvider:
|
||||
"""Tests for CoinGeckoProvider."""
|
||||
|
||||
def test_init(self, provider):
|
||||
"""Test provider initialization."""
|
||||
assert provider.name == "CoinGecko"
|
||||
assert not provider.supports_websocket
|
||||
assert not provider._connected
|
||||
|
||||
@patch('src.data.providers.coingecko_provider.httpx.Client')
|
||||
def test_connect_success(self, mock_client_class, provider, mock_response):
|
||||
"""Test successful connection."""
|
||||
mock_client = Mock()
|
||||
mock_client.get = Mock(return_value=mock_response)
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
result = provider.connect()
|
||||
|
||||
assert result is True
|
||||
assert provider._connected is True
|
||||
|
||||
@patch('src.data.providers.coingecko_provider.httpx.Client')
|
||||
def test_connect_failure(self, mock_client_class, provider):
|
||||
"""Test connection failure."""
|
||||
mock_client = Mock()
|
||||
mock_client.get = Mock(side_effect=Exception("Connection failed"))
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
result = provider.connect()
|
||||
|
||||
assert result is False
|
||||
assert not provider._connected
|
||||
|
||||
def test_parse_symbol(self, provider):
|
||||
"""Test symbol parsing."""
|
||||
coin_id, currency = provider._parse_symbol('BTC/USD')
|
||||
assert coin_id == 'bitcoin'
|
||||
assert currency == 'usd'
|
||||
|
||||
coin_id, currency = provider._parse_symbol('ETH/USDT')
|
||||
assert coin_id == 'ethereum'
|
||||
assert currency == 'usd' # USDT maps to USD
|
||||
|
||||
@patch('src.data.providers.coingecko_provider.httpx.Client')
|
||||
def test_get_ticker(self, mock_client_class, provider, mock_response):
|
||||
"""Test getting ticker data."""
|
||||
mock_client = Mock()
|
||||
mock_client.get = Mock(return_value=mock_response)
|
||||
mock_client_class.return_value = mock_client
|
||||
provider.connect()
|
||||
|
||||
ticker = provider.get_ticker('BTC/USD')
|
||||
|
||||
assert ticker['symbol'] == 'BTC/USD'
|
||||
assert isinstance(ticker['last'], Decimal)
|
||||
assert ticker['last'] > 0
|
||||
assert 'timestamp' in ticker
|
||||
|
||||
@patch('src.data.providers.coingecko_provider.httpx.Client')
|
||||
def test_get_ohlcv(self, mock_client_class, provider):
|
||||
"""Test getting OHLCV data."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=[
|
||||
[1609459200000, 50000, 51000, 49000, 50000],
|
||||
])
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.get = Mock(return_value=mock_response)
|
||||
mock_client_class.return_value = mock_client
|
||||
provider.connect()
|
||||
|
||||
ohlcv = provider.get_ohlcv('BTC/USD', '1h', limit=10)
|
||||
|
||||
assert len(ohlcv) > 0
|
||||
# CoinGecko returns 5 elements, we add volume as 0
|
||||
assert len(ohlcv[0]) == 6
|
||||
|
||||
def test_normalize_symbol(self, provider):
|
||||
"""Test symbol normalization."""
|
||||
normalized = provider.normalize_symbol('btc-usdt')
|
||||
assert normalized == 'BTC/USDT'
|
||||
120
tests/unit/data/test_cache_manager.py
Normal file
120
tests/unit/data/test_cache_manager.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Unit tests for cache manager."""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from src.data.cache_manager import CacheManager, CacheEntry
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Tests for CacheEntry."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test cache entry initialization."""
|
||||
entry = CacheEntry("test_data", 60.0)
|
||||
assert entry.data == "test_data"
|
||||
assert entry.expires_at > time.time()
|
||||
assert entry.access_count == 0
|
||||
|
||||
def test_is_expired(self):
|
||||
"""Test expiration checking."""
|
||||
entry = CacheEntry("test_data", 0.01) # Very short TTL
|
||||
assert not entry.is_expired()
|
||||
time.sleep(0.02)
|
||||
assert entry.is_expired()
|
||||
|
||||
def test_touch(self):
|
||||
"""Test access tracking."""
|
||||
entry = CacheEntry("test_data", 60.0)
|
||||
initial_count = entry.access_count
|
||||
entry.touch()
|
||||
assert entry.access_count == initial_count + 1
|
||||
|
||||
|
||||
class TestCacheManager:
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self):
|
||||
"""Create a cache manager instance."""
|
||||
return CacheManager(default_ttl=1.0, max_size=10)
|
||||
|
||||
def test_get_set(self, cache):
|
||||
"""Test basic get and set operations."""
|
||||
cache.set("key1", "value1")
|
||||
assert cache.get("key1") == "value1"
|
||||
|
||||
def test_get_missing(self, cache):
|
||||
"""Test getting non-existent key."""
|
||||
assert cache.get("missing") is None
|
||||
|
||||
def test_expiration(self, cache):
|
||||
"""Test cache entry expiration."""
|
||||
cache.set("key1", "value1", ttl=0.1)
|
||||
assert cache.get("key1") == "value1"
|
||||
time.sleep(0.2)
|
||||
assert cache.get("key1") is None
|
||||
|
||||
def test_lru_eviction(self, cache):
|
||||
"""Test LRU eviction when max size reached."""
|
||||
# Fill cache to max size
|
||||
for i in range(10):
|
||||
cache.set(f"key{i}", f"value{i}")
|
||||
|
||||
# Add one more - should evict oldest
|
||||
cache.set("key10", "value10")
|
||||
|
||||
# Oldest key should be evicted
|
||||
assert cache.get("key0") is None
|
||||
assert cache.get("key10") == "value10"
|
||||
|
||||
def test_type_specific_ttl(self, cache):
|
||||
"""Test type-specific TTL."""
|
||||
cache.set("ticker1", {"price": 100}, cache_type='ticker')
|
||||
cache.set("ohlcv1", [[1, 2, 3, 4, 5, 6]], cache_type='ohlcv')
|
||||
|
||||
# Both should be cached
|
||||
assert cache.get("ticker1") is not None
|
||||
assert cache.get("ohlcv1") is not None
|
||||
|
||||
def test_delete(self, cache):
|
||||
"""Test cache entry deletion."""
|
||||
cache.set("key1", "value1")
|
||||
assert cache.get("key1") == "value1"
|
||||
|
||||
cache.delete("key1")
|
||||
assert cache.get("key1") is None
|
||||
|
||||
def test_clear(self, cache):
|
||||
"""Test cache clearing."""
|
||||
cache.set("key1", "value1")
|
||||
cache.set("key2", "value2")
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert cache.get("key1") is None
|
||||
assert cache.get("key2") is None
|
||||
|
||||
def test_stats(self, cache):
|
||||
"""Test cache statistics."""
|
||||
cache.set("key1", "value1")
|
||||
cache.get("key1") # Hit
|
||||
cache.get("missing") # Miss
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats['hits'] >= 1
|
||||
assert stats['misses'] >= 1
|
||||
assert stats['size'] == 1
|
||||
assert 'hit_rate' in stats
|
||||
|
||||
def test_invalidate_pattern(self, cache):
|
||||
"""Test pattern-based invalidation."""
|
||||
cache.set("ticker:BTC/USD", "value1")
|
||||
cache.set("ticker:ETH/USD", "value2")
|
||||
cache.set("ohlcv:BTC/USD", "value3")
|
||||
|
||||
cache.invalidate_pattern("ticker:")
|
||||
|
||||
assert cache.get("ticker:BTC/USD") is None
|
||||
assert cache.get("ticker:ETH/USD") is None
|
||||
assert cache.get("ohlcv:BTC/USD") is not None
|
||||
145
tests/unit/data/test_health_monitor.py
Normal file
145
tests/unit/data/test_health_monitor.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Unit tests for health monitor."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.data.health_monitor import HealthMonitor, HealthMetrics, HealthStatus
|
||||
|
||||
|
||||
class TestHealthMetrics:
|
||||
"""Tests for HealthMetrics."""
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful operation."""
|
||||
metrics = HealthMetrics()
|
||||
metrics.record_success(0.5)
|
||||
|
||||
assert metrics.status == HealthStatus.HEALTHY
|
||||
assert metrics.success_count == 1
|
||||
assert metrics.consecutive_failures == 0
|
||||
assert len(metrics.response_times) == 1
|
||||
|
||||
def test_record_failure(self):
|
||||
"""Test recording failed operation."""
|
||||
metrics = HealthMetrics()
|
||||
metrics.record_failure()
|
||||
|
||||
assert metrics.failure_count == 1
|
||||
assert metrics.consecutive_failures == 1
|
||||
|
||||
def test_circuit_breaker(self):
|
||||
"""Test circuit breaker opening."""
|
||||
metrics = HealthMetrics()
|
||||
|
||||
# Record 5 failures
|
||||
for _ in range(5):
|
||||
metrics.record_failure()
|
||||
|
||||
assert metrics.circuit_breaker_open is True
|
||||
assert metrics.consecutive_failures == 5
|
||||
|
||||
def test_should_attempt(self):
|
||||
"""Test should_attempt logic."""
|
||||
metrics = HealthMetrics()
|
||||
|
||||
# Should attempt if circuit breaker not open
|
||||
assert metrics.should_attempt() is True
|
||||
|
||||
# Open circuit breaker
|
||||
for _ in range(5):
|
||||
metrics.record_failure()
|
||||
|
||||
# Should not attempt immediately
|
||||
assert metrics.should_attempt(circuit_breaker_timeout=60) is False
|
||||
|
||||
def test_get_avg_response_time(self):
|
||||
"""Test average response time calculation."""
|
||||
metrics = HealthMetrics()
|
||||
metrics.response_times.extend([0.1, 0.2, 0.3])
|
||||
|
||||
avg = metrics.get_avg_response_time()
|
||||
assert avg == 0.2
|
||||
|
||||
|
||||
class TestHealthMonitor:
|
||||
"""Tests for HealthMonitor."""
|
||||
|
||||
@pytest.fixture
|
||||
def monitor(self):
|
||||
"""Create a health monitor instance."""
|
||||
return HealthMonitor()
|
||||
|
||||
def test_record_success(self, monitor):
|
||||
"""Test recording success."""
|
||||
monitor.record_success("provider1", 0.5)
|
||||
|
||||
metrics = monitor.get_metrics("provider1")
|
||||
assert metrics is not None
|
||||
assert metrics.status == HealthStatus.HEALTHY
|
||||
assert metrics.success_count == 1
|
||||
|
||||
def test_record_failure(self, monitor):
|
||||
"""Test recording failure."""
|
||||
monitor.record_failure("provider1")
|
||||
|
||||
metrics = monitor.get_metrics("provider1")
|
||||
assert metrics is not None
|
||||
assert metrics.failure_count == 1
|
||||
assert metrics.consecutive_failures == 1
|
||||
|
||||
def test_is_healthy(self, monitor):
|
||||
"""Test health checking."""
|
||||
# No metrics yet - assume healthy
|
||||
assert monitor.is_healthy("provider1") is True
|
||||
|
||||
# Record success
|
||||
monitor.record_success("provider1", 0.5)
|
||||
assert monitor.is_healthy("provider1") is True
|
||||
|
||||
# Record many failures
|
||||
for _ in range(10):
|
||||
monitor.record_failure("provider1")
|
||||
|
||||
assert monitor.is_healthy("provider1") is False
|
||||
|
||||
def test_get_health_status(self, monitor):
|
||||
"""Test getting health status."""
|
||||
monitor.record_success("provider1", 0.5)
|
||||
status = monitor.get_health_status("provider1")
|
||||
assert status == HealthStatus.HEALTHY
|
||||
|
||||
def test_select_healthiest(self, monitor):
|
||||
"""Test selecting healthiest provider."""
|
||||
# Make provider1 healthy
|
||||
monitor.record_success("provider1", 0.1)
|
||||
monitor.record_success("provider1", 0.2)
|
||||
|
||||
# Make provider2 unhealthy
|
||||
monitor.record_failure("provider2")
|
||||
monitor.record_failure("provider2")
|
||||
monitor.record_failure("provider2")
|
||||
|
||||
healthiest = monitor.select_healthiest(["provider1", "provider2"])
|
||||
assert healthiest == "provider1"
|
||||
|
||||
def test_reset_circuit_breaker(self, monitor):
|
||||
"""Test resetting circuit breaker."""
|
||||
# Open circuit breaker
|
||||
for _ in range(5):
|
||||
monitor.record_failure("provider1")
|
||||
|
||||
assert monitor.get_metrics("provider1").circuit_breaker_open is True
|
||||
|
||||
monitor.reset_circuit_breaker("provider1")
|
||||
|
||||
metrics = monitor.get_metrics("provider1")
|
||||
assert metrics.circuit_breaker_open is False
|
||||
assert metrics.consecutive_failures == 0
|
||||
|
||||
def test_reset_metrics(self, monitor):
|
||||
"""Test resetting metrics."""
|
||||
monitor.record_success("provider1", 0.5)
|
||||
assert monitor.get_metrics("provider1") is not None
|
||||
|
||||
monitor.reset_metrics("provider1")
|
||||
assert monitor.get_metrics("provider1") is None
|
||||
68
tests/unit/data/test_indicators.py
Normal file
68
tests/unit/data/test_indicators.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Tests for technical indicators."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from src.data.indicators import get_indicators, TechnicalIndicators
|
||||
|
||||
|
||||
class TestTechnicalIndicators:
|
||||
"""Tests for TechnicalIndicators."""
|
||||
|
||||
@pytest.fixture
|
||||
def indicators(self):
|
||||
"""Create indicators instance."""
|
||||
return get_indicators()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample price data."""
|
||||
dates = pd.date_range(start='2025-01-01', periods=100, freq='1H')
|
||||
return pd.DataFrame({
|
||||
'close': [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)],
|
||||
'high': [101 + i * 0.1 for i in range(100)],
|
||||
'low': [99 + i * 0.1 for i in range(100)],
|
||||
'open': [100 + i * 0.1 for i in range(100)],
|
||||
'volume': [1000.0] * 100
|
||||
})
|
||||
|
||||
def test_sma(self, indicators, sample_data):
|
||||
"""Test Simple Moving Average."""
|
||||
sma = indicators.sma(sample_data['close'], period=20)
|
||||
assert len(sma) == len(sample_data)
|
||||
assert not sma.isna().all() # Should have some valid values
|
||||
|
||||
def test_ema(self, indicators, sample_data):
|
||||
"""Test Exponential Moving Average."""
|
||||
ema = indicators.ema(sample_data['close'], period=20)
|
||||
assert len(ema) == len(sample_data)
|
||||
|
||||
def test_rsi(self, indicators, sample_data):
|
||||
"""Test Relative Strength Index."""
|
||||
rsi = indicators.rsi(sample_data['close'], period=14)
|
||||
assert len(rsi) == len(sample_data)
|
||||
# RSI should be between 0 and 100
|
||||
valid_rsi = rsi.dropna()
|
||||
if len(valid_rsi) > 0:
|
||||
assert (valid_rsi >= 0).all()
|
||||
assert (valid_rsi <= 100).all()
|
||||
|
||||
def test_macd(self, indicators, sample_data):
|
||||
"""Test MACD."""
|
||||
macd_result = indicators.macd(sample_data['close'], fast=12, slow=26, signal=9)
|
||||
assert 'macd' in macd_result
|
||||
assert 'signal' in macd_result
|
||||
assert 'histogram' in macd_result
|
||||
|
||||
def test_bollinger_bands(self, indicators, sample_data):
|
||||
"""Test Bollinger Bands."""
|
||||
bb = indicators.bollinger_bands(sample_data['close'], period=20, std_dev=2)
|
||||
assert 'upper' in bb
|
||||
assert 'middle' in bb
|
||||
assert 'lower' in bb
|
||||
# Upper should be above middle, middle above lower
|
||||
valid_data = bb.dropna()
|
||||
if len(valid_data) > 0:
|
||||
assert (valid_data['upper'] >= valid_data['middle']).all()
|
||||
assert (valid_data['middle'] >= valid_data['lower']).all()
|
||||
|
||||
80
tests/unit/data/test_indicators_divergence.py
Normal file
80
tests/unit/data/test_indicators_divergence.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Tests for divergence detection in indicators."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
|
||||
class TestDivergenceDetection:
|
||||
"""Tests for divergence detection."""
|
||||
|
||||
@pytest.fixture
|
||||
def indicators(self):
|
||||
"""Create indicators instance."""
|
||||
return get_indicators()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample price data with clear trend."""
|
||||
dates = pd.date_range(start='2025-01-01', periods=100, freq='1H')
|
||||
# Create price data with trend
|
||||
prices = [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)]
|
||||
return pd.Series(prices, index=dates)
|
||||
|
||||
def test_detect_divergence_insufficient_data(self, indicators):
|
||||
"""Test divergence detection with insufficient data."""
|
||||
prices = pd.Series([100, 101, 102])
|
||||
indicator = pd.Series([50, 51, 52])
|
||||
|
||||
result = indicators.detect_divergence(prices, indicator, lookback=20)
|
||||
|
||||
assert result['type'] is None
|
||||
assert result['confidence'] == 0.0
|
||||
|
||||
def test_detect_divergence_structure(self, indicators, sample_data):
|
||||
"""Test divergence detection returns correct structure."""
|
||||
# Create indicator data
|
||||
indicator = pd.Series([50 + i * 0.1 for i in range(100)], index=sample_data.index)
|
||||
|
||||
result = indicators.detect_divergence(sample_data, indicator, lookback=20)
|
||||
|
||||
# Check structure
|
||||
assert 'type' in result
|
||||
assert 'confidence' in result
|
||||
assert 'price_swing_high' in result
|
||||
assert 'price_swing_low' in result
|
||||
assert 'indicator_swing_high' in result
|
||||
assert 'indicator_swing_low' in result
|
||||
|
||||
# Type should be None, 'bullish', or 'bearish'
|
||||
assert result['type'] in [None, 'bullish', 'bearish']
|
||||
|
||||
# Confidence should be 0.0 to 1.0
|
||||
assert 0.0 <= result['confidence'] <= 1.0
|
||||
|
||||
def test_detect_divergence_with_trend(self, indicators):
|
||||
"""Test divergence detection with clear trend data."""
|
||||
# Create price making lower lows
|
||||
prices = pd.Series([100, 95, 90, 85, 80])
|
||||
|
||||
# Create indicator making higher lows (bullish divergence)
|
||||
indicator = pd.Series([30, 32, 34, 36, 38])
|
||||
|
||||
# Need more data for lookback
|
||||
prices_long = pd.concat([pd.Series([110] * 30), prices])
|
||||
indicator_long = pd.concat([pd.Series([25] * 30), indicator])
|
||||
|
||||
result = indicators.detect_divergence(
|
||||
prices_long,
|
||||
indicator_long,
|
||||
lookback=5,
|
||||
min_swings=2
|
||||
)
|
||||
|
||||
# Should detect bullish divergence (price down, indicator up)
|
||||
# Note: This may not always detect due to swing detection logic
|
||||
assert result is not None
|
||||
assert 'type' in result
|
||||
assert 'confidence' in result
|
||||
|
||||
135
tests/unit/data/test_pricing_service.py
Normal file
135
tests/unit/data/test_pricing_service.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Unit tests for pricing service."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from decimal import Decimal
|
||||
|
||||
from src.data.pricing_service import PricingService, get_pricing_service
|
||||
from src.data.providers.base_provider import BasePricingProvider
|
||||
|
||||
|
||||
class MockProvider(BasePricingProvider):
|
||||
"""Mock provider for testing."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "MockProvider"
|
||||
|
||||
@property
|
||||
def supports_websocket(self) -> bool:
|
||||
return False
|
||||
|
||||
def connect(self) -> bool:
|
||||
self._connected = True
|
||||
return True
|
||||
|
||||
def disconnect(self):
|
||||
self._connected = False
|
||||
|
||||
def get_ticker(self, symbol: str):
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': Decimal('50000'),
|
||||
'ask': Decimal('50001'),
|
||||
'last': Decimal('50000.5'),
|
||||
'high': Decimal('51000'),
|
||||
'low': Decimal('49000'),
|
||||
'volume': Decimal('1000000'),
|
||||
'timestamp': 1609459200000,
|
||||
}
|
||||
|
||||
def get_ohlcv(self, symbol: str, timeframe: str = '1h', since=None, limit: int = 100):
|
||||
return [[1609459200000, 50000, 51000, 49000, 50000, 1000]]
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback) -> bool:
|
||||
if symbol not in self._subscribers:
|
||||
self._subscribers[symbol] = []
|
||||
self._subscribers[symbol].append(callback)
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock configuration."""
|
||||
config = Mock()
|
||||
config.get = Mock(side_effect=lambda key, default=None: {
|
||||
"data_providers.primary": [
|
||||
{"name": "mock", "enabled": True, "priority": 1}
|
||||
],
|
||||
"data_providers.fallback": {"enabled": True, "api_key": ""},
|
||||
"data_providers.caching.ticker_ttl": 2,
|
||||
"data_providers.caching.ohlcv_ttl": 60,
|
||||
"data_providers.caching.max_cache_size": 1000,
|
||||
}.get(key, default))
|
||||
return config
|
||||
|
||||
|
||||
class TestPricingService:
|
||||
"""Tests for PricingService."""
|
||||
|
||||
@patch('src.data.pricing_service.get_config')
|
||||
@patch('src.data.providers.ccxt_provider.CCXTProvider')
|
||||
@patch('src.data.providers.coingecko_provider.CoinGeckoProvider')
|
||||
def test_init(self, mock_coingecko, mock_ccxt, mock_get_config, mock_config):
|
||||
"""Test service initialization."""
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_ccxt_instance = MockProvider()
|
||||
mock_ccxt.return_value = mock_ccxt_instance
|
||||
mock_coingecko_instance = MockProvider()
|
||||
mock_coingecko.return_value = mock_coingecko_instance
|
||||
|
||||
service = PricingService()
|
||||
|
||||
assert service.cache is not None
|
||||
assert service.health_monitor is not None
|
||||
|
||||
@patch('src.data.pricing_service.get_config')
|
||||
@patch('src.data.providers.ccxt_provider.CCXTProvider')
|
||||
def test_get_ticker(self, mock_ccxt, mock_get_config, mock_config):
|
||||
"""Test getting ticker data."""
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_provider = MockProvider()
|
||||
mock_ccxt.return_value = mock_provider
|
||||
|
||||
service = PricingService()
|
||||
service._providers["MockProvider"] = mock_provider
|
||||
service._active_provider = "MockProvider"
|
||||
|
||||
ticker = service.get_ticker("BTC/USD")
|
||||
|
||||
assert ticker['symbol'] == "BTC/USD"
|
||||
assert isinstance(ticker['last'], Decimal)
|
||||
|
||||
@patch('src.data.pricing_service.get_config')
|
||||
@patch('src.data.providers.ccxt_provider.CCXTProvider')
|
||||
def test_get_ohlcv(self, mock_ccxt, mock_get_config, mock_config):
|
||||
"""Test getting OHLCV data."""
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_provider = MockProvider()
|
||||
mock_ccxt.return_value = mock_provider
|
||||
|
||||
service = PricingService()
|
||||
service._providers["MockProvider"] = mock_provider
|
||||
service._active_provider = "MockProvider"
|
||||
|
||||
ohlcv = service.get_ohlcv("BTC/USD", "1h", limit=10)
|
||||
|
||||
assert len(ohlcv) > 0
|
||||
assert len(ohlcv[0]) == 6
|
||||
|
||||
@patch('src.data.pricing_service.get_config')
|
||||
@patch('src.data.providers.ccxt_provider.CCXTProvider')
|
||||
def test_subscribe_ticker(self, mock_ccxt, mock_get_config, mock_config):
|
||||
"""Test subscribing to ticker updates."""
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_provider = MockProvider()
|
||||
mock_ccxt.return_value = mock_provider
|
||||
|
||||
service = PricingService()
|
||||
service._providers["MockProvider"] = mock_provider
|
||||
service._active_provider = "MockProvider"
|
||||
|
||||
callback = Mock()
|
||||
result = service.subscribe_ticker("BTC/USD", callback)
|
||||
|
||||
assert result is True
|
||||
118
tests/unit/data/test_redis_cache.py
Normal file
118
tests/unit/data/test_redis_cache.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Tests for Redis cache."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
|
||||
|
||||
class TestRedisCache:
|
||||
"""Tests for RedisCache class."""
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ticker_cache_hit(self, mock_get_client):
|
||||
"""Test getting cached ticker data."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = '{"price": 45000.0, "symbol": "BTC/USD"}'
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
result = await cache.get_ticker("BTC/USD")
|
||||
|
||||
assert result is not None
|
||||
assert result["price"] == 45000.0
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ticker_cache_miss(self, mock_get_client):
|
||||
"""Test ticker cache miss."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = None
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
result = await cache.get_ticker("BTC/USD")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_ticker(self, mock_get_client):
|
||||
"""Test setting ticker cache."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
result = await cache.set_ticker("BTC/USD", {"price": 45000.0})
|
||||
|
||||
assert result is True
|
||||
mock_client.setex.assert_called_once()
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ohlcv(self, mock_get_client):
|
||||
"""Test getting cached OHLCV data."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = '[[1700000000, 45000, 45500, 44500, 45200, 1000]]'
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
result = await cache.get_ohlcv("BTC/USD", "1h", 100)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == 1700000000
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_ohlcv(self, mock_get_client):
|
||||
"""Test setting OHLCV cache."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
ohlcv_data = [[1700000000, 45000, 45500, 44500, 45200, 1000]]
|
||||
result = await cache.set_ohlcv("BTC/USD", "1h", ohlcv_data)
|
||||
|
||||
assert result is True
|
||||
mock_client.setex.assert_called_once()
|
||||
|
||||
|
||||
class TestGetRedisCache:
|
||||
"""Tests for get_redis_cache singleton."""
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
def test_returns_singleton(self, mock_get_client):
|
||||
"""Test that get_redis_cache returns same instance."""
|
||||
mock_get_client.return_value = Mock()
|
||||
|
||||
# Reset the global
|
||||
import src.data.redis_cache as cache_module
|
||||
cache_module._redis_cache = None
|
||||
|
||||
from src.data.redis_cache import get_redis_cache
|
||||
|
||||
cache1 = get_redis_cache()
|
||||
cache2 = get_redis_cache()
|
||||
|
||||
assert cache1 is cache2
|
||||
Reference in New Issue
Block a user