Files
crypto_trader/tests/unit/data/test_indicators.py

69 lines
2.5 KiB
Python
Raw Permalink Normal View History

"""Tests for technical indicators."""
import pytest
import pandas as pd
import numpy as np
from src.data.indicators import get_indicators, TechnicalIndicators
class TestTechnicalIndicators:
"""Tests for TechnicalIndicators."""
@pytest.fixture
def indicators(self):
"""Create indicators instance."""
return get_indicators()
@pytest.fixture
def sample_data(self):
"""Create sample price data."""
dates = pd.date_range(start='2025-01-01', periods=100, freq='1H')
return pd.DataFrame({
'close': [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)],
'high': [101 + i * 0.1 for i in range(100)],
'low': [99 + i * 0.1 for i in range(100)],
'open': [100 + i * 0.1 for i in range(100)],
'volume': [1000.0] * 100
})
def test_sma(self, indicators, sample_data):
"""Test Simple Moving Average."""
sma = indicators.sma(sample_data['close'], period=20)
assert len(sma) == len(sample_data)
assert not sma.isna().all() # Should have some valid values
def test_ema(self, indicators, sample_data):
"""Test Exponential Moving Average."""
ema = indicators.ema(sample_data['close'], period=20)
assert len(ema) == len(sample_data)
def test_rsi(self, indicators, sample_data):
"""Test Relative Strength Index."""
rsi = indicators.rsi(sample_data['close'], period=14)
assert len(rsi) == len(sample_data)
# RSI should be between 0 and 100
valid_rsi = rsi.dropna()
if len(valid_rsi) > 0:
assert (valid_rsi >= 0).all()
assert (valid_rsi <= 100).all()
def test_macd(self, indicators, sample_data):
"""Test MACD."""
macd_result = indicators.macd(sample_data['close'], fast=12, slow=26, signal=9)
assert 'macd' in macd_result
assert 'signal' in macd_result
assert 'histogram' in macd_result
def test_bollinger_bands(self, indicators, sample_data):
"""Test Bollinger Bands."""
bb = indicators.bollinger_bands(sample_data['close'], period=20, std_dev=2)
assert 'upper' in bb
assert 'middle' in bb
assert 'lower' in bb
# Upper should be above middle, middle above lower
valid_data = bb.dropna()
if len(valid_data) > 0:
assert (valid_data['upper'] >= valid_data['middle']).all()
assert (valid_data['middle'] >= valid_data['lower']).all()