"""Tests for technical indicators.""" import pytest import pandas as pd import numpy as np from src.data.indicators import get_indicators, TechnicalIndicators class TestTechnicalIndicators: """Tests for TechnicalIndicators.""" @pytest.fixture def indicators(self): """Create indicators instance.""" return get_indicators() @pytest.fixture def sample_data(self): """Create sample price data.""" dates = pd.date_range(start='2025-01-01', periods=100, freq='1H') return pd.DataFrame({ 'close': [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)], 'high': [101 + i * 0.1 for i in range(100)], 'low': [99 + i * 0.1 for i in range(100)], 'open': [100 + i * 0.1 for i in range(100)], 'volume': [1000.0] * 100 }) def test_sma(self, indicators, sample_data): """Test Simple Moving Average.""" sma = indicators.sma(sample_data['close'], period=20) assert len(sma) == len(sample_data) assert not sma.isna().all() # Should have some valid values def test_ema(self, indicators, sample_data): """Test Exponential Moving Average.""" ema = indicators.ema(sample_data['close'], period=20) assert len(ema) == len(sample_data) def test_rsi(self, indicators, sample_data): """Test Relative Strength Index.""" rsi = indicators.rsi(sample_data['close'], period=14) assert len(rsi) == len(sample_data) # RSI should be between 0 and 100 valid_rsi = rsi.dropna() if len(valid_rsi) > 0: assert (valid_rsi >= 0).all() assert (valid_rsi <= 100).all() def test_macd(self, indicators, sample_data): """Test MACD.""" macd_result = indicators.macd(sample_data['close'], fast=12, slow=26, signal=9) assert 'macd' in macd_result assert 'signal' in macd_result assert 'histogram' in macd_result def test_bollinger_bands(self, indicators, sample_data): """Test Bollinger Bands.""" bb = indicators.bollinger_bands(sample_data['close'], period=20, std_dev=2) assert 'upper' in bb assert 'middle' in bb assert 'lower' in bb # Upper should be above middle, middle above lower valid_data = bb.dropna() if len(valid_data) > 0: assert (valid_data['upper'] >= valid_data['middle']).all() assert (valid_data['middle'] >= valid_data['lower']).all()