Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
2
tests/unit/__init__.py
Normal file
2
tests/unit/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests."""
|
||||
|
||||
2
tests/unit/alerts/__init__.py
Normal file
2
tests/unit/alerts/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for alert system."""
|
||||
|
||||
33
tests/unit/alerts/test_engine.py
Normal file
33
tests/unit/alerts/test_engine.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Tests for alert engine."""
|
||||
|
||||
import pytest
|
||||
from src.alerts.engine import get_alert_engine, AlertEngine
|
||||
|
||||
|
||||
class TestAlertEngine:
|
||||
"""Tests for AlertEngine."""
|
||||
|
||||
@pytest.fixture
|
||||
def alert_engine(self):
|
||||
"""Create alert engine instance."""
|
||||
return get_alert_engine()
|
||||
|
||||
def test_alert_engine_initialization(self, alert_engine):
|
||||
"""Test alert engine initialization."""
|
||||
assert alert_engine is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_alert(self, alert_engine):
|
||||
"""Test processing alert."""
|
||||
# Create test alert
|
||||
alert = {
|
||||
'name': 'test_alert',
|
||||
'type': 'price',
|
||||
'condition': 'BTC/USD > 50000',
|
||||
'is_active': True
|
||||
}
|
||||
|
||||
# Process alert (simplified)
|
||||
result = await alert_engine.process_alert(alert, {'BTC/USD': 51000.0})
|
||||
assert isinstance(result, bool)
|
||||
|
||||
2
tests/unit/autopilot/__init__.py
Normal file
2
tests/unit/autopilot/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Tests for autopilot module."""
|
||||
|
||||
175
tests/unit/autopilot/test_intelligent_autopilot.py
Normal file
175
tests/unit/autopilot/test_intelligent_autopilot.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Tests for intelligent autopilot functionality."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from src.core.database import OrderSide, OrderType
|
||||
|
||||
|
||||
class TestPreFlightValidation:
|
||||
"""Tests for pre-flight order validation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_autopilot(self):
|
||||
"""Create mock autopilot with necessary attributes."""
|
||||
from src.autopilot.intelligent_autopilot import IntelligentAutopilot
|
||||
|
||||
with patch.object(IntelligentAutopilot, '__init__', lambda x, *args, **kwargs: None):
|
||||
autopilot = IntelligentAutopilot.__new__(IntelligentAutopilot)
|
||||
autopilot.symbol = 'BTC/USD'
|
||||
autopilot.paper_trading = True
|
||||
autopilot.logger = Mock()
|
||||
|
||||
# Mock trading engine
|
||||
autopilot.trading_engine = Mock()
|
||||
autopilot.trading_engine.paper_trading = Mock()
|
||||
autopilot.trading_engine.paper_trading.get_balance.return_value = Decimal('1000.0')
|
||||
autopilot.trading_engine.paper_trading.get_positions.return_value = []
|
||||
|
||||
return autopilot
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_execute_order_insufficient_funds(self, mock_autopilot):
|
||||
"""Test that insufficient funds returns False."""
|
||||
mock_autopilot.trading_engine.paper_trading.get_balance.return_value = Decimal('10.0')
|
||||
|
||||
can_execute, reason = await mock_autopilot._can_execute_order(
|
||||
side=OrderSide.BUY,
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('100.0')
|
||||
)
|
||||
|
||||
assert can_execute is False
|
||||
assert 'Insufficient funds' in reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_execute_order_sufficient_funds(self, mock_autopilot):
|
||||
"""Test that sufficient funds returns True."""
|
||||
mock_autopilot.trading_engine.paper_trading.get_balance.return_value = Decimal('1000.0')
|
||||
|
||||
can_execute, reason = await mock_autopilot._can_execute_order(
|
||||
side=OrderSide.BUY,
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('100.0')
|
||||
)
|
||||
|
||||
assert can_execute is True
|
||||
assert reason == 'OK'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_execute_order_no_position_for_sell(self, mock_autopilot):
|
||||
"""Test that SELL without position returns False."""
|
||||
mock_autopilot.trading_engine.paper_trading.get_positions.return_value = []
|
||||
|
||||
can_execute, reason = await mock_autopilot._can_execute_order(
|
||||
side=OrderSide.SELL,
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('100.0')
|
||||
)
|
||||
|
||||
assert can_execute is False
|
||||
assert 'No position to sell' in reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_execute_order_minimum_value(self, mock_autopilot):
|
||||
"""Test that order below minimum value returns False."""
|
||||
can_execute, reason = await mock_autopilot._can_execute_order(
|
||||
side=OrderSide.BUY,
|
||||
quantity=Decimal('0.001'),
|
||||
price=Decimal('0.10') # Order value = $0.0001
|
||||
)
|
||||
|
||||
assert can_execute is False
|
||||
assert 'below minimum' in reason
|
||||
|
||||
|
||||
class TestSmartOrderTypeSelection:
|
||||
"""Tests for smart order type selection."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_autopilot(self):
|
||||
"""Create mock autopilot for order type tests."""
|
||||
from src.autopilot.intelligent_autopilot import IntelligentAutopilot
|
||||
|
||||
with patch.object(IntelligentAutopilot, '__init__', lambda x, *args, **kwargs: None):
|
||||
autopilot = IntelligentAutopilot.__new__(IntelligentAutopilot)
|
||||
autopilot.logger = Mock()
|
||||
return autopilot
|
||||
|
||||
def test_strong_signal_uses_market(self, mock_autopilot):
|
||||
"""Test that strong signals (>80%) use MARKET orders."""
|
||||
order_type, price = mock_autopilot._determine_order_type_and_price(
|
||||
side=OrderSide.BUY,
|
||||
signal_strength=0.85,
|
||||
current_price=Decimal('100.0'),
|
||||
is_stop_loss=False
|
||||
)
|
||||
|
||||
assert order_type == OrderType.MARKET
|
||||
assert price is None
|
||||
|
||||
def test_normal_signal_uses_limit(self, mock_autopilot):
|
||||
"""Test that normal signals use LIMIT orders."""
|
||||
order_type, price = mock_autopilot._determine_order_type_and_price(
|
||||
side=OrderSide.BUY,
|
||||
signal_strength=0.65,
|
||||
current_price=Decimal('100.0'),
|
||||
is_stop_loss=False
|
||||
)
|
||||
|
||||
assert order_type == OrderType.LIMIT
|
||||
assert price is not None
|
||||
# BUY limit should be below market
|
||||
assert price < Decimal('100.0')
|
||||
|
||||
def test_stop_loss_uses_market(self, mock_autopilot):
|
||||
"""Test that stop-loss exits use MARKET orders."""
|
||||
order_type, price = mock_autopilot._determine_order_type_and_price(
|
||||
side=OrderSide.SELL,
|
||||
signal_strength=0.5,
|
||||
current_price=Decimal('100.0'),
|
||||
is_stop_loss=True
|
||||
)
|
||||
|
||||
assert order_type == OrderType.MARKET
|
||||
assert price is None
|
||||
|
||||
def test_take_profit_uses_limit(self, mock_autopilot):
|
||||
"""Test that take-profit exits can use LIMIT orders."""
|
||||
order_type, price = mock_autopilot._determine_order_type_and_price(
|
||||
side=OrderSide.SELL,
|
||||
signal_strength=0.6,
|
||||
current_price=Decimal('100.0'),
|
||||
is_stop_loss=False
|
||||
)
|
||||
|
||||
assert order_type == OrderType.LIMIT
|
||||
assert price is not None
|
||||
# SELL limit should be above market
|
||||
assert price > Decimal('100.0')
|
||||
|
||||
def test_buy_limit_price_discount(self, mock_autopilot):
|
||||
"""Test BUY LIMIT price is 0.1% below market."""
|
||||
order_type, price = mock_autopilot._determine_order_type_and_price(
|
||||
side=OrderSide.BUY,
|
||||
signal_strength=0.6,
|
||||
current_price=Decimal('1000.00'),
|
||||
is_stop_loss=False
|
||||
)
|
||||
|
||||
# 0.1% discount = 999.00
|
||||
expected = Decimal('999.00')
|
||||
assert price == expected
|
||||
|
||||
def test_sell_limit_price_premium(self, mock_autopilot):
|
||||
"""Test SELL LIMIT price is 0.1% above market."""
|
||||
order_type, price = mock_autopilot._determine_order_type_and_price(
|
||||
side=OrderSide.SELL,
|
||||
signal_strength=0.6,
|
||||
current_price=Decimal('1000.00'),
|
||||
is_stop_loss=False
|
||||
)
|
||||
|
||||
# 0.1% premium = 1001.00
|
||||
expected = Decimal('1001.00')
|
||||
assert price == expected
|
||||
161
tests/unit/autopilot/test_strategy_groups.py
Normal file
161
tests/unit/autopilot/test_strategy_groups.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Tests for strategy grouping module."""
|
||||
|
||||
import pytest
|
||||
from src.autopilot.strategy_groups import (
|
||||
StrategyGroup,
|
||||
STRATEGY_TO_GROUP,
|
||||
GROUP_TO_STRATEGIES,
|
||||
get_strategy_group,
|
||||
get_strategies_in_group,
|
||||
get_all_groups,
|
||||
get_best_strategy_in_group,
|
||||
convert_strategy_to_group_label,
|
||||
)
|
||||
|
||||
|
||||
class TestStrategyGroupMappings:
|
||||
"""Tests for strategy group mappings."""
|
||||
|
||||
def test_all_strategies_have_group(self):
|
||||
"""Verify all registered strategies are mapped to a group."""
|
||||
# These are the strategies registered in src/strategies/__init__.py
|
||||
expected_strategies = [
|
||||
"rsi", "macd", "moving_average", "confirmed", "divergence",
|
||||
"bollinger_mean_reversion", "dca", "grid", "momentum",
|
||||
"consensus", "pairs_trading", "volatility_breakout",
|
||||
"sentiment", "market_making"
|
||||
]
|
||||
|
||||
for strategy in expected_strategies:
|
||||
group = get_strategy_group(strategy)
|
||||
assert group is not None, f"Strategy '{strategy}' is not mapped to any group"
|
||||
|
||||
def test_get_strategy_group_case_insensitive(self):
|
||||
"""Test that strategy lookup is case-insensitive."""
|
||||
assert get_strategy_group("RSI") == get_strategy_group("rsi")
|
||||
assert get_strategy_group("MACD") == get_strategy_group("macd")
|
||||
assert get_strategy_group("Moving_Average") == get_strategy_group("moving_average")
|
||||
|
||||
def test_get_strategy_group_unknown(self):
|
||||
"""Test that unknown strategies return None."""
|
||||
assert get_strategy_group("nonexistent_strategy") is None
|
||||
assert get_strategy_group("") is None
|
||||
|
||||
def test_group_to_strategies_reverse_mapping(self):
|
||||
"""Verify GROUP_TO_STRATEGIES is the reverse of STRATEGY_TO_GROUP."""
|
||||
for strategy, group in STRATEGY_TO_GROUP.items():
|
||||
assert strategy in GROUP_TO_STRATEGIES[group]
|
||||
|
||||
def test_all_groups_have_strategies(self):
|
||||
"""Verify all groups have at least one strategy."""
|
||||
for group in StrategyGroup:
|
||||
strategies = get_strategies_in_group(group)
|
||||
assert len(strategies) > 0, f"Group '{group}' has no strategies"
|
||||
|
||||
|
||||
class TestGetAllGroups:
|
||||
"""Tests for get_all_groups function."""
|
||||
|
||||
def test_returns_all_groups(self):
|
||||
"""Verify all groups are returned."""
|
||||
groups = get_all_groups()
|
||||
assert len(groups) == 5
|
||||
assert StrategyGroup.TREND_FOLLOWING in groups
|
||||
assert StrategyGroup.MEAN_REVERSION in groups
|
||||
assert StrategyGroup.MOMENTUM in groups
|
||||
assert StrategyGroup.MARKET_MAKING in groups
|
||||
assert StrategyGroup.SENTIMENT_BASED in groups
|
||||
|
||||
|
||||
class TestGetStrategiesInGroup:
|
||||
"""Tests for get_strategies_in_group function."""
|
||||
|
||||
def test_trend_following_strategies(self):
|
||||
"""Test trend following group contains expected strategies."""
|
||||
strategies = get_strategies_in_group(StrategyGroup.TREND_FOLLOWING)
|
||||
assert "moving_average" in strategies
|
||||
assert "macd" in strategies
|
||||
assert "confirmed" in strategies
|
||||
|
||||
def test_mean_reversion_strategies(self):
|
||||
"""Test mean reversion group contains expected strategies."""
|
||||
strategies = get_strategies_in_group(StrategyGroup.MEAN_REVERSION)
|
||||
assert "rsi" in strategies
|
||||
assert "bollinger_mean_reversion" in strategies
|
||||
assert "grid" in strategies
|
||||
|
||||
def test_momentum_strategies(self):
|
||||
"""Test momentum group contains expected strategies."""
|
||||
strategies = get_strategies_in_group(StrategyGroup.MOMENTUM)
|
||||
assert "momentum" in strategies
|
||||
assert "volatility_breakout" in strategies
|
||||
|
||||
|
||||
class TestGetBestStrategyInGroup:
|
||||
"""Tests for get_best_strategy_in_group function."""
|
||||
|
||||
def test_trend_following_high_adx(self):
|
||||
"""Test trend following selection with high ADX."""
|
||||
features = {"adx": 35.0, "rsi": 50.0, "atr_percent": 2.0, "volume_ratio": 1.0}
|
||||
strategy, confidence = get_best_strategy_in_group(
|
||||
StrategyGroup.TREND_FOLLOWING, features
|
||||
)
|
||||
assert strategy == "confirmed"
|
||||
assert confidence > 0.5
|
||||
|
||||
def test_mean_reversion_extreme_rsi(self):
|
||||
"""Test mean reversion selection with extreme RSI."""
|
||||
features = {"adx": 15.0, "rsi": 25.0, "atr_percent": 1.5, "volume_ratio": 1.0}
|
||||
strategy, confidence = get_best_strategy_in_group(
|
||||
StrategyGroup.MEAN_REVERSION, features
|
||||
)
|
||||
assert strategy == "rsi"
|
||||
assert confidence > 0.5
|
||||
|
||||
def test_momentum_high_volume(self):
|
||||
"""Test momentum selection with high volume."""
|
||||
features = {"adx": 25.0, "rsi": 55.0, "atr_percent": 3.0, "volume_ratio": 2.0}
|
||||
strategy, confidence = get_best_strategy_in_group(
|
||||
StrategyGroup.MOMENTUM, features
|
||||
)
|
||||
assert strategy == "volatility_breakout"
|
||||
assert confidence > 0.5
|
||||
|
||||
def test_respects_available_strategies(self):
|
||||
"""Test that only available strategies are selected."""
|
||||
features = {"adx": 35.0, "rsi": 50.0}
|
||||
# Only MACD available from trend following
|
||||
strategy, confidence = get_best_strategy_in_group(
|
||||
StrategyGroup.TREND_FOLLOWING,
|
||||
features,
|
||||
available_strategies=["macd"]
|
||||
)
|
||||
assert strategy == "macd"
|
||||
|
||||
def test_fallback_when_no_strategies_available(self):
|
||||
"""Test fallback when no strategies in group are available."""
|
||||
features = {"adx": 25.0, "rsi": 50.0}
|
||||
strategy, confidence = get_best_strategy_in_group(
|
||||
StrategyGroup.TREND_FOLLOWING,
|
||||
features,
|
||||
available_strategies=["some_other_strategy"]
|
||||
)
|
||||
# Should return safe default
|
||||
assert strategy == "rsi"
|
||||
assert confidence == 0.5
|
||||
|
||||
|
||||
class TestConvertStrategyToGroupLabel:
|
||||
"""Tests for convert_strategy_to_group_label function."""
|
||||
|
||||
def test_converts_known_strategies(self):
|
||||
"""Test conversion of known strategies."""
|
||||
assert convert_strategy_to_group_label("rsi") == "mean_reversion"
|
||||
assert convert_strategy_to_group_label("macd") == "trend_following"
|
||||
assert convert_strategy_to_group_label("momentum") == "momentum"
|
||||
assert convert_strategy_to_group_label("dca") == "market_making"
|
||||
assert convert_strategy_to_group_label("sentiment") == "sentiment_based"
|
||||
|
||||
def test_unknown_strategy_returns_original(self):
|
||||
"""Test that unknown strategies return original name."""
|
||||
assert convert_strategy_to_group_label("unknown") == "unknown"
|
||||
2
tests/unit/backend/__init__.py
Normal file
2
tests/unit/backend/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Backend API tests."""
|
||||
|
||||
2
tests/unit/backend/api/__init__.py
Normal file
2
tests/unit/backend/api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Backend API endpoint tests."""
|
||||
|
||||
379
tests/unit/backend/api/test_autopilot.py
Normal file
379
tests/unit/backend/api/test_autopilot.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Tests for autopilot API endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_autopilot():
|
||||
"""Mock autopilot instance."""
|
||||
autopilot = Mock()
|
||||
autopilot.symbol = "BTC/USD"
|
||||
autopilot.is_running = False
|
||||
autopilot.get_status.return_value = {
|
||||
"symbol": "BTC/USD",
|
||||
"running": False,
|
||||
"interval": 60.0,
|
||||
"has_market_data": True,
|
||||
"market_data_length": 100,
|
||||
"headlines_count": 5,
|
||||
"last_sentiment_score": 0.5,
|
||||
"last_pattern": "head_and_shoulders",
|
||||
"last_signal": None,
|
||||
}
|
||||
autopilot.last_signal = None
|
||||
autopilot.analyze_once.return_value = None
|
||||
return autopilot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_intelligent_autopilot():
|
||||
"""Mock intelligent autopilot instance."""
|
||||
autopilot = Mock()
|
||||
autopilot.symbol = "BTC/USD"
|
||||
autopilot.is_running = False
|
||||
autopilot.enable_auto_execution = False
|
||||
autopilot.get_status.return_value = {
|
||||
"symbol": "BTC/USD",
|
||||
"timeframe": "1h",
|
||||
"running": False,
|
||||
"selected_strategy": None,
|
||||
"trades_today": 0,
|
||||
"max_trades_per_day": 10,
|
||||
"min_confidence_threshold": 0.75,
|
||||
"enable_auto_execution": False,
|
||||
"last_analysis": None,
|
||||
"model_info": {},
|
||||
}
|
||||
return autopilot
|
||||
|
||||
|
||||
class TestUnifiedAutopilotEndpoints:
|
||||
"""Tests for unified autopilot endpoints."""
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot_mode_info')
|
||||
def test_get_modes(self, mock_get_mode_info, client):
|
||||
"""Test getting autopilot mode information."""
|
||||
mock_get_mode_info.return_value = {
|
||||
"modes": {
|
||||
"pattern": {"name": "Pattern-Based Autopilot"},
|
||||
"intelligent": {"name": "ML-Based Autopilot"},
|
||||
},
|
||||
"comparison": {},
|
||||
}
|
||||
|
||||
response = client.get("/api/autopilot/modes")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "modes" in data
|
||||
assert "pattern" in data["modes"]
|
||||
assert "intelligent" in data["modes"]
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot')
|
||||
@patch('backend.api.autopilot.run_autopilot_loop')
|
||||
def test_start_unified_pattern_mode(
|
||||
self, mock_run_loop, mock_get_autopilot, client, mock_autopilot
|
||||
):
|
||||
"""Test starting unified autopilot in pattern mode."""
|
||||
mock_get_autopilot.return_value = mock_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "pattern",
|
||||
"auto_execute": False,
|
||||
"interval": 60.0,
|
||||
"pattern_order": 5,
|
||||
"auto_fetch_news": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "started"
|
||||
assert data["mode"] == "pattern"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["auto_execute"] is False
|
||||
mock_get_autopilot.assert_called_once()
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_start_unified_intelligent_mode(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test starting unified autopilot in intelligent mode."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "intelligent",
|
||||
"auto_execute": True,
|
||||
"exchange_id": 1,
|
||||
"timeframe": "1h",
|
||||
"interval": 60.0,
|
||||
"paper_trading": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "started"
|
||||
assert data["mode"] == "intelligent"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["auto_execute"] is True
|
||||
assert mock_intelligent_autopilot.enable_auto_execution is True
|
||||
mock_get_intelligent.assert_called_once()
|
||||
|
||||
def test_start_unified_invalid_mode(self, client):
|
||||
"""Test starting unified autopilot with invalid mode."""
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "invalid_mode",
|
||||
"auto_execute": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid mode" in response.json()["detail"]
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot')
|
||||
def test_stop_unified_pattern_mode(
|
||||
self, mock_get_autopilot, client, mock_autopilot
|
||||
):
|
||||
"""Test stopping unified autopilot in pattern mode."""
|
||||
mock_get_autopilot.return_value = mock_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=pattern"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "stopped"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["mode"] == "pattern"
|
||||
mock_autopilot.stop.assert_called_once()
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_stop_unified_intelligent_mode(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test stopping unified autopilot in intelligent mode."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=intelligent&timeframe=1h"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "stopped"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["mode"] == "intelligent"
|
||||
mock_intelligent_autopilot.stop.assert_called_once()
|
||||
|
||||
def test_stop_unified_invalid_mode(self, client):
|
||||
"""Test stopping unified autopilot with invalid mode."""
|
||||
response = client.post(
|
||||
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=invalid_mode"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid mode" in response.json()["detail"]
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot')
|
||||
def test_get_unified_status_pattern_mode(
|
||||
self, mock_get_autopilot, client, mock_autopilot
|
||||
):
|
||||
"""Test getting unified autopilot status in pattern mode."""
|
||||
mock_get_autopilot.return_value = mock_autopilot
|
||||
|
||||
response = client.get(
|
||||
"/api/autopilot/status-unified/BTC/USD?mode=pattern"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["mode"] == "pattern"
|
||||
assert "running" in data
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_get_unified_status_intelligent_mode(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test getting unified autopilot status in intelligent mode."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.get(
|
||||
"/api/autopilot/status-unified/BTC/USD?mode=intelligent&timeframe=1h"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["mode"] == "intelligent"
|
||||
assert "running" in data
|
||||
|
||||
def test_get_unified_status_invalid_mode(self, client):
|
||||
"""Test getting unified autopilot status with invalid mode."""
|
||||
response = client.get(
|
||||
"/api/autopilot/status-unified/BTC/USD?mode=invalid_mode"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid mode" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestModeSelection:
|
||||
"""Tests for mode selection logic."""
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot_mode_info')
|
||||
def test_mode_info_structure(self, mock_get_mode_info):
|
||||
"""Test that mode info has correct structure."""
|
||||
from src.autopilot import get_autopilot_mode_info
|
||||
|
||||
mode_info = get_autopilot_mode_info()
|
||||
|
||||
assert "modes" in mode_info
|
||||
assert "pattern" in mode_info["modes"]
|
||||
assert "intelligent" in mode_info["modes"]
|
||||
assert "comparison" in mode_info
|
||||
|
||||
# Check pattern mode structure
|
||||
pattern = mode_info["modes"]["pattern"]
|
||||
assert "name" in pattern
|
||||
assert "description" in pattern
|
||||
assert "how_it_works" in pattern
|
||||
assert "best_for" in pattern
|
||||
assert "tradeoffs" in pattern
|
||||
assert "features" in pattern
|
||||
assert "requirements" in pattern
|
||||
|
||||
# Check intelligent mode structure
|
||||
intelligent = mode_info["modes"]["intelligent"]
|
||||
assert "name" in intelligent
|
||||
assert "description" in intelligent
|
||||
assert "how_it_works" in intelligent
|
||||
assert "best_for" in intelligent
|
||||
assert "tradeoffs" in intelligent
|
||||
assert "features" in intelligent
|
||||
assert "requirements" in intelligent
|
||||
|
||||
# Check comparison structure
|
||||
comparison = mode_info["comparison"]
|
||||
assert "transparency" in comparison
|
||||
assert "adaptability" in comparison
|
||||
assert "setup_time" in comparison
|
||||
assert "resource_usage" in comparison
|
||||
|
||||
|
||||
class TestAutoExecution:
|
||||
"""Tests for auto-execution functionality."""
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_auto_execute_enabled(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test that auto-execute is set when enabled."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "intelligent",
|
||||
"auto_execute": True,
|
||||
"exchange_id": 1,
|
||||
"timeframe": "1h",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_intelligent_autopilot.enable_auto_execution is True
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_auto_execute_disabled(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test that auto-execute is not set when disabled."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "intelligent",
|
||||
"auto_execute": False,
|
||||
"exchange_id": 1,
|
||||
"timeframe": "1h",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Note: enable_auto_execution may have a default value, so we check it's not True
|
||||
# The actual behavior depends on the implementation
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""Tests for backward compatibility with old endpoints."""
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot')
|
||||
@patch('backend.api.autopilot.run_autopilot_loop')
|
||||
def test_old_start_endpoint_still_works(
|
||||
self, mock_run_loop, mock_get_autopilot, client, mock_autopilot
|
||||
):
|
||||
"""Test that old /start endpoint still works (deprecated but functional)."""
|
||||
mock_get_autopilot.return_value = mock_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"interval": 60.0,
|
||||
"pattern_order": 5,
|
||||
"auto_fetch_news": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "started"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_old_intelligent_start_endpoint_still_works(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test that old /intelligent/start endpoint still works (deprecated but functional)."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/intelligent/start",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"exchange_id": 1,
|
||||
"timeframe": "1h",
|
||||
"interval": 60.0,
|
||||
"paper_trading": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "started"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
|
||||
81
tests/unit/backend/api/test_exchanges.py
Normal file
81
tests/unit/backend/api/test_exchanges.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for exchanges API endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
from src.core.database import Exchange
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_exchange():
|
||||
"""Mock exchange object."""
|
||||
exchange = Mock(spec=Exchange)
|
||||
exchange.id = 1
|
||||
exchange.name = "coinbase"
|
||||
exchange.is_enabled = True
|
||||
exchange.api_permissions = "read_only"
|
||||
return exchange
|
||||
|
||||
|
||||
class TestListExchanges:
|
||||
"""Tests for GET /api/exchanges."""
|
||||
|
||||
@patch('backend.api.exchanges.get_db')
|
||||
def test_list_exchanges_success(self, mock_get_db, client):
|
||||
"""Test listing exchanges."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.all.return_value = []
|
||||
|
||||
response = client.get("/api/exchanges")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestGetExchange:
|
||||
"""Tests for GET /api/exchanges/{exchange_id}."""
|
||||
|
||||
@patch('backend.api.exchanges.get_db')
|
||||
def test_get_exchange_success(self, mock_get_db, client, mock_exchange):
|
||||
"""Test getting exchange by ID."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_exchange
|
||||
|
||||
response = client.get("/api/exchanges/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == 1
|
||||
|
||||
@patch('backend.api.exchanges.get_db')
|
||||
def test_get_exchange_not_found(self, mock_get_db, client):
|
||||
"""Test getting non-existent exchange."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
response = client.get("/api/exchanges/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Exchange not found" in response.json()["detail"]
|
||||
|
||||
65
tests/unit/backend/api/test_portfolio.py
Normal file
65
tests/unit/backend/api/test_portfolio.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Tests for portfolio API endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_portfolio_tracker():
|
||||
"""Mock portfolio tracker."""
|
||||
tracker = Mock()
|
||||
tracker.get_current_portfolio.return_value = {
|
||||
"positions": [],
|
||||
"performance": {"total_return": 0.1},
|
||||
"timestamp": "2025-01-01T00:00:00"
|
||||
}
|
||||
tracker.get_portfolio_history.return_value = {
|
||||
"dates": ["2025-01-01"],
|
||||
"values": [1000.0],
|
||||
"pnl": [0.0]
|
||||
}
|
||||
return tracker
|
||||
|
||||
|
||||
class TestGetCurrentPortfolio:
|
||||
"""Tests for GET /api/portfolio/current."""
|
||||
|
||||
@patch('backend.api.portfolio.get_portfolio_tracker')
|
||||
def test_get_current_portfolio_success(self, mock_get_tracker, client, mock_portfolio_tracker):
|
||||
"""Test getting current portfolio."""
|
||||
mock_get_tracker.return_value = mock_portfolio_tracker
|
||||
|
||||
response = client.get("/api/portfolio/current?paper_trading=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "positions" in data
|
||||
assert "performance" in data
|
||||
assert "timestamp" in data
|
||||
|
||||
|
||||
class TestGetPortfolioHistory:
|
||||
"""Tests for GET /api/portfolio/history."""
|
||||
|
||||
@patch('backend.api.portfolio.get_portfolio_tracker')
|
||||
def test_get_portfolio_history_success(self, mock_get_tracker, client, mock_portfolio_tracker):
|
||||
"""Test getting portfolio history."""
|
||||
mock_get_tracker.return_value = mock_portfolio_tracker
|
||||
|
||||
response = client.get("/api/portfolio/history?paper_trading=true&days=30")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "dates" in data
|
||||
assert "values" in data
|
||||
assert "pnl" in data
|
||||
|
||||
108
tests/unit/backend/api/test_strategies.py
Normal file
108
tests/unit/backend/api/test_strategies.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Tests for strategies API endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
from src.core.database import Strategy
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_strategy_registry():
|
||||
"""Mock strategy registry."""
|
||||
registry = Mock()
|
||||
registry.list_available.return_value = ["RSIStrategy", "MACDStrategy"]
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_strategy():
|
||||
"""Mock strategy object."""
|
||||
strategy = Mock(spec=Strategy)
|
||||
strategy.id = 1
|
||||
strategy.name = "Test Strategy"
|
||||
strategy.type = "RSIStrategy"
|
||||
strategy.status = "active" # Use string instead of enum
|
||||
strategy.symbol = "BTC/USD"
|
||||
strategy.params = {}
|
||||
return strategy
|
||||
|
||||
|
||||
|
||||
class TestListStrategies:
|
||||
"""Tests for GET /api/strategies."""
|
||||
|
||||
@patch('backend.api.strategies.get_db')
|
||||
def test_list_strategies_success(self, mock_get_db, client):
|
||||
"""Test listing strategies."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.all.return_value = []
|
||||
|
||||
response = client.get("/api/strategies")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestGetStrategy:
|
||||
"""Tests for GET /api/strategies/{strategy_id}."""
|
||||
|
||||
@patch('backend.api.strategies.get_db')
|
||||
def test_get_strategy_success(self, mock_get_db, client, mock_strategy):
|
||||
"""Test getting strategy by ID."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_strategy
|
||||
|
||||
response = client.get("/api/strategies/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == 1
|
||||
|
||||
@patch('backend.api.strategies.get_db')
|
||||
def test_get_strategy_not_found(self, mock_get_db, client):
|
||||
"""Test getting non-existent strategy."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
response = client.get("/api/strategies/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Strategy not found" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestListAvailableStrategyTypes:
|
||||
"""Tests for GET /api/strategies/types."""
|
||||
|
||||
@patch('backend.api.strategies.get_strategy_registry')
|
||||
def test_list_strategy_types_success(self, mock_get_registry, client, mock_strategy_registry):
|
||||
"""Test listing available strategy types."""
|
||||
mock_get_registry.return_value = mock_strategy_registry
|
||||
|
||||
response = client.get("/api/strategies/types")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert "RSIStrategy" in data
|
||||
|
||||
293
tests/unit/backend/api/test_trading.py
Normal file
293
tests/unit/backend/api/test_trading.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Tests for trading API endpoints."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime
|
||||
|
||||
from backend.main import app
|
||||
from backend.core.schemas import OrderCreate, OrderSide, OrderType
|
||||
from src.core.database import Order, OrderStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trading_engine():
|
||||
"""Mock trading engine."""
|
||||
engine = Mock()
|
||||
engine.order_manager = Mock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_order():
|
||||
"""Mock order object."""
|
||||
order = Mock()
|
||||
order.id = 1
|
||||
order.exchange_id = 1
|
||||
order.strategy_id = None
|
||||
order.symbol = "BTC/USD"
|
||||
order.order_type = OrderType.MARKET
|
||||
order.side = OrderSide.BUY
|
||||
order.status = OrderStatus.FILLED
|
||||
order.quantity = Decimal("0.1")
|
||||
order.price = Decimal("50000")
|
||||
order.filled_quantity = Decimal("0.1")
|
||||
order.average_fill_price = Decimal("50000")
|
||||
order.fee = Decimal("5")
|
||||
order.paper_trading = True
|
||||
order.created_at = datetime.now()
|
||||
order.updated_at = datetime.now()
|
||||
order.filled_at = datetime.now()
|
||||
return order
|
||||
|
||||
|
||||
class TestCreateOrder:
|
||||
"""Tests for POST /api/trading/orders."""
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
def test_create_order_success(self, mock_get_engine, client, mock_trading_engine, mock_order):
|
||||
"""Test successful order creation."""
|
||||
mock_get_engine.return_value = mock_trading_engine
|
||||
mock_trading_engine.execute_order.return_value = mock_order
|
||||
|
||||
order_data = {
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTC/USD",
|
||||
"side": "buy",
|
||||
"order_type": "market",
|
||||
"quantity": "0.1",
|
||||
"paper_trading": True
|
||||
}
|
||||
|
||||
response = client.post("/api/trading/orders", json=order_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == 1
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["side"] == "buy"
|
||||
mock_trading_engine.execute_order.assert_called_once()
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
def test_create_order_execution_failed(self, mock_get_engine, client, mock_trading_engine):
|
||||
"""Test order creation when execution fails."""
|
||||
mock_get_engine.return_value = mock_trading_engine
|
||||
mock_trading_engine.execute_order.return_value = None
|
||||
|
||||
order_data = {
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTC/USD",
|
||||
"side": "buy",
|
||||
"order_type": "market",
|
||||
"quantity": "0.1",
|
||||
"paper_trading": True
|
||||
}
|
||||
|
||||
response = client.post("/api/trading/orders", json=order_data)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Order execution failed" in response.json()["detail"]
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
def test_create_order_invalid_data(self, client):
|
||||
"""Test order creation with invalid data."""
|
||||
order_data = {
|
||||
"exchange_id": "invalid",
|
||||
"symbol": "BTC/USD",
|
||||
"side": "buy",
|
||||
"order_type": "market",
|
||||
"quantity": "0.1"
|
||||
}
|
||||
|
||||
response = client.post("/api/trading/orders", json=order_data)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
class TestGetOrders:
|
||||
"""Tests for GET /api/trading/orders."""
|
||||
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_get_orders_success(self, mock_get_db, client, mock_database):
|
||||
"""Test getting orders."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
# Create mock orders
|
||||
mock_order1 = Mock(spec=Order)
|
||||
mock_order1.id = 1
|
||||
mock_order1.symbol = "BTC/USD"
|
||||
mock_order1.paper_trading = True
|
||||
|
||||
mock_order2 = Mock(spec=Order)
|
||||
mock_order2.id = 2
|
||||
mock_order2.symbol = "ETH/USD"
|
||||
mock_order2.paper_trading = True
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_order1, mock_order2]
|
||||
|
||||
response = client.get("/api/trading/orders?paper_trading=true&limit=10")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
|
||||
class TestGetOrder:
|
||||
"""Tests for GET /api/trading/orders/{order_id}."""
|
||||
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_get_order_success(self, mock_get_db, client, mock_database):
|
||||
"""Test getting order by ID."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_order = Mock(spec=Order)
|
||||
mock_order.id = 1
|
||||
mock_order.symbol = "BTC/USD"
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
|
||||
|
||||
response = client.get("/api/trading/orders/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == 1
|
||||
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_get_order_not_found(self, mock_get_db, client):
|
||||
"""Test getting non-existent order."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
response = client.get("/api/trading/orders/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Order not found" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestCancelOrder:
|
||||
"""Tests for POST /api/trading/orders/{order_id}/cancel."""
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_cancel_order_success(self, mock_get_db, mock_get_engine, client, mock_trading_engine):
|
||||
"""Test successful order cancellation."""
|
||||
mock_get_engine.return_value = mock_trading_engine
|
||||
mock_trading_engine.order_manager.cancel_order.return_value = True
|
||||
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_order = Mock(spec=Order)
|
||||
mock_order.id = 1
|
||||
mock_order.status = OrderStatus.OPEN
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
|
||||
|
||||
response = client.post("/api/trading/orders/1/cancel")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "cancelled"
|
||||
assert data["order_id"] == 1
|
||||
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_cancel_order_not_found(self, mock_get_db, client):
|
||||
"""Test cancelling non-existent order."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
response = client.post("/api/trading/orders/999/cancel")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Order not found" in response.json()["detail"]
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_cancel_order_already_filled(self, mock_get_db, mock_get_engine, client, mock_trading_engine):
|
||||
"""Test cancelling already filled order."""
|
||||
mock_get_engine.return_value = mock_trading_engine
|
||||
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_order = Mock(spec=Order)
|
||||
mock_order.id = 1
|
||||
mock_order.status = OrderStatus.FILLED
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
|
||||
|
||||
response = client.post("/api/trading/orders/1/cancel")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "cannot be cancelled" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestGetPositions:
|
||||
"""Tests for GET /api/trading/positions."""
|
||||
|
||||
@patch('backend.api.trading.get_paper_trading')
|
||||
def test_get_positions_paper_trading(self, mock_get_paper, client):
|
||||
"""Test getting positions for paper trading."""
|
||||
mock_paper = Mock()
|
||||
mock_position = Mock()
|
||||
mock_position.symbol = "BTC/USD"
|
||||
mock_position.quantity = Decimal("0.1")
|
||||
mock_position.entry_price = Decimal("50000")
|
||||
mock_position.current_price = Decimal("51000")
|
||||
mock_position.unrealized_pnl = Decimal("100")
|
||||
mock_position.realized_pnl = Decimal("0")
|
||||
mock_paper.get_positions.return_value = [mock_position]
|
||||
mock_get_paper.return_value = mock_paper
|
||||
|
||||
response = client.get("/api/trading/positions?paper_trading=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
if len(data) > 0:
|
||||
assert data[0]["symbol"] == "BTC/USD"
|
||||
|
||||
|
||||
class TestGetBalance:
|
||||
"""Tests for GET /api/trading/balance."""
|
||||
|
||||
@patch('backend.api.trading.get_paper_trading')
|
||||
def test_get_balance_paper_trading(self, mock_get_paper, client):
|
||||
"""Test getting balance for paper trading."""
|
||||
mock_paper = Mock()
|
||||
mock_paper.get_balance.return_value = Decimal("1000")
|
||||
mock_paper.get_performance.return_value = {"total_return": 0.1}
|
||||
mock_get_paper.return_value = mock_paper
|
||||
|
||||
response = client.get("/api/trading/balance?paper_trading=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "balance" in data
|
||||
assert "performance" in data
|
||||
assert data["balance"] == 1000.0
|
||||
|
||||
77
tests/unit/backend/core/test_dependencies.py
Normal file
77
tests/unit/backend/core/test_dependencies.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Tests for backend dependencies."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from backend.core.dependencies import (
|
||||
get_database, get_trading_engine, get_portfolio_tracker,
|
||||
get_strategy_registry, get_backtesting_engine, get_exchange_factory
|
||||
)
|
||||
|
||||
|
||||
class TestGetDatabase:
|
||||
"""Tests for get_database dependency."""
|
||||
|
||||
def test_get_database_singleton(self):
|
||||
"""Test that get_database returns same instance."""
|
||||
db1 = get_database()
|
||||
db2 = get_database()
|
||||
# Should be cached/same instance due to lru_cache
|
||||
assert db1 is db2
|
||||
|
||||
|
||||
class TestGetTradingEngine:
|
||||
"""Tests for get_trading_engine dependency."""
|
||||
|
||||
@patch('backend.core.dependencies.get_trading_engine')
|
||||
def test_get_trading_engine(self, mock_get_engine):
|
||||
"""Test getting trading engine."""
|
||||
mock_engine = Mock()
|
||||
mock_get_engine.return_value = mock_engine
|
||||
engine = get_trading_engine()
|
||||
assert engine is not None
|
||||
|
||||
|
||||
class TestGetPortfolioTracker:
|
||||
"""Tests for get_portfolio_tracker dependency."""
|
||||
|
||||
@patch('backend.core.dependencies.get_portfolio_tracker')
|
||||
def test_get_portfolio_tracker(self, mock_get_tracker):
|
||||
"""Test getting portfolio tracker."""
|
||||
mock_tracker = Mock()
|
||||
mock_get_tracker.return_value = mock_tracker
|
||||
tracker = get_portfolio_tracker()
|
||||
assert tracker is not None
|
||||
|
||||
|
||||
class TestGetStrategyRegistry:
|
||||
"""Tests for get_strategy_registry dependency."""
|
||||
|
||||
@patch('backend.core.dependencies.get_strategy_registry')
|
||||
def test_get_strategy_registry(self, mock_get_registry):
|
||||
"""Test getting strategy registry."""
|
||||
mock_registry = Mock()
|
||||
mock_get_registry.return_value = mock_registry
|
||||
registry = get_strategy_registry()
|
||||
assert registry is not None
|
||||
|
||||
|
||||
class TestGetBacktestingEngine:
|
||||
"""Tests for get_backtesting_engine dependency."""
|
||||
|
||||
@patch('backend.core.dependencies.get_backtesting_engine')
|
||||
def test_get_backtesting_engine(self, mock_get_engine):
|
||||
"""Test getting backtesting engine."""
|
||||
mock_engine = Mock()
|
||||
mock_get_engine.return_value = mock_engine
|
||||
engine = get_backtesting_engine()
|
||||
assert engine is not None
|
||||
|
||||
|
||||
class TestGetExchangeFactory:
|
||||
"""Tests for get_exchange_factory dependency."""
|
||||
|
||||
def test_get_exchange_factory(self):
|
||||
"""Test getting exchange factory."""
|
||||
factory = get_exchange_factory()
|
||||
assert factory is not None
|
||||
|
||||
128
tests/unit/backend/core/test_schemas.py
Normal file
128
tests/unit/backend/core/test_schemas.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Tests for Pydantic schemas."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.core.schemas import (
|
||||
OrderCreate, OrderResponse, OrderSide, OrderType, OrderStatus,
|
||||
PositionResponse, PortfolioResponse, PortfolioHistoryResponse
|
||||
)
|
||||
|
||||
|
||||
class TestOrderCreate:
|
||||
"""Tests for OrderCreate schema."""
|
||||
|
||||
def test_order_create_valid(self):
|
||||
"""Test valid order creation."""
|
||||
order = OrderCreate(
|
||||
exchange_id=1,
|
||||
symbol="BTC/USD",
|
||||
side=OrderSide.BUY,
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=Decimal("0.1"),
|
||||
paper_trading=True
|
||||
)
|
||||
assert order.exchange_id == 1
|
||||
assert order.symbol == "BTC/USD"
|
||||
assert order.side == OrderSide.BUY
|
||||
|
||||
def test_order_create_with_price(self):
|
||||
"""Test order creation with price."""
|
||||
order = OrderCreate(
|
||||
exchange_id=1,
|
||||
symbol="BTC/USD",
|
||||
side=OrderSide.BUY,
|
||||
order_type=OrderType.LIMIT,
|
||||
quantity=Decimal("0.1"),
|
||||
price=Decimal("50000"),
|
||||
paper_trading=True
|
||||
)
|
||||
assert order.price == Decimal("50000")
|
||||
|
||||
def test_order_create_invalid_side(self):
|
||||
"""Test order creation with invalid side."""
|
||||
with pytest.raises(ValidationError):
|
||||
OrderCreate(
|
||||
exchange_id=1,
|
||||
symbol="BTC/USD",
|
||||
side="invalid",
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=Decimal("0.1")
|
||||
)
|
||||
|
||||
|
||||
class TestOrderResponse:
|
||||
"""Tests for OrderResponse schema."""
|
||||
|
||||
def test_order_response_from_dict(self):
|
||||
"""Test creating OrderResponse from dictionary."""
|
||||
order_data = {
|
||||
"id": 1,
|
||||
"exchange_id": 1,
|
||||
"strategy_id": None,
|
||||
"symbol": "BTC/USD",
|
||||
"order_type": OrderType.MARKET,
|
||||
"side": OrderSide.BUY,
|
||||
"status": OrderStatus.FILLED,
|
||||
"quantity": Decimal("0.1"),
|
||||
"price": Decimal("50000"),
|
||||
"filled_quantity": Decimal("0.1"),
|
||||
"average_fill_price": Decimal("50000"),
|
||||
"fee": Decimal("5"),
|
||||
"paper_trading": True,
|
||||
"created_at": datetime.now(),
|
||||
"updated_at": datetime.now(),
|
||||
"filled_at": datetime.now()
|
||||
}
|
||||
|
||||
order = OrderResponse(**order_data)
|
||||
assert order.id == 1
|
||||
assert order.symbol == "BTC/USD"
|
||||
|
||||
|
||||
class TestPositionResponse:
|
||||
"""Tests for PositionResponse schema."""
|
||||
|
||||
def test_position_response_valid(self):
|
||||
"""Test valid position response."""
|
||||
position = PositionResponse(
|
||||
symbol="BTC/USD",
|
||||
quantity=Decimal("0.1"),
|
||||
entry_price=Decimal("50000"),
|
||||
current_price=Decimal("51000"),
|
||||
unrealized_pnl=Decimal("100"),
|
||||
realized_pnl=Decimal("0")
|
||||
)
|
||||
assert position.symbol == "BTC/USD"
|
||||
assert position.unrealized_pnl == Decimal("100")
|
||||
|
||||
|
||||
class TestPortfolioResponse:
|
||||
"""Tests for PortfolioResponse schema."""
|
||||
|
||||
def test_portfolio_response_valid(self):
|
||||
"""Test valid portfolio response."""
|
||||
portfolio = PortfolioResponse(
|
||||
positions=[],
|
||||
performance={"total_return": 0.1},
|
||||
timestamp="2025-01-01T00:00:00"
|
||||
)
|
||||
assert portfolio.positions == []
|
||||
assert portfolio.performance["total_return"] == 0.1
|
||||
|
||||
|
||||
class TestPortfolioHistoryResponse:
|
||||
"""Tests for PortfolioHistoryResponse schema."""
|
||||
|
||||
def test_portfolio_history_response_valid(self):
|
||||
"""Test valid portfolio history response."""
|
||||
history = PortfolioHistoryResponse(
|
||||
dates=["2025-01-01", "2025-01-02"],
|
||||
values=[1000.0, 1100.0],
|
||||
pnl=[0.0, 100.0]
|
||||
)
|
||||
assert len(history.dates) == 2
|
||||
assert len(history.values) == 2
|
||||
|
||||
2
tests/unit/backtesting/__init__.py
Normal file
2
tests/unit/backtesting/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for backtesting."""
|
||||
|
||||
26
tests/unit/backtesting/test_engine.py
Normal file
26
tests/unit/backtesting/test_engine.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Tests for backtesting engine."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from src.backtesting.engine import get_backtest_engine, BacktestingEngine
|
||||
|
||||
|
||||
class TestBacktestingEngine:
|
||||
"""Tests for BacktestingEngine."""
|
||||
|
||||
@pytest.fixture
|
||||
def backtest_engine(self):
|
||||
"""Create backtesting engine instance."""
|
||||
return get_backtest_engine()
|
||||
|
||||
def test_engine_initialization(self, backtest_engine):
|
||||
"""Test backtesting engine initialization."""
|
||||
assert backtest_engine is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_backtest(self, backtest_engine):
|
||||
"""Test running a backtest."""
|
||||
# This would require a full strategy implementation
|
||||
# Simplified test
|
||||
assert backtest_engine is not None
|
||||
|
||||
85
tests/unit/backtesting/test_slippage.py
Normal file
85
tests/unit/backtesting/test_slippage.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Tests for slippage model."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from src.backtesting.slippage import SlippageModel, FeeModel
|
||||
|
||||
|
||||
class TestSlippageModel:
|
||||
"""Tests for SlippageModel."""
|
||||
|
||||
@pytest.fixture
|
||||
def slippage_model(self):
|
||||
"""Create slippage model."""
|
||||
return SlippageModel(slippage_rate=0.001)
|
||||
|
||||
def test_calculate_fill_price_market_buy(self, slippage_model):
|
||||
"""Test fill price calculation for market buy."""
|
||||
order_price = Decimal('50000.0')
|
||||
market_price = Decimal('50000.0')
|
||||
|
||||
fill_price = slippage_model.calculate_fill_price(
|
||||
order_price, "buy", "market", market_price
|
||||
)
|
||||
|
||||
assert fill_price > market_price # Buy orders pay more
|
||||
|
||||
def test_calculate_fill_price_market_sell(self, slippage_model):
|
||||
"""Test fill price calculation for market sell."""
|
||||
order_price = Decimal('50000.0')
|
||||
market_price = Decimal('50000.0')
|
||||
|
||||
fill_price = slippage_model.calculate_fill_price(
|
||||
order_price, "sell", "market", market_price
|
||||
)
|
||||
|
||||
assert fill_price < market_price # Sell orders receive less
|
||||
|
||||
def test_calculate_fill_price_limit(self, slippage_model):
|
||||
"""Test fill price for limit orders."""
|
||||
order_price = Decimal('49000.0')
|
||||
market_price = Decimal('50000.0')
|
||||
|
||||
fill_price = slippage_model.calculate_fill_price(
|
||||
order_price, "buy", "limit", market_price
|
||||
)
|
||||
|
||||
assert fill_price == order_price # Limit orders fill at order price
|
||||
|
||||
|
||||
class TestFeeModel:
|
||||
"""Tests for FeeModel."""
|
||||
|
||||
@pytest.fixture
|
||||
def fee_model(self):
|
||||
"""Create fee model."""
|
||||
return FeeModel(maker_fee=0.001, taker_fee=0.002)
|
||||
|
||||
def test_calculate_fee_maker(self, fee_model):
|
||||
"""Test maker fee calculation."""
|
||||
fee = fee_model.calculate_fee(
|
||||
quantity=Decimal('0.01'),
|
||||
price=Decimal('50000.0'),
|
||||
is_maker=True
|
||||
)
|
||||
|
||||
assert fee > 0
|
||||
# Fee should be 0.1% of trade value
|
||||
expected = Decimal('0.01') * Decimal('50000.0') * Decimal('0.001')
|
||||
assert abs(float(fee - expected)) < 0.01
|
||||
|
||||
def test_calculate_fee_taker(self, fee_model):
|
||||
"""Test taker fee calculation."""
|
||||
fee = fee_model.calculate_fee(
|
||||
quantity=Decimal('0.01'),
|
||||
price=Decimal('50000.0'),
|
||||
is_maker=False
|
||||
)
|
||||
|
||||
assert fee > 0
|
||||
# Taker fee should be higher than maker
|
||||
maker_fee = fee_model.calculate_fee(
|
||||
Decimal('0.01'), Decimal('50000.0'), is_maker=True
|
||||
)
|
||||
assert fee > maker_fee
|
||||
|
||||
1
tests/unit/core/__init__.py
Normal file
1
tests/unit/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test init file."""
|
||||
44
tests/unit/core/test_config.py
Normal file
44
tests/unit/core/test_config.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for configuration system."""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from src.core.config import Config, get_config
|
||||
|
||||
|
||||
class TestConfig:
|
||||
"""Tests for Config class."""
|
||||
|
||||
def test_config_initialization(self, tmp_path):
|
||||
"""Test config initialization."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config = Config(config_file=str(config_file))
|
||||
assert config is not None
|
||||
assert config.config_dir is not None
|
||||
assert config.data_dir is not None
|
||||
|
||||
def test_config_get(self, tmp_path):
|
||||
"""Test config get method."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config = Config(config_file=str(config_file))
|
||||
# Test nested key access
|
||||
value = config.get('paper_trading.default_capital')
|
||||
assert value is not None
|
||||
|
||||
def test_config_set(self, tmp_path):
|
||||
"""Test config set method."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config = Config(config_file=str(config_file))
|
||||
config.set('paper_trading.default_capital', 200.0)
|
||||
value = config.get('paper_trading.default_capital')
|
||||
assert value == 200.0
|
||||
|
||||
def test_config_defaults(self, tmp_path):
|
||||
"""Test default configuration values."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config = Config(config_file=str(config_file))
|
||||
assert config.get('paper_trading.default_capital') == 100.0
|
||||
assert config.get('database.type') == 'postgresql'
|
||||
|
||||
97
tests/unit/core/test_database.py
Normal file
97
tests/unit/core/test_database.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Tests for database system."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from src.core.database import (
|
||||
get_database, Base, Exchange, Strategy, Trade, Position, Order
|
||||
)
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
"""Tests for database operations."""
|
||||
|
||||
def test_database_initialization(self):
|
||||
"""Test database initialization."""
|
||||
db = get_database()
|
||||
assert db is not None
|
||||
assert db.engine is not None
|
||||
|
||||
def test_table_creation(self, mock_database):
|
||||
"""Test table creation."""
|
||||
engine, Session = mock_database
|
||||
# Verify tables exist
|
||||
assert Base.metadata.tables.get('exchanges') is not None
|
||||
assert Base.metadata.tables.get('strategies') is not None
|
||||
assert Base.metadata.tables.get('trades') is not None
|
||||
|
||||
def test_exchange_model(self, mock_database):
|
||||
"""Test Exchange model."""
|
||||
engine, Session = mock_database
|
||||
session = Session()
|
||||
|
||||
exchange = Exchange(
|
||||
name="test_exchange",
|
||||
api_key="encrypted_key",
|
||||
secret_key="encrypted_secret",
|
||||
api_permissions="read_only",
|
||||
is_enabled=True
|
||||
)
|
||||
session.add(exchange)
|
||||
session.commit()
|
||||
|
||||
retrieved = session.query(Exchange).filter_by(name="test_exchange").first()
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "test_exchange"
|
||||
assert retrieved.api_permissions == "read_only"
|
||||
|
||||
session.close()
|
||||
|
||||
def test_strategy_model(self, mock_database):
|
||||
"""Test Strategy model."""
|
||||
engine, Session = mock_database
|
||||
session = Session()
|
||||
|
||||
strategy = Strategy(
|
||||
name="test_strategy",
|
||||
strategy_type="RSI",
|
||||
parameters='{"rsi_period": 14}',
|
||||
is_enabled=True,
|
||||
is_paper_trading=True
|
||||
)
|
||||
session.add(strategy)
|
||||
session.commit()
|
||||
|
||||
retrieved = session.query(Strategy).filter_by(name="test_strategy").first()
|
||||
assert retrieved is not None
|
||||
assert retrieved.strategy_type == "RSI"
|
||||
|
||||
session.close()
|
||||
|
||||
def test_trade_model(self, mock_database):
|
||||
"""Test Trade model."""
|
||||
engine, Session = mock_database
|
||||
session = Session()
|
||||
|
||||
trade = Trade(
|
||||
order_id="test_order_123",
|
||||
symbol="BTC/USD",
|
||||
side="buy",
|
||||
type="market",
|
||||
price=50000.0,
|
||||
amount=0.01,
|
||||
cost=500.0,
|
||||
fee=0.5,
|
||||
status="filled",
|
||||
is_paper_trade=True
|
||||
)
|
||||
session.add(trade)
|
||||
session.commit()
|
||||
|
||||
retrieved = session.query(Trade).filter_by(order_id="test_order_123").first()
|
||||
assert retrieved is not None
|
||||
assert retrieved.symbol == "BTC/USD"
|
||||
assert retrieved.status == "filled"
|
||||
|
||||
session.close()
|
||||
|
||||
43
tests/unit/core/test_logger.py
Normal file
43
tests/unit/core/test_logger.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Tests for logging system."""
|
||||
|
||||
import pytest
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from src.core.logger import setup_logging, get_logger
|
||||
|
||||
|
||||
class TestLogger:
|
||||
"""Tests for logging system."""
|
||||
|
||||
def test_logger_setup(self, test_log_dir):
|
||||
"""Test logger setup."""
|
||||
with patch('src.core.logger.get_config') as mock_get_config:
|
||||
mock_config = mock_get_config.return_value
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
'logging.dir': str(test_log_dir),
|
||||
'logging.retention_days': 30,
|
||||
'logging.level': 'INFO'
|
||||
}.get(key, default)
|
||||
|
||||
setup_logging()
|
||||
logger = get_logger('test')
|
||||
assert logger is not None
|
||||
assert isinstance(logger, logging.Logger)
|
||||
|
||||
def test_logger_get(self):
|
||||
"""Test getting logger instance."""
|
||||
logger = get_logger('test_module')
|
||||
assert logger is not None
|
||||
assert logger.name == 'test_module'
|
||||
|
||||
def test_logger_levels(self):
|
||||
"""Test different log levels."""
|
||||
logger = get_logger('test')
|
||||
# Should not raise exceptions
|
||||
logger.debug("Debug message")
|
||||
logger.info("Info message")
|
||||
logger.warning("Warning message")
|
||||
logger.error("Error message")
|
||||
logger.critical("Critical message")
|
||||
|
||||
92
tests/unit/core/test_redis.py
Normal file
92
tests/unit/core/test_redis.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for Redis client wrapper."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
|
||||
|
||||
class TestRedisClient:
|
||||
"""Tests for RedisClient class."""
|
||||
|
||||
@patch('src.core.redis.get_config')
|
||||
def test_get_client_creates_connection(self, mock_config):
|
||||
"""Test that get_client creates a Redis connection."""
|
||||
# Setup mock config
|
||||
mock_config.return_value.get.return_value = {
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"db": 0,
|
||||
"password": None,
|
||||
"socket_connect_timeout": 5
|
||||
}
|
||||
|
||||
from src.core.redis import RedisClient
|
||||
|
||||
client = RedisClient()
|
||||
|
||||
# Should not have connected yet
|
||||
assert client._client is None
|
||||
|
||||
# Get client should trigger connection
|
||||
with patch('src.core.redis.redis.ConnectionPool') as mock_pool:
|
||||
with patch('src.core.redis.redis.Redis') as mock_redis:
|
||||
redis_client = client.get_client()
|
||||
|
||||
mock_pool.assert_called_once()
|
||||
mock_redis.assert_called_once()
|
||||
|
||||
@patch('src.core.redis.get_config')
|
||||
def test_get_client_reuses_existing(self, mock_config):
|
||||
"""Test that get_client reuses existing connection."""
|
||||
mock_config.return_value.get.return_value = {
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"db": 0,
|
||||
}
|
||||
|
||||
from src.core.redis import RedisClient
|
||||
|
||||
client = RedisClient()
|
||||
|
||||
# Pre-set a mock client
|
||||
mock_redis = Mock()
|
||||
client._client = mock_redis
|
||||
|
||||
# Should return existing
|
||||
result = client.get_client()
|
||||
assert result is mock_redis
|
||||
|
||||
@patch('src.core.redis.get_config')
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_connection(self, mock_config):
|
||||
"""Test closing Redis connection."""
|
||||
mock_config.return_value.get.return_value = {"host": "localhost"}
|
||||
|
||||
from src.core.redis import RedisClient
|
||||
|
||||
client = RedisClient()
|
||||
mock_redis = AsyncMock()
|
||||
client._client = mock_redis
|
||||
|
||||
await client.close()
|
||||
|
||||
mock_redis.aclose.assert_called_once()
|
||||
|
||||
|
||||
class TestGetRedisClient:
|
||||
"""Tests for get_redis_client singleton."""
|
||||
|
||||
@patch('src.core.redis.get_config')
|
||||
def test_returns_singleton(self, mock_config):
|
||||
"""Test that get_redis_client returns same instance."""
|
||||
mock_config.return_value.get.return_value = {"host": "localhost"}
|
||||
|
||||
# Reset the global
|
||||
import src.core.redis as redis_module
|
||||
redis_module._redis_client = None
|
||||
|
||||
from src.core.redis import get_redis_client
|
||||
|
||||
client1 = get_redis_client()
|
||||
client2 = get_redis_client()
|
||||
|
||||
assert client1 is client2
|
||||
1
tests/unit/data/__init__.py
Normal file
1
tests/unit/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test init file."""
|
||||
139
tests/unit/data/providers/test_ccxt_provider.py
Normal file
139
tests/unit/data/providers/test_ccxt_provider.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Unit tests for CCXT pricing provider."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
|
||||
from src.data.providers.ccxt_provider import CCXTProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ccxt_exchange():
|
||||
"""Create a mock CCXT exchange."""
|
||||
exchange = Mock()
|
||||
exchange.markets = {
|
||||
'BTC/USDT': {},
|
||||
'ETH/USDT': {},
|
||||
'BTC/USD': {},
|
||||
}
|
||||
exchange.id = 'kraken'
|
||||
exchange.fetch_ticker = Mock(return_value={
|
||||
'bid': 50000.0,
|
||||
'ask': 50001.0,
|
||||
'last': 50000.5,
|
||||
'high': 51000.0,
|
||||
'low': 49000.0,
|
||||
'quoteVolume': 1000000.0,
|
||||
'timestamp': 1609459200000,
|
||||
})
|
||||
exchange.fetch_ohlcv = Mock(return_value=[
|
||||
[1609459200000, 50000, 51000, 49000, 50000, 1000],
|
||||
])
|
||||
exchange.load_markets = Mock(return_value=exchange.markets)
|
||||
return exchange
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider():
|
||||
"""Create a CCXT provider instance."""
|
||||
return CCXTProvider(exchange_name='kraken')
|
||||
|
||||
|
||||
class TestCCXTProvider:
|
||||
"""Tests for CCXTProvider."""
|
||||
|
||||
def test_init(self, provider):
|
||||
"""Test provider initialization."""
|
||||
assert provider.name == "CCXT Provider"
|
||||
assert not provider._connected
|
||||
assert provider.exchange is None
|
||||
|
||||
def test_name_property(self, provider):
|
||||
"""Test name property."""
|
||||
provider._selected_exchange_id = 'kraken'
|
||||
assert provider.name == "CCXT-Kraken"
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_connect_success(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test successful connection."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
|
||||
result = provider.connect()
|
||||
|
||||
assert result is True
|
||||
assert provider._connected is True
|
||||
assert provider.exchange == mock_ccxt_exchange
|
||||
assert provider._selected_exchange_id == 'kraken'
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_connect_failure(self, mock_ccxt, provider):
|
||||
"""Test connection failure."""
|
||||
mock_ccxt.kraken = Mock(side_effect=Exception("Connection failed"))
|
||||
|
||||
result = provider.connect()
|
||||
|
||||
assert result is False
|
||||
assert not provider._connected
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_get_ticker(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test getting ticker data."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
provider.connect()
|
||||
|
||||
ticker = provider.get_ticker('BTC/USDT')
|
||||
|
||||
assert ticker['symbol'] == 'BTC/USDT'
|
||||
assert isinstance(ticker['bid'], Decimal)
|
||||
assert isinstance(ticker['last'], Decimal)
|
||||
assert ticker['last'] > 0
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_get_ohlcv(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test getting OHLCV data."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
provider.connect()
|
||||
|
||||
ohlcv = provider.get_ohlcv('BTC/USDT', '1h', limit=10)
|
||||
|
||||
assert len(ohlcv) > 0
|
||||
assert len(ohlcv[0]) == 6 # timestamp, open, high, low, close, volume
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_subscribe_ticker(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test subscribing to ticker updates."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
provider.connect()
|
||||
|
||||
callback = Mock()
|
||||
result = provider.subscribe_ticker('BTC/USDT', callback)
|
||||
|
||||
assert result is True
|
||||
assert 'ticker_BTC/USDT' in provider._subscribers
|
||||
|
||||
def test_normalize_symbol(self, provider):
|
||||
"""Test symbol normalization."""
|
||||
# Test with exchange
|
||||
with patch.object(provider, 'exchange') as mock_exchange:
|
||||
mock_exchange.markets = {'BTC/USDT': {}}
|
||||
normalized = provider.normalize_symbol('btc-usdt')
|
||||
assert normalized == 'BTC/USDT'
|
||||
|
||||
# Test without exchange
|
||||
provider.exchange = None
|
||||
normalized = provider.normalize_symbol('btc-usdt')
|
||||
assert normalized == 'BTC/USDT'
|
||||
|
||||
@patch('src.data.providers.ccxt_provider.ccxt')
|
||||
def test_disconnect(self, mock_ccxt, provider, mock_ccxt_exchange):
|
||||
"""Test disconnection."""
|
||||
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
|
||||
provider.connect()
|
||||
provider.subscribe_ticker('BTC/USDT', Mock())
|
||||
|
||||
provider.disconnect()
|
||||
|
||||
assert not provider._connected
|
||||
assert provider.exchange is None
|
||||
assert len(provider._subscribers) == 0
|
||||
113
tests/unit/data/providers/test_coingecko_provider.py
Normal file
113
tests/unit/data/providers/test_coingecko_provider.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Unit tests for CoinGecko pricing provider."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from decimal import Decimal
|
||||
import httpx
|
||||
|
||||
from src.data.providers.coingecko_provider import CoinGeckoProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider():
|
||||
"""Create a CoinGecko provider instance."""
|
||||
return CoinGeckoProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
"""Create a mock HTTP response."""
|
||||
response = Mock()
|
||||
response.status_code = 200
|
||||
response.json = Mock(return_value={
|
||||
'bitcoin': {
|
||||
'usd': 50000.0,
|
||||
'usd_24h_change': 2.5,
|
||||
'usd_24h_vol': 1000000.0,
|
||||
}
|
||||
})
|
||||
return response
|
||||
|
||||
|
||||
class TestCoinGeckoProvider:
|
||||
"""Tests for CoinGeckoProvider."""
|
||||
|
||||
def test_init(self, provider):
|
||||
"""Test provider initialization."""
|
||||
assert provider.name == "CoinGecko"
|
||||
assert not provider.supports_websocket
|
||||
assert not provider._connected
|
||||
|
||||
@patch('src.data.providers.coingecko_provider.httpx.Client')
|
||||
def test_connect_success(self, mock_client_class, provider, mock_response):
|
||||
"""Test successful connection."""
|
||||
mock_client = Mock()
|
||||
mock_client.get = Mock(return_value=mock_response)
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
result = provider.connect()
|
||||
|
||||
assert result is True
|
||||
assert provider._connected is True
|
||||
|
||||
@patch('src.data.providers.coingecko_provider.httpx.Client')
|
||||
def test_connect_failure(self, mock_client_class, provider):
|
||||
"""Test connection failure."""
|
||||
mock_client = Mock()
|
||||
mock_client.get = Mock(side_effect=Exception("Connection failed"))
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
result = provider.connect()
|
||||
|
||||
assert result is False
|
||||
assert not provider._connected
|
||||
|
||||
def test_parse_symbol(self, provider):
|
||||
"""Test symbol parsing."""
|
||||
coin_id, currency = provider._parse_symbol('BTC/USD')
|
||||
assert coin_id == 'bitcoin'
|
||||
assert currency == 'usd'
|
||||
|
||||
coin_id, currency = provider._parse_symbol('ETH/USDT')
|
||||
assert coin_id == 'ethereum'
|
||||
assert currency == 'usd' # USDT maps to USD
|
||||
|
||||
@patch('src.data.providers.coingecko_provider.httpx.Client')
|
||||
def test_get_ticker(self, mock_client_class, provider, mock_response):
|
||||
"""Test getting ticker data."""
|
||||
mock_client = Mock()
|
||||
mock_client.get = Mock(return_value=mock_response)
|
||||
mock_client_class.return_value = mock_client
|
||||
provider.connect()
|
||||
|
||||
ticker = provider.get_ticker('BTC/USD')
|
||||
|
||||
assert ticker['symbol'] == 'BTC/USD'
|
||||
assert isinstance(ticker['last'], Decimal)
|
||||
assert ticker['last'] > 0
|
||||
assert 'timestamp' in ticker
|
||||
|
||||
@patch('src.data.providers.coingecko_provider.httpx.Client')
|
||||
def test_get_ohlcv(self, mock_client_class, provider):
|
||||
"""Test getting OHLCV data."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=[
|
||||
[1609459200000, 50000, 51000, 49000, 50000],
|
||||
])
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.get = Mock(return_value=mock_response)
|
||||
mock_client_class.return_value = mock_client
|
||||
provider.connect()
|
||||
|
||||
ohlcv = provider.get_ohlcv('BTC/USD', '1h', limit=10)
|
||||
|
||||
assert len(ohlcv) > 0
|
||||
# CoinGecko returns 5 elements, we add volume as 0
|
||||
assert len(ohlcv[0]) == 6
|
||||
|
||||
def test_normalize_symbol(self, provider):
|
||||
"""Test symbol normalization."""
|
||||
normalized = provider.normalize_symbol('btc-usdt')
|
||||
assert normalized == 'BTC/USDT'
|
||||
120
tests/unit/data/test_cache_manager.py
Normal file
120
tests/unit/data/test_cache_manager.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Unit tests for cache manager."""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from src.data.cache_manager import CacheManager, CacheEntry
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Tests for CacheEntry."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test cache entry initialization."""
|
||||
entry = CacheEntry("test_data", 60.0)
|
||||
assert entry.data == "test_data"
|
||||
assert entry.expires_at > time.time()
|
||||
assert entry.access_count == 0
|
||||
|
||||
def test_is_expired(self):
|
||||
"""Test expiration checking."""
|
||||
entry = CacheEntry("test_data", 0.01) # Very short TTL
|
||||
assert not entry.is_expired()
|
||||
time.sleep(0.02)
|
||||
assert entry.is_expired()
|
||||
|
||||
def test_touch(self):
|
||||
"""Test access tracking."""
|
||||
entry = CacheEntry("test_data", 60.0)
|
||||
initial_count = entry.access_count
|
||||
entry.touch()
|
||||
assert entry.access_count == initial_count + 1
|
||||
|
||||
|
||||
class TestCacheManager:
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self):
|
||||
"""Create a cache manager instance."""
|
||||
return CacheManager(default_ttl=1.0, max_size=10)
|
||||
|
||||
def test_get_set(self, cache):
|
||||
"""Test basic get and set operations."""
|
||||
cache.set("key1", "value1")
|
||||
assert cache.get("key1") == "value1"
|
||||
|
||||
def test_get_missing(self, cache):
|
||||
"""Test getting non-existent key."""
|
||||
assert cache.get("missing") is None
|
||||
|
||||
def test_expiration(self, cache):
|
||||
"""Test cache entry expiration."""
|
||||
cache.set("key1", "value1", ttl=0.1)
|
||||
assert cache.get("key1") == "value1"
|
||||
time.sleep(0.2)
|
||||
assert cache.get("key1") is None
|
||||
|
||||
def test_lru_eviction(self, cache):
|
||||
"""Test LRU eviction when max size reached."""
|
||||
# Fill cache to max size
|
||||
for i in range(10):
|
||||
cache.set(f"key{i}", f"value{i}")
|
||||
|
||||
# Add one more - should evict oldest
|
||||
cache.set("key10", "value10")
|
||||
|
||||
# Oldest key should be evicted
|
||||
assert cache.get("key0") is None
|
||||
assert cache.get("key10") == "value10"
|
||||
|
||||
def test_type_specific_ttl(self, cache):
|
||||
"""Test type-specific TTL."""
|
||||
cache.set("ticker1", {"price": 100}, cache_type='ticker')
|
||||
cache.set("ohlcv1", [[1, 2, 3, 4, 5, 6]], cache_type='ohlcv')
|
||||
|
||||
# Both should be cached
|
||||
assert cache.get("ticker1") is not None
|
||||
assert cache.get("ohlcv1") is not None
|
||||
|
||||
def test_delete(self, cache):
|
||||
"""Test cache entry deletion."""
|
||||
cache.set("key1", "value1")
|
||||
assert cache.get("key1") == "value1"
|
||||
|
||||
cache.delete("key1")
|
||||
assert cache.get("key1") is None
|
||||
|
||||
def test_clear(self, cache):
|
||||
"""Test cache clearing."""
|
||||
cache.set("key1", "value1")
|
||||
cache.set("key2", "value2")
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert cache.get("key1") is None
|
||||
assert cache.get("key2") is None
|
||||
|
||||
def test_stats(self, cache):
|
||||
"""Test cache statistics."""
|
||||
cache.set("key1", "value1")
|
||||
cache.get("key1") # Hit
|
||||
cache.get("missing") # Miss
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats['hits'] >= 1
|
||||
assert stats['misses'] >= 1
|
||||
assert stats['size'] == 1
|
||||
assert 'hit_rate' in stats
|
||||
|
||||
def test_invalidate_pattern(self, cache):
|
||||
"""Test pattern-based invalidation."""
|
||||
cache.set("ticker:BTC/USD", "value1")
|
||||
cache.set("ticker:ETH/USD", "value2")
|
||||
cache.set("ohlcv:BTC/USD", "value3")
|
||||
|
||||
cache.invalidate_pattern("ticker:")
|
||||
|
||||
assert cache.get("ticker:BTC/USD") is None
|
||||
assert cache.get("ticker:ETH/USD") is None
|
||||
assert cache.get("ohlcv:BTC/USD") is not None
|
||||
145
tests/unit/data/test_health_monitor.py
Normal file
145
tests/unit/data/test_health_monitor.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Unit tests for health monitor."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.data.health_monitor import HealthMonitor, HealthMetrics, HealthStatus
|
||||
|
||||
|
||||
class TestHealthMetrics:
|
||||
"""Tests for HealthMetrics."""
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful operation."""
|
||||
metrics = HealthMetrics()
|
||||
metrics.record_success(0.5)
|
||||
|
||||
assert metrics.status == HealthStatus.HEALTHY
|
||||
assert metrics.success_count == 1
|
||||
assert metrics.consecutive_failures == 0
|
||||
assert len(metrics.response_times) == 1
|
||||
|
||||
def test_record_failure(self):
|
||||
"""Test recording failed operation."""
|
||||
metrics = HealthMetrics()
|
||||
metrics.record_failure()
|
||||
|
||||
assert metrics.failure_count == 1
|
||||
assert metrics.consecutive_failures == 1
|
||||
|
||||
def test_circuit_breaker(self):
|
||||
"""Test circuit breaker opening."""
|
||||
metrics = HealthMetrics()
|
||||
|
||||
# Record 5 failures
|
||||
for _ in range(5):
|
||||
metrics.record_failure()
|
||||
|
||||
assert metrics.circuit_breaker_open is True
|
||||
assert metrics.consecutive_failures == 5
|
||||
|
||||
def test_should_attempt(self):
|
||||
"""Test should_attempt logic."""
|
||||
metrics = HealthMetrics()
|
||||
|
||||
# Should attempt if circuit breaker not open
|
||||
assert metrics.should_attempt() is True
|
||||
|
||||
# Open circuit breaker
|
||||
for _ in range(5):
|
||||
metrics.record_failure()
|
||||
|
||||
# Should not attempt immediately
|
||||
assert metrics.should_attempt(circuit_breaker_timeout=60) is False
|
||||
|
||||
def test_get_avg_response_time(self):
|
||||
"""Test average response time calculation."""
|
||||
metrics = HealthMetrics()
|
||||
metrics.response_times.extend([0.1, 0.2, 0.3])
|
||||
|
||||
avg = metrics.get_avg_response_time()
|
||||
assert avg == 0.2
|
||||
|
||||
|
||||
class TestHealthMonitor:
|
||||
"""Tests for HealthMonitor."""
|
||||
|
||||
@pytest.fixture
|
||||
def monitor(self):
|
||||
"""Create a health monitor instance."""
|
||||
return HealthMonitor()
|
||||
|
||||
def test_record_success(self, monitor):
|
||||
"""Test recording success."""
|
||||
monitor.record_success("provider1", 0.5)
|
||||
|
||||
metrics = monitor.get_metrics("provider1")
|
||||
assert metrics is not None
|
||||
assert metrics.status == HealthStatus.HEALTHY
|
||||
assert metrics.success_count == 1
|
||||
|
||||
def test_record_failure(self, monitor):
|
||||
"""Test recording failure."""
|
||||
monitor.record_failure("provider1")
|
||||
|
||||
metrics = monitor.get_metrics("provider1")
|
||||
assert metrics is not None
|
||||
assert metrics.failure_count == 1
|
||||
assert metrics.consecutive_failures == 1
|
||||
|
||||
def test_is_healthy(self, monitor):
|
||||
"""Test health checking."""
|
||||
# No metrics yet - assume healthy
|
||||
assert monitor.is_healthy("provider1") is True
|
||||
|
||||
# Record success
|
||||
monitor.record_success("provider1", 0.5)
|
||||
assert monitor.is_healthy("provider1") is True
|
||||
|
||||
# Record many failures
|
||||
for _ in range(10):
|
||||
monitor.record_failure("provider1")
|
||||
|
||||
assert monitor.is_healthy("provider1") is False
|
||||
|
||||
def test_get_health_status(self, monitor):
|
||||
"""Test getting health status."""
|
||||
monitor.record_success("provider1", 0.5)
|
||||
status = monitor.get_health_status("provider1")
|
||||
assert status == HealthStatus.HEALTHY
|
||||
|
||||
def test_select_healthiest(self, monitor):
|
||||
"""Test selecting healthiest provider."""
|
||||
# Make provider1 healthy
|
||||
monitor.record_success("provider1", 0.1)
|
||||
monitor.record_success("provider1", 0.2)
|
||||
|
||||
# Make provider2 unhealthy
|
||||
monitor.record_failure("provider2")
|
||||
monitor.record_failure("provider2")
|
||||
monitor.record_failure("provider2")
|
||||
|
||||
healthiest = monitor.select_healthiest(["provider1", "provider2"])
|
||||
assert healthiest == "provider1"
|
||||
|
||||
def test_reset_circuit_breaker(self, monitor):
|
||||
"""Test resetting circuit breaker."""
|
||||
# Open circuit breaker
|
||||
for _ in range(5):
|
||||
monitor.record_failure("provider1")
|
||||
|
||||
assert monitor.get_metrics("provider1").circuit_breaker_open is True
|
||||
|
||||
monitor.reset_circuit_breaker("provider1")
|
||||
|
||||
metrics = monitor.get_metrics("provider1")
|
||||
assert metrics.circuit_breaker_open is False
|
||||
assert metrics.consecutive_failures == 0
|
||||
|
||||
def test_reset_metrics(self, monitor):
|
||||
"""Test resetting metrics."""
|
||||
monitor.record_success("provider1", 0.5)
|
||||
assert monitor.get_metrics("provider1") is not None
|
||||
|
||||
monitor.reset_metrics("provider1")
|
||||
assert monitor.get_metrics("provider1") is None
|
||||
68
tests/unit/data/test_indicators.py
Normal file
68
tests/unit/data/test_indicators.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Tests for technical indicators."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from src.data.indicators import get_indicators, TechnicalIndicators
|
||||
|
||||
|
||||
class TestTechnicalIndicators:
|
||||
"""Tests for TechnicalIndicators."""
|
||||
|
||||
@pytest.fixture
|
||||
def indicators(self):
|
||||
"""Create indicators instance."""
|
||||
return get_indicators()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample price data."""
|
||||
dates = pd.date_range(start='2025-01-01', periods=100, freq='1H')
|
||||
return pd.DataFrame({
|
||||
'close': [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)],
|
||||
'high': [101 + i * 0.1 for i in range(100)],
|
||||
'low': [99 + i * 0.1 for i in range(100)],
|
||||
'open': [100 + i * 0.1 for i in range(100)],
|
||||
'volume': [1000.0] * 100
|
||||
})
|
||||
|
||||
def test_sma(self, indicators, sample_data):
|
||||
"""Test Simple Moving Average."""
|
||||
sma = indicators.sma(sample_data['close'], period=20)
|
||||
assert len(sma) == len(sample_data)
|
||||
assert not sma.isna().all() # Should have some valid values
|
||||
|
||||
def test_ema(self, indicators, sample_data):
|
||||
"""Test Exponential Moving Average."""
|
||||
ema = indicators.ema(sample_data['close'], period=20)
|
||||
assert len(ema) == len(sample_data)
|
||||
|
||||
def test_rsi(self, indicators, sample_data):
|
||||
"""Test Relative Strength Index."""
|
||||
rsi = indicators.rsi(sample_data['close'], period=14)
|
||||
assert len(rsi) == len(sample_data)
|
||||
# RSI should be between 0 and 100
|
||||
valid_rsi = rsi.dropna()
|
||||
if len(valid_rsi) > 0:
|
||||
assert (valid_rsi >= 0).all()
|
||||
assert (valid_rsi <= 100).all()
|
||||
|
||||
def test_macd(self, indicators, sample_data):
|
||||
"""Test MACD."""
|
||||
macd_result = indicators.macd(sample_data['close'], fast=12, slow=26, signal=9)
|
||||
assert 'macd' in macd_result
|
||||
assert 'signal' in macd_result
|
||||
assert 'histogram' in macd_result
|
||||
|
||||
def test_bollinger_bands(self, indicators, sample_data):
|
||||
"""Test Bollinger Bands."""
|
||||
bb = indicators.bollinger_bands(sample_data['close'], period=20, std_dev=2)
|
||||
assert 'upper' in bb
|
||||
assert 'middle' in bb
|
||||
assert 'lower' in bb
|
||||
# Upper should be above middle, middle above lower
|
||||
valid_data = bb.dropna()
|
||||
if len(valid_data) > 0:
|
||||
assert (valid_data['upper'] >= valid_data['middle']).all()
|
||||
assert (valid_data['middle'] >= valid_data['lower']).all()
|
||||
|
||||
80
tests/unit/data/test_indicators_divergence.py
Normal file
80
tests/unit/data/test_indicators_divergence.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Tests for divergence detection in indicators."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
|
||||
class TestDivergenceDetection:
|
||||
"""Tests for divergence detection."""
|
||||
|
||||
@pytest.fixture
|
||||
def indicators(self):
|
||||
"""Create indicators instance."""
|
||||
return get_indicators()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample price data with clear trend."""
|
||||
dates = pd.date_range(start='2025-01-01', periods=100, freq='1H')
|
||||
# Create price data with trend
|
||||
prices = [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)]
|
||||
return pd.Series(prices, index=dates)
|
||||
|
||||
def test_detect_divergence_insufficient_data(self, indicators):
|
||||
"""Test divergence detection with insufficient data."""
|
||||
prices = pd.Series([100, 101, 102])
|
||||
indicator = pd.Series([50, 51, 52])
|
||||
|
||||
result = indicators.detect_divergence(prices, indicator, lookback=20)
|
||||
|
||||
assert result['type'] is None
|
||||
assert result['confidence'] == 0.0
|
||||
|
||||
def test_detect_divergence_structure(self, indicators, sample_data):
|
||||
"""Test divergence detection returns correct structure."""
|
||||
# Create indicator data
|
||||
indicator = pd.Series([50 + i * 0.1 for i in range(100)], index=sample_data.index)
|
||||
|
||||
result = indicators.detect_divergence(sample_data, indicator, lookback=20)
|
||||
|
||||
# Check structure
|
||||
assert 'type' in result
|
||||
assert 'confidence' in result
|
||||
assert 'price_swing_high' in result
|
||||
assert 'price_swing_low' in result
|
||||
assert 'indicator_swing_high' in result
|
||||
assert 'indicator_swing_low' in result
|
||||
|
||||
# Type should be None, 'bullish', or 'bearish'
|
||||
assert result['type'] in [None, 'bullish', 'bearish']
|
||||
|
||||
# Confidence should be 0.0 to 1.0
|
||||
assert 0.0 <= result['confidence'] <= 1.0
|
||||
|
||||
def test_detect_divergence_with_trend(self, indicators):
|
||||
"""Test divergence detection with clear trend data."""
|
||||
# Create price making lower lows
|
||||
prices = pd.Series([100, 95, 90, 85, 80])
|
||||
|
||||
# Create indicator making higher lows (bullish divergence)
|
||||
indicator = pd.Series([30, 32, 34, 36, 38])
|
||||
|
||||
# Need more data for lookback
|
||||
prices_long = pd.concat([pd.Series([110] * 30), prices])
|
||||
indicator_long = pd.concat([pd.Series([25] * 30), indicator])
|
||||
|
||||
result = indicators.detect_divergence(
|
||||
prices_long,
|
||||
indicator_long,
|
||||
lookback=5,
|
||||
min_swings=2
|
||||
)
|
||||
|
||||
# Should detect bullish divergence (price down, indicator up)
|
||||
# Note: This may not always detect due to swing detection logic
|
||||
assert result is not None
|
||||
assert 'type' in result
|
||||
assert 'confidence' in result
|
||||
|
||||
135
tests/unit/data/test_pricing_service.py
Normal file
135
tests/unit/data/test_pricing_service.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Unit tests for pricing service."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from decimal import Decimal
|
||||
|
||||
from src.data.pricing_service import PricingService, get_pricing_service
|
||||
from src.data.providers.base_provider import BasePricingProvider
|
||||
|
||||
|
||||
class MockProvider(BasePricingProvider):
|
||||
"""Mock provider for testing."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "MockProvider"
|
||||
|
||||
@property
|
||||
def supports_websocket(self) -> bool:
|
||||
return False
|
||||
|
||||
def connect(self) -> bool:
|
||||
self._connected = True
|
||||
return True
|
||||
|
||||
def disconnect(self):
|
||||
self._connected = False
|
||||
|
||||
def get_ticker(self, symbol: str):
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': Decimal('50000'),
|
||||
'ask': Decimal('50001'),
|
||||
'last': Decimal('50000.5'),
|
||||
'high': Decimal('51000'),
|
||||
'low': Decimal('49000'),
|
||||
'volume': Decimal('1000000'),
|
||||
'timestamp': 1609459200000,
|
||||
}
|
||||
|
||||
def get_ohlcv(self, symbol: str, timeframe: str = '1h', since=None, limit: int = 100):
|
||||
return [[1609459200000, 50000, 51000, 49000, 50000, 1000]]
|
||||
|
||||
def subscribe_ticker(self, symbol: str, callback) -> bool:
|
||||
if symbol not in self._subscribers:
|
||||
self._subscribers[symbol] = []
|
||||
self._subscribers[symbol].append(callback)
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock configuration."""
|
||||
config = Mock()
|
||||
config.get = Mock(side_effect=lambda key, default=None: {
|
||||
"data_providers.primary": [
|
||||
{"name": "mock", "enabled": True, "priority": 1}
|
||||
],
|
||||
"data_providers.fallback": {"enabled": True, "api_key": ""},
|
||||
"data_providers.caching.ticker_ttl": 2,
|
||||
"data_providers.caching.ohlcv_ttl": 60,
|
||||
"data_providers.caching.max_cache_size": 1000,
|
||||
}.get(key, default))
|
||||
return config
|
||||
|
||||
|
||||
class TestPricingService:
|
||||
"""Tests for PricingService."""
|
||||
|
||||
@patch('src.data.pricing_service.get_config')
|
||||
@patch('src.data.providers.ccxt_provider.CCXTProvider')
|
||||
@patch('src.data.providers.coingecko_provider.CoinGeckoProvider')
|
||||
def test_init(self, mock_coingecko, mock_ccxt, mock_get_config, mock_config):
|
||||
"""Test service initialization."""
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_ccxt_instance = MockProvider()
|
||||
mock_ccxt.return_value = mock_ccxt_instance
|
||||
mock_coingecko_instance = MockProvider()
|
||||
mock_coingecko.return_value = mock_coingecko_instance
|
||||
|
||||
service = PricingService()
|
||||
|
||||
assert service.cache is not None
|
||||
assert service.health_monitor is not None
|
||||
|
||||
@patch('src.data.pricing_service.get_config')
|
||||
@patch('src.data.providers.ccxt_provider.CCXTProvider')
|
||||
def test_get_ticker(self, mock_ccxt, mock_get_config, mock_config):
|
||||
"""Test getting ticker data."""
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_provider = MockProvider()
|
||||
mock_ccxt.return_value = mock_provider
|
||||
|
||||
service = PricingService()
|
||||
service._providers["MockProvider"] = mock_provider
|
||||
service._active_provider = "MockProvider"
|
||||
|
||||
ticker = service.get_ticker("BTC/USD")
|
||||
|
||||
assert ticker['symbol'] == "BTC/USD"
|
||||
assert isinstance(ticker['last'], Decimal)
|
||||
|
||||
@patch('src.data.pricing_service.get_config')
|
||||
@patch('src.data.providers.ccxt_provider.CCXTProvider')
|
||||
def test_get_ohlcv(self, mock_ccxt, mock_get_config, mock_config):
|
||||
"""Test getting OHLCV data."""
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_provider = MockProvider()
|
||||
mock_ccxt.return_value = mock_provider
|
||||
|
||||
service = PricingService()
|
||||
service._providers["MockProvider"] = mock_provider
|
||||
service._active_provider = "MockProvider"
|
||||
|
||||
ohlcv = service.get_ohlcv("BTC/USD", "1h", limit=10)
|
||||
|
||||
assert len(ohlcv) > 0
|
||||
assert len(ohlcv[0]) == 6
|
||||
|
||||
@patch('src.data.pricing_service.get_config')
|
||||
@patch('src.data.providers.ccxt_provider.CCXTProvider')
|
||||
def test_subscribe_ticker(self, mock_ccxt, mock_get_config, mock_config):
|
||||
"""Test subscribing to ticker updates."""
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_provider = MockProvider()
|
||||
mock_ccxt.return_value = mock_provider
|
||||
|
||||
service = PricingService()
|
||||
service._providers["MockProvider"] = mock_provider
|
||||
service._active_provider = "MockProvider"
|
||||
|
||||
callback = Mock()
|
||||
result = service.subscribe_ticker("BTC/USD", callback)
|
||||
|
||||
assert result is True
|
||||
118
tests/unit/data/test_redis_cache.py
Normal file
118
tests/unit/data/test_redis_cache.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Tests for Redis cache."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
|
||||
|
||||
class TestRedisCache:
|
||||
"""Tests for RedisCache class."""
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ticker_cache_hit(self, mock_get_client):
|
||||
"""Test getting cached ticker data."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = '{"price": 45000.0, "symbol": "BTC/USD"}'
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
result = await cache.get_ticker("BTC/USD")
|
||||
|
||||
assert result is not None
|
||||
assert result["price"] == 45000.0
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ticker_cache_miss(self, mock_get_client):
|
||||
"""Test ticker cache miss."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = None
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
result = await cache.get_ticker("BTC/USD")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_ticker(self, mock_get_client):
|
||||
"""Test setting ticker cache."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
result = await cache.set_ticker("BTC/USD", {"price": 45000.0})
|
||||
|
||||
assert result is True
|
||||
mock_client.setex.assert_called_once()
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ohlcv(self, mock_get_client):
|
||||
"""Test getting cached OHLCV data."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = '[[1700000000, 45000, 45500, 44500, 45200, 1000]]'
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
result = await cache.get_ohlcv("BTC/USD", "1h", 100)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == 1700000000
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_ohlcv(self, mock_get_client):
|
||||
"""Test setting OHLCV cache."""
|
||||
mock_redis = Mock()
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.get_client.return_value = mock_client
|
||||
mock_get_client.return_value = mock_redis
|
||||
|
||||
from src.data.redis_cache import RedisCache
|
||||
cache = RedisCache()
|
||||
|
||||
ohlcv_data = [[1700000000, 45000, 45500, 44500, 45200, 1000]]
|
||||
result = await cache.set_ohlcv("BTC/USD", "1h", ohlcv_data)
|
||||
|
||||
assert result is True
|
||||
mock_client.setex.assert_called_once()
|
||||
|
||||
|
||||
class TestGetRedisCache:
|
||||
"""Tests for get_redis_cache singleton."""
|
||||
|
||||
@patch('src.data.redis_cache.get_redis_client')
|
||||
def test_returns_singleton(self, mock_get_client):
|
||||
"""Test that get_redis_cache returns same instance."""
|
||||
mock_get_client.return_value = Mock()
|
||||
|
||||
# Reset the global
|
||||
import src.data.redis_cache as cache_module
|
||||
cache_module._redis_cache = None
|
||||
|
||||
from src.data.redis_cache import get_redis_cache
|
||||
|
||||
cache1 = get_redis_cache()
|
||||
cache2 = get_redis_cache()
|
||||
|
||||
assert cache1 is cache2
|
||||
2
tests/unit/exchanges/__init__.py
Normal file
2
tests/unit/exchanges/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for exchange adapters."""
|
||||
|
||||
24
tests/unit/exchanges/test_base.py
Normal file
24
tests/unit/exchanges/test_base.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Tests for base exchange adapter."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from src.exchanges.base import BaseExchange
|
||||
|
||||
|
||||
class TestBaseExchange:
|
||||
"""Tests for BaseExchange abstract class."""
|
||||
|
||||
def test_base_exchange_init(self):
|
||||
"""Test base exchange initialization."""
|
||||
# Can't instantiate abstract class, test through concrete implementation
|
||||
from src.exchanges.coinbase import CoinbaseExchange
|
||||
|
||||
exchange = CoinbaseExchange(
|
||||
name="test",
|
||||
api_key="test_key",
|
||||
secret_key="test_secret"
|
||||
)
|
||||
assert exchange.name == "test"
|
||||
assert exchange.api_key == "test_key"
|
||||
assert not exchange.is_connected
|
||||
|
||||
56
tests/unit/exchanges/test_coinbase.py
Normal file
56
tests/unit/exchanges/test_coinbase.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Tests for Coinbase exchange adapter."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from src.exchanges.coinbase import CoinbaseExchange
|
||||
|
||||
|
||||
class TestCoinbaseExchange:
|
||||
"""Tests for CoinbaseExchange."""
|
||||
|
||||
@pytest.fixture
|
||||
def exchange(self):
|
||||
"""Create Coinbase exchange instance."""
|
||||
return CoinbaseExchange(
|
||||
name="test_coinbase",
|
||||
api_key="test_key",
|
||||
secret_key="test_secret",
|
||||
permissions="read_only"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(self, exchange):
|
||||
"""Test connection to Coinbase."""
|
||||
with patch.object(exchange.exchange, 'load_markets', new_callable=AsyncMock):
|
||||
await exchange.connect()
|
||||
assert exchange.is_connected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_balance(self, exchange):
|
||||
"""Test fetching balance."""
|
||||
mock_balance = {'USD': {'free': 1000.0, 'used': 0.0, 'total': 1000.0}}
|
||||
exchange.exchange.fetch_balance = AsyncMock(return_value=mock_balance)
|
||||
exchange.is_connected = True
|
||||
|
||||
balance = await exchange.fetch_balance()
|
||||
assert balance == mock_balance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_place_order_readonly(self, exchange):
|
||||
"""Test placing order with read-only permissions."""
|
||||
exchange.permissions = "read_only"
|
||||
exchange.is_connected = True
|
||||
|
||||
with pytest.raises(PermissionError):
|
||||
await exchange.place_order("BTC/USD", "buy", "market", 0.01)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_ohlcv(self, exchange):
|
||||
"""Test fetching OHLCV data."""
|
||||
mock_ohlcv = [[1609459200000, 29000, 29500, 28800, 29300, 1000]]
|
||||
exchange.exchange.fetch_ohlcv = AsyncMock(return_value=mock_ohlcv)
|
||||
exchange.is_connected = True
|
||||
|
||||
ohlcv = await exchange.fetch_ohlcv("BTC/USD", "1h")
|
||||
assert ohlcv == mock_ohlcv
|
||||
|
||||
40
tests/unit/exchanges/test_factory.py
Normal file
40
tests/unit/exchanges/test_factory.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Tests for exchange factory."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from src.exchanges.factory import ExchangeFactory
|
||||
from src.exchanges.coinbase import CoinbaseExchange
|
||||
|
||||
|
||||
class TestExchangeFactory:
|
||||
"""Tests for ExchangeFactory."""
|
||||
|
||||
def test_register_exchange(self):
|
||||
"""Test exchange registration."""
|
||||
ExchangeFactory.register_exchange("test_exchange", CoinbaseExchange)
|
||||
assert "test_exchange" in ExchangeFactory.list_available()
|
||||
|
||||
def test_get_exchange(self):
|
||||
"""Test getting exchange instance."""
|
||||
with patch('src.exchanges.factory.get_key_manager') as mock_km:
|
||||
mock_km.return_value.get_exchange_keys.return_value = {
|
||||
'api_key': 'test_key',
|
||||
'secret_key': 'test_secret',
|
||||
'permissions': 'read_only'
|
||||
}
|
||||
|
||||
exchange = ExchangeFactory.get_exchange("coinbase")
|
||||
assert exchange is not None
|
||||
assert isinstance(exchange, CoinbaseExchange)
|
||||
|
||||
def test_get_nonexistent_exchange(self):
|
||||
"""Test getting non-existent exchange."""
|
||||
with pytest.raises(ValueError, match="not registered"):
|
||||
ExchangeFactory.get_exchange("nonexistent")
|
||||
|
||||
def test_list_available(self):
|
||||
"""Test listing available exchanges."""
|
||||
exchanges = ExchangeFactory.list_available()
|
||||
assert isinstance(exchanges, list)
|
||||
assert "coinbase" in exchanges
|
||||
|
||||
44
tests/unit/exchanges/test_websocket.py
Normal file
44
tests/unit/exchanges/test_websocket.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for WebSocket functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from src.exchanges.coinbase import CoinbaseAdapter
|
||||
|
||||
|
||||
def test_subscribe_ticker():
|
||||
"""Test ticker subscription."""
|
||||
adapter = CoinbaseAdapter("test_key", "test_secret", sandbox=True)
|
||||
callback = Mock()
|
||||
|
||||
adapter.subscribe_ticker("BTC/USD", callback)
|
||||
|
||||
assert f'ticker_BTC/USD' in adapter._ws_callbacks
|
||||
assert adapter._ws_callbacks[f'ticker_BTC/USD'] == callback
|
||||
|
||||
|
||||
def test_subscribe_orderbook():
|
||||
"""Test orderbook subscription."""
|
||||
adapter = CoinbaseAdapter("test_key", "test_secret", sandbox=True)
|
||||
callback = Mock()
|
||||
|
||||
adapter.subscribe_orderbook("BTC/USD", callback)
|
||||
|
||||
assert f'orderbook_BTC/USD' in adapter._ws_callbacks
|
||||
|
||||
|
||||
def test_subscribe_trades():
|
||||
"""Test trades subscription."""
|
||||
adapter = CoinbaseAdapter("test_key", "test_secret", sandbox=True)
|
||||
callback = Mock()
|
||||
|
||||
adapter.subscribe_trades("BTC/USD", callback)
|
||||
|
||||
assert f'trades_BTC/USD' in adapter._ws_callbacks
|
||||
|
||||
|
||||
def test_normalize_symbol():
|
||||
"""Test symbol normalization."""
|
||||
adapter = CoinbaseAdapter("test_key", "test_secret", sandbox=True)
|
||||
|
||||
assert adapter.normalize_symbol("BTC/USD") == "BTC-USD"
|
||||
assert adapter.normalize_symbol("ETH/USDT") == "ETH-USDT"
|
||||
2
tests/unit/optimization/__init__.py
Normal file
2
tests/unit/optimization/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for optimization."""
|
||||
|
||||
44
tests/unit/optimization/test_grid_search.py
Normal file
44
tests/unit/optimization/test_grid_search.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for grid search optimization."""
|
||||
|
||||
import pytest
|
||||
from src.optimization.grid_search import GridSearchOptimizer
|
||||
|
||||
|
||||
class TestGridSearchOptimizer:
|
||||
"""Tests for GridSearchOptimizer."""
|
||||
|
||||
@pytest.fixture
|
||||
def optimizer(self):
|
||||
"""Create grid search optimizer."""
|
||||
return GridSearchOptimizer()
|
||||
|
||||
def test_optimize_maximize(self, optimizer):
|
||||
"""Test optimization with maximize."""
|
||||
param_grid = {
|
||||
'param1': [1, 2, 3],
|
||||
'param2': [10, 20]
|
||||
}
|
||||
|
||||
def objective(params):
|
||||
return params['param1'] * params['param2']
|
||||
|
||||
result = optimizer.optimize(param_grid, objective, maximize=True)
|
||||
|
||||
assert result['best_params'] is not None
|
||||
assert result['best_score'] is not None
|
||||
assert result['best_score'] > 0
|
||||
|
||||
def test_optimize_minimize(self, optimizer):
|
||||
"""Test optimization with minimize."""
|
||||
param_grid = {
|
||||
'param1': [1, 2, 3]
|
||||
}
|
||||
|
||||
def objective(params):
|
||||
return params['param1'] * 10
|
||||
|
||||
result = optimizer.optimize(param_grid, objective, maximize=False)
|
||||
|
||||
assert result['best_params'] is not None
|
||||
assert result['best_score'] is not None
|
||||
|
||||
2
tests/unit/portfolio/__init__.py
Normal file
2
tests/unit/portfolio/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for portfolio management."""
|
||||
|
||||
35
tests/unit/portfolio/test_analytics.py
Normal file
35
tests/unit/portfolio/test_analytics.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Tests for portfolio analytics."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from src.portfolio.analytics import get_portfolio_analytics, PortfolioAnalytics
|
||||
|
||||
|
||||
class TestPortfolioAnalytics:
|
||||
"""Tests for PortfolioAnalytics."""
|
||||
|
||||
@pytest.fixture
|
||||
def analytics(self):
|
||||
"""Create portfolio analytics instance."""
|
||||
return get_portfolio_analytics()
|
||||
|
||||
def test_calculate_sharpe_ratio(self, analytics):
|
||||
"""Test Sharpe ratio calculation."""
|
||||
returns = pd.Series([0.01, -0.005, 0.02, -0.01, 0.015])
|
||||
sharpe = analytics.calculate_sharpe_ratio(returns, risk_free_rate=0.0)
|
||||
assert isinstance(sharpe, float)
|
||||
|
||||
def test_calculate_sortino_ratio(self, analytics):
|
||||
"""Test Sortino ratio calculation."""
|
||||
returns = pd.Series([0.01, -0.005, 0.02, -0.01, 0.015])
|
||||
sortino = analytics.calculate_sortino_ratio(returns, risk_free_rate=0.0)
|
||||
assert isinstance(sortino, float)
|
||||
|
||||
def test_calculate_max_drawdown(self, analytics):
|
||||
"""Test max drawdown calculation."""
|
||||
equity_curve = pd.Series([10000, 10500, 10200, 11000, 10800])
|
||||
drawdown = analytics.calculate_max_drawdown(equity_curve)
|
||||
assert isinstance(drawdown, float)
|
||||
assert 0 <= drawdown <= 1
|
||||
|
||||
29
tests/unit/portfolio/test_tracker.py
Normal file
29
tests/unit/portfolio/test_tracker.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Tests for portfolio tracker."""
|
||||
|
||||
import pytest
|
||||
from src.portfolio.tracker import get_portfolio_tracker, PortfolioTracker
|
||||
|
||||
|
||||
class TestPortfolioTracker:
|
||||
"""Tests for PortfolioTracker."""
|
||||
|
||||
@pytest.fixture
|
||||
def tracker(self):
|
||||
"""Create portfolio tracker instance."""
|
||||
return get_portfolio_tracker()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_portfolio(self, tracker):
|
||||
"""Test getting current portfolio."""
|
||||
portfolio = await tracker.get_current_portfolio(paper_trading=True)
|
||||
assert portfolio is not None
|
||||
assert "positions" in portfolio
|
||||
assert "performance" in portfolio
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_positions_prices(self, tracker):
|
||||
"""Test updating position prices."""
|
||||
prices = {"BTC/USD": 50000.0}
|
||||
await tracker.update_positions_prices(prices, paper_trading=True)
|
||||
# Should not raise exception
|
||||
|
||||
2
tests/unit/reporting/__init__.py
Normal file
2
tests/unit/reporting/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for reporting."""
|
||||
|
||||
35
tests/unit/reporting/test_csv_exporter.py
Normal file
35
tests/unit/reporting/test_csv_exporter.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Tests for CSV exporter."""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from src.reporting.csv_exporter import get_csv_exporter, CSVExporter
|
||||
|
||||
|
||||
class TestCSVExporter:
|
||||
"""Tests for CSVExporter."""
|
||||
|
||||
@pytest.fixture
|
||||
def exporter(self):
|
||||
"""Create CSV exporter instance."""
|
||||
return get_csv_exporter()
|
||||
|
||||
def test_export_trades(self, exporter, mock_database):
|
||||
"""Test exporting trades."""
|
||||
engine, Session = mock_database
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
filepath = Path(tmpdir) / "trades.csv"
|
||||
|
||||
# Export (may be empty if no trades)
|
||||
result = exporter.export_trades(filepath, paper_trading=True)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_export_portfolio(self, exporter):
|
||||
"""Test exporting portfolio."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
filepath = Path(tmpdir) / "portfolio.csv"
|
||||
|
||||
result = exporter.export_portfolio(filepath)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
2
tests/unit/resilience/__init__.py
Normal file
2
tests/unit/resilience/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for resilience."""
|
||||
|
||||
31
tests/unit/resilience/test_state_manager.py
Normal file
31
tests/unit/resilience/test_state_manager.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Tests for state manager."""
|
||||
|
||||
import pytest
|
||||
from src.resilience.state_manager import get_state_manager, StateManager
|
||||
|
||||
|
||||
class TestStateManager:
|
||||
"""Tests for StateManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def state_manager(self):
|
||||
"""Create state manager instance."""
|
||||
return get_state_manager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_state(self, state_manager):
|
||||
"""Test saving state."""
|
||||
result = await state_manager.save_state("test_key", {"data": "value"})
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_state(self, state_manager):
|
||||
"""Test loading state."""
|
||||
# Save first
|
||||
await state_manager.save_state("test_key", {"data": "value"})
|
||||
|
||||
# Load
|
||||
state = await state_manager.load_state("test_key")
|
||||
assert state is not None
|
||||
assert state.get("data") == "value"
|
||||
|
||||
2
tests/unit/risk/__init__.py
Normal file
2
tests/unit/risk/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for risk management."""
|
||||
|
||||
55
tests/unit/risk/test_manager.py
Normal file
55
tests/unit/risk/test_manager.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Tests for risk manager."""
|
||||
|
||||
import pytest
|
||||
from src.risk.manager import get_risk_manager, RiskManager
|
||||
|
||||
|
||||
class TestRiskManager:
|
||||
"""Tests for RiskManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def risk_manager(self):
|
||||
"""Create risk manager instance."""
|
||||
return get_risk_manager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_trade_risk(self, risk_manager):
|
||||
"""Test trade risk checking."""
|
||||
# Test with valid trade
|
||||
result = await risk_manager.check_trade_risk(
|
||||
exchange_id=1,
|
||||
strategy_id=1,
|
||||
symbol="BTC/USD",
|
||||
side="buy",
|
||||
amount=0.01,
|
||||
price=50000.0,
|
||||
current_portfolio_value=10000.0
|
||||
)
|
||||
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_max_drawdown(self, risk_manager):
|
||||
"""Test max drawdown check."""
|
||||
result = await risk_manager.check_max_drawdown(
|
||||
current_portfolio_value=9000.0,
|
||||
peak_portfolio_value=10000.0
|
||||
)
|
||||
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_risk_limit(self, risk_manager, mock_database):
|
||||
"""Test adding risk limit."""
|
||||
engine, Session = mock_database
|
||||
|
||||
await risk_manager.add_risk_limit(
|
||||
limit_type="max_drawdown",
|
||||
value=0.10, # 10%
|
||||
is_active=True
|
||||
)
|
||||
|
||||
# Verify limit was added
|
||||
await risk_manager.load_risk_limits()
|
||||
assert len(risk_manager.risk_limits) > 0
|
||||
|
||||
41
tests/unit/risk/test_position_sizing.py
Normal file
41
tests/unit/risk/test_position_sizing.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Tests for position sizing."""
|
||||
|
||||
import pytest
|
||||
from src.risk.position_sizing import get_position_sizer, PositionSizing
|
||||
|
||||
|
||||
class TestPositionSizing:
|
||||
"""Tests for PositionSizing."""
|
||||
|
||||
@pytest.fixture
|
||||
def position_sizer(self):
|
||||
"""Create position sizer instance."""
|
||||
return get_position_sizer()
|
||||
|
||||
def test_fixed_percentage(self, position_sizer):
|
||||
"""Test fixed percentage sizing."""
|
||||
size = position_sizer.fixed_percentage(10000.0, 0.02) # 2%
|
||||
assert size == 200.0
|
||||
|
||||
def test_fixed_amount(self, position_sizer):
|
||||
"""Test fixed amount sizing."""
|
||||
size = position_sizer.fixed_amount(500.0)
|
||||
assert size == 500.0
|
||||
|
||||
def test_volatility_based(self, position_sizer):
|
||||
"""Test volatility-based sizing."""
|
||||
size = position_sizer.volatility_based(
|
||||
capital=10000.0,
|
||||
atr=100.0,
|
||||
risk_per_trade_percentage=0.01 # 1%
|
||||
)
|
||||
assert size > 0
|
||||
|
||||
def test_kelly_criterion(self, position_sizer):
|
||||
"""Test Kelly Criterion."""
|
||||
fraction = position_sizer.kelly_criterion(
|
||||
win_probability=0.6,
|
||||
payout_ratio=1.5
|
||||
)
|
||||
assert 0 <= fraction <= 1
|
||||
|
||||
122
tests/unit/risk/test_stop_loss_atr.py
Normal file
122
tests/unit/risk/test_stop_loss_atr.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Tests for ATR-based stop loss."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from src.risk.stop_loss import StopLossManager
|
||||
|
||||
|
||||
class TestATRStopLoss:
|
||||
"""Tests for ATR-based stop loss functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def stop_loss_manager(self):
|
||||
"""Create stop loss manager instance."""
|
||||
return StopLossManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ohlcv_data(self):
|
||||
"""Create sample OHLCV data."""
|
||||
dates = pd.date_range(start='2025-01-01', periods=50, freq='1H')
|
||||
base_price = 50000
|
||||
return pd.DataFrame({
|
||||
'high': [base_price + 100 + i * 10 for i in range(50)],
|
||||
'low': [base_price - 100 + i * 10 for i in range(50)],
|
||||
'close': [base_price + i * 10 for i in range(50)],
|
||||
'open': [base_price - 50 + i * 10 for i in range(50)],
|
||||
'volume': [1000.0] * 50
|
||||
}, index=dates)
|
||||
|
||||
def test_set_atr_stop_loss(self, stop_loss_manager, sample_ohlcv_data):
|
||||
"""Test setting ATR-based stop loss."""
|
||||
position_id = 1
|
||||
entry_price = Decimal("50000")
|
||||
|
||||
stop_loss_manager.set_stop_loss(
|
||||
position_id=position_id,
|
||||
stop_price=entry_price,
|
||||
use_atr=True,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14,
|
||||
ohlcv_data=sample_ohlcv_data
|
||||
)
|
||||
|
||||
assert position_id in stop_loss_manager.stop_losses
|
||||
config = stop_loss_manager.stop_losses[position_id]
|
||||
assert config['use_atr'] is True
|
||||
assert config['atr_multiplier'] == Decimal('2.0')
|
||||
assert config['atr_period'] == 14
|
||||
assert 'stop_price' in config
|
||||
|
||||
def test_calculate_atr_stop(self, stop_loss_manager, sample_ohlcv_data):
|
||||
"""Test calculating ATR stop price."""
|
||||
entry_price = Decimal("50000")
|
||||
|
||||
stop_price_long = stop_loss_manager.calculate_atr_stop(
|
||||
entry_price=entry_price,
|
||||
is_long=True,
|
||||
ohlcv_data=sample_ohlcv_data,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14
|
||||
)
|
||||
|
||||
stop_price_short = stop_loss_manager.calculate_atr_stop(
|
||||
entry_price=entry_price,
|
||||
is_long=False,
|
||||
ohlcv_data=sample_ohlcv_data,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14
|
||||
)
|
||||
|
||||
# Long position: stop should be below entry
|
||||
assert stop_price_long < entry_price
|
||||
|
||||
# Short position: stop should be above entry
|
||||
assert stop_price_short > entry_price
|
||||
|
||||
def test_atr_trailing_stop(self, stop_loss_manager, sample_ohlcv_data):
|
||||
"""Test ATR-based trailing stop."""
|
||||
position_id = 1
|
||||
entry_price = Decimal("50000")
|
||||
|
||||
stop_loss_manager.set_stop_loss(
|
||||
position_id=position_id,
|
||||
stop_price=entry_price,
|
||||
trailing=True,
|
||||
use_atr=True,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14,
|
||||
ohlcv_data=sample_ohlcv_data
|
||||
)
|
||||
|
||||
# Check stop loss with higher price (should update trailing stop)
|
||||
current_price = Decimal("51000")
|
||||
triggered = stop_loss_manager.check_stop_loss(
|
||||
position_id=position_id,
|
||||
current_price=current_price,
|
||||
is_long=True,
|
||||
ohlcv_data=sample_ohlcv_data
|
||||
)
|
||||
|
||||
# Should not trigger at higher price
|
||||
assert triggered is False
|
||||
|
||||
def test_atr_stop_insufficient_data(self, stop_loss_manager):
|
||||
"""Test ATR stop with insufficient data falls back to percentage."""
|
||||
entry_price = Decimal("50000")
|
||||
insufficient_data = pd.DataFrame({
|
||||
'high': [51000],
|
||||
'low': [49000],
|
||||
'close': [50000]
|
||||
})
|
||||
|
||||
stop_price = stop_loss_manager.calculate_atr_stop(
|
||||
entry_price=entry_price,
|
||||
is_long=True,
|
||||
ohlcv_data=insufficient_data,
|
||||
atr_period=14
|
||||
)
|
||||
|
||||
# Should fall back to percentage-based stop (2%)
|
||||
assert stop_price == entry_price * Decimal('0.98')
|
||||
|
||||
2
tests/unit/security/__init__.py
Normal file
2
tests/unit/security/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for security."""
|
||||
|
||||
35
tests/unit/security/test_encryption.py
Normal file
35
tests/unit/security/test_encryption.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Tests for encryption."""
|
||||
|
||||
import pytest
|
||||
from src.security.encryption import get_encryption_manager, EncryptionManager
|
||||
|
||||
|
||||
class TestEncryption:
|
||||
"""Tests for encryption system."""
|
||||
|
||||
@pytest.fixture
|
||||
def encryptor(self):
|
||||
"""Create encryption manager."""
|
||||
return get_encryption_manager()
|
||||
|
||||
def test_encrypt_decrypt(self, encryptor):
|
||||
"""Test encryption and decryption."""
|
||||
plaintext = "test_api_key_12345"
|
||||
|
||||
encrypted = encryptor.encrypt(plaintext)
|
||||
assert encrypted != plaintext
|
||||
assert len(encrypted) > 0
|
||||
|
||||
decrypted = encryptor.decrypt(encrypted)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_encrypt_different_values(self, encryptor):
|
||||
"""Test that different values encrypt differently."""
|
||||
value1 = "key1"
|
||||
value2 = "key2"
|
||||
|
||||
encrypted1 = encryptor.encrypt(value1)
|
||||
encrypted2 = encryptor.encrypt(value2)
|
||||
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
2
tests/unit/strategies/__init__.py
Normal file
2
tests/unit/strategies/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for strategy framework."""
|
||||
|
||||
89
tests/unit/strategies/test_base.py
Normal file
89
tests/unit/strategies/test_base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Tests for base strategy class."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from src.strategies.base import BaseStrategy, StrategyRegistry
|
||||
|
||||
|
||||
class ConcreteStrategy(BaseStrategy):
|
||||
"""Concrete strategy for testing."""
|
||||
|
||||
async def on_data(self, new_data: pd.DataFrame):
|
||||
"""Handle new data."""
|
||||
self.current_data = pd.concat([self.current_data, new_data]).tail(100)
|
||||
|
||||
async def generate_signal(self):
|
||||
"""Generate signal."""
|
||||
if len(self.current_data) > 0:
|
||||
return {"signal": "hold", "price": self.current_data['close'].iloc[-1]}
|
||||
return {"signal": "hold", "price": None}
|
||||
|
||||
async def calculate_position_size(self, capital: float, risk_percentage: float) -> float:
|
||||
"""Calculate position size."""
|
||||
return capital * risk_percentage
|
||||
|
||||
|
||||
class TestBaseStrategy:
|
||||
"""Tests for BaseStrategy."""
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create strategy instance."""
|
||||
return ConcreteStrategy(
|
||||
strategy_id=1,
|
||||
name="test_strategy",
|
||||
symbol="BTC/USD",
|
||||
timeframe="1h",
|
||||
parameters={}
|
||||
)
|
||||
|
||||
def test_strategy_initialization(self, strategy):
|
||||
"""Test strategy initialization."""
|
||||
assert strategy.strategy_id == 1
|
||||
assert strategy.name == "test_strategy"
|
||||
assert strategy.symbol == "BTC/USD"
|
||||
assert strategy.timeframe == "1h"
|
||||
assert not strategy.is_active
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strategy_start_stop(self, strategy):
|
||||
"""Test strategy start and stop."""
|
||||
await strategy.start()
|
||||
assert strategy.is_active
|
||||
|
||||
await strategy.stop()
|
||||
assert not strategy.is_active
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_signal(self, strategy):
|
||||
"""Test signal generation."""
|
||||
signal = await strategy.generate_signal()
|
||||
assert "signal" in signal
|
||||
assert signal["signal"] in ["buy", "sell", "hold"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_position_size(self, strategy):
|
||||
"""Test position size calculation."""
|
||||
size = await strategy.calculate_position_size(1000.0, 0.01)
|
||||
assert size == 10.0
|
||||
|
||||
|
||||
class TestStrategyRegistry:
|
||||
"""Tests for StrategyRegistry."""
|
||||
|
||||
def test_register_strategy(self):
|
||||
"""Test strategy registration."""
|
||||
StrategyRegistry.register_strategy("test_strategy", ConcreteStrategy)
|
||||
assert "test_strategy" in StrategyRegistry.list_available()
|
||||
|
||||
def test_get_strategy_class(self):
|
||||
"""Test getting strategy class."""
|
||||
StrategyRegistry.register_strategy("test_strategy", ConcreteStrategy)
|
||||
strategy_class = StrategyRegistry.get_strategy_class("test_strategy")
|
||||
assert strategy_class == ConcreteStrategy
|
||||
|
||||
def test_get_nonexistent_strategy(self):
|
||||
"""Test getting non-existent strategy."""
|
||||
with pytest.raises(ValueError, match="not registered"):
|
||||
StrategyRegistry.get_strategy_class("nonexistent")
|
||||
|
||||
55
tests/unit/strategies/test_bollinger_mean_reversion.py
Normal file
55
tests/unit/strategies/test_bollinger_mean_reversion.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Tests for Bollinger Bands mean reversion strategy."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from src.strategies.technical.bollinger_mean_reversion import BollingerMeanReversionStrategy
|
||||
from src.strategies.base import SignalType
|
||||
|
||||
|
||||
class TestBollingerMeanReversionStrategy:
|
||||
"""Tests for BollingerMeanReversionStrategy."""
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create Bollinger mean reversion strategy instance."""
|
||||
return BollingerMeanReversionStrategy(
|
||||
name="test_bollinger_mr",
|
||||
parameters={
|
||||
'period': 20,
|
||||
'std_dev': 2.0,
|
||||
'trend_filter': True,
|
||||
'trend_ma_period': 50,
|
||||
'entry_threshold': 0.95,
|
||||
'exit_threshold': 0.5
|
||||
}
|
||||
)
|
||||
|
||||
def test_initialization(self, strategy):
|
||||
"""Test strategy initialization."""
|
||||
assert strategy.period == 20
|
||||
assert strategy.std_dev == 2.0
|
||||
assert strategy.trend_filter is True
|
||||
assert strategy.trend_ma_period == 50
|
||||
assert strategy.entry_threshold == 0.95
|
||||
assert strategy.exit_threshold == 0.5
|
||||
|
||||
def test_on_tick_insufficient_data(self, strategy):
|
||||
"""Test that strategy returns None with insufficient data."""
|
||||
signal = strategy.on_tick(
|
||||
symbol="BTC/USD",
|
||||
price=Decimal("50000"),
|
||||
timeframe="1h",
|
||||
data={'volume': 1000}
|
||||
)
|
||||
assert signal is None
|
||||
|
||||
def test_position_tracking(self, strategy):
|
||||
"""Test position tracking."""
|
||||
assert strategy._in_position is False
|
||||
assert strategy._entry_price is None
|
||||
|
||||
def test_strategy_metadata(self, strategy):
|
||||
"""Test strategy metadata."""
|
||||
assert strategy.name == "test_bollinger_mr"
|
||||
assert strategy.enabled is False
|
||||
|
||||
61
tests/unit/strategies/test_confirmed_strategy.py
Normal file
61
tests/unit/strategies/test_confirmed_strategy.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Tests for Confirmed strategy."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from src.strategies.technical.confirmed_strategy import ConfirmedStrategy
|
||||
from src.strategies.base import SignalType
|
||||
|
||||
|
||||
class TestConfirmedStrategy:
|
||||
"""Tests for ConfirmedStrategy."""
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create Confirmed strategy instance."""
|
||||
return ConfirmedStrategy(
|
||||
name="test_confirmed",
|
||||
parameters={
|
||||
'rsi_period': 14,
|
||||
'macd_fast': 12,
|
||||
'macd_slow': 26,
|
||||
'macd_signal': 9,
|
||||
'ma_fast': 10,
|
||||
'ma_slow': 30,
|
||||
'min_confirmations': 2,
|
||||
'require_rsi': True,
|
||||
'require_macd': True,
|
||||
'require_ma': True
|
||||
}
|
||||
)
|
||||
|
||||
def test_initialization(self, strategy):
|
||||
"""Test strategy initialization."""
|
||||
assert strategy.rsi_period == 14
|
||||
assert strategy.macd_fast == 12
|
||||
assert strategy.ma_fast == 10
|
||||
assert strategy.min_confirmations == 2
|
||||
|
||||
def test_on_tick_insufficient_data(self, strategy):
|
||||
"""Test that strategy returns None with insufficient data."""
|
||||
signal = strategy.on_tick(
|
||||
symbol="BTC/USD",
|
||||
price=Decimal("50000"),
|
||||
timeframe="1h",
|
||||
data={'volume': 1000}
|
||||
)
|
||||
assert signal is None
|
||||
|
||||
def test_min_confirmations_requirement(self, strategy):
|
||||
"""Test that signal requires minimum confirmations."""
|
||||
# This would require actual price history to generate real signals
|
||||
# For now, we test the structure
|
||||
assert strategy.min_confirmations == 2
|
||||
assert strategy.require_rsi is True
|
||||
assert strategy.require_macd is True
|
||||
assert strategy.require_ma is True
|
||||
|
||||
def test_strategy_metadata(self, strategy):
|
||||
"""Test strategy metadata."""
|
||||
assert strategy.name == "test_confirmed"
|
||||
assert strategy.enabled is False
|
||||
|
||||
53
tests/unit/strategies/test_consensus_strategy.py
Normal file
53
tests/unit/strategies/test_consensus_strategy.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Tests for Consensus (ensemble) strategy."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from src.strategies.ensemble.consensus_strategy import ConsensusStrategy
|
||||
from src.strategies.base import SignalType
|
||||
|
||||
|
||||
class TestConsensusStrategy:
|
||||
"""Tests for ConsensusStrategy."""
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create Consensus strategy instance."""
|
||||
return ConsensusStrategy(
|
||||
name="test_consensus",
|
||||
parameters={
|
||||
'strategy_names': ['rsi', 'macd'],
|
||||
'min_consensus': 2,
|
||||
'use_weights': True,
|
||||
'min_weight': 0.3,
|
||||
'exclude_strategies': []
|
||||
}
|
||||
)
|
||||
|
||||
def test_initialization(self, strategy):
|
||||
"""Test strategy initialization."""
|
||||
assert strategy.min_consensus == 2
|
||||
assert strategy.use_weights is True
|
||||
assert strategy.min_weight == 0.3
|
||||
|
||||
def test_on_tick_no_strategies(self, strategy):
|
||||
"""Test that strategy handles empty strategy list."""
|
||||
# Strategy should handle cases where no strategies are available
|
||||
signal = strategy.on_tick(
|
||||
symbol="BTC/USD",
|
||||
price=Decimal("50000"),
|
||||
timeframe="1h",
|
||||
data={'volume': 1000}
|
||||
)
|
||||
# May return None if no strategies available or no consensus
|
||||
assert signal is None or isinstance(signal, (type(None), object))
|
||||
|
||||
def test_strategy_metadata(self, strategy):
|
||||
"""Test strategy metadata."""
|
||||
assert strategy.name == "test_consensus"
|
||||
assert strategy.enabled is False
|
||||
|
||||
def test_consensus_calculation(self, strategy):
|
||||
"""Test consensus calculation parameters."""
|
||||
assert strategy.min_consensus == 2
|
||||
assert strategy.use_weights is True
|
||||
|
||||
52
tests/unit/strategies/test_dca_strategy.py
Normal file
52
tests/unit/strategies/test_dca_strategy.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Tests for DCA strategy."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
from src.strategies.dca.dca_strategy import DCAStrategy
|
||||
|
||||
|
||||
def test_dca_strategy_initialization():
|
||||
"""Test DCA strategy initializes correctly."""
|
||||
strategy = DCAStrategy("Test DCA", {"amount": 10, "interval": "daily"})
|
||||
assert strategy.name == "Test DCA"
|
||||
assert strategy.amount == Decimal("10")
|
||||
assert strategy.interval == "daily"
|
||||
|
||||
|
||||
def test_dca_daily_interval():
|
||||
"""Test DCA with daily interval."""
|
||||
strategy = DCAStrategy("Daily DCA", {"amount": 10, "interval": "daily"})
|
||||
assert strategy.interval_delta == timedelta(days=1)
|
||||
|
||||
|
||||
def test_dca_weekly_interval():
|
||||
"""Test DCA with weekly interval."""
|
||||
strategy = DCAStrategy("Weekly DCA", {"amount": 10, "interval": "weekly"})
|
||||
assert strategy.interval_delta == timedelta(weeks=1)
|
||||
|
||||
|
||||
def test_dca_monthly_interval():
|
||||
"""Test DCA with monthly interval."""
|
||||
strategy = DCAStrategy("Monthly DCA", {"amount": 10, "interval": "monthly"})
|
||||
assert strategy.interval_delta == timedelta(days=30)
|
||||
|
||||
|
||||
def test_dca_signal_generation():
|
||||
"""Test DCA generates buy signals."""
|
||||
strategy = DCAStrategy("Test DCA", {"amount": 10, "interval": "daily"})
|
||||
strategy.last_purchase_time = None
|
||||
|
||||
signal = strategy.on_tick("BTC/USD", Decimal("100"), "1h", {})
|
||||
assert signal is not None
|
||||
assert signal.signal_type.value == "buy"
|
||||
assert signal.quantity == Decimal("0.1") # 10 / 100
|
||||
|
||||
|
||||
def test_dca_interval_respect():
|
||||
"""Test DCA respects interval timing."""
|
||||
strategy = DCAStrategy("Test DCA", {"amount": 10, "interval": "daily"})
|
||||
strategy.last_purchase_time = datetime.utcnow() - timedelta(hours=12)
|
||||
|
||||
signal = strategy.on_tick("BTC/USD", Decimal("100"), "1h", {})
|
||||
assert signal is None # Should not generate signal yet
|
||||
59
tests/unit/strategies/test_divergence_strategy.py
Normal file
59
tests/unit/strategies/test_divergence_strategy.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Tests for Divergence strategy."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from src.strategies.technical.divergence_strategy import DivergenceStrategy
|
||||
from src.strategies.base import SignalType
|
||||
|
||||
|
||||
class TestDivergenceStrategy:
|
||||
"""Tests for DivergenceStrategy."""
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create Divergence strategy instance."""
|
||||
return DivergenceStrategy(
|
||||
name="test_divergence",
|
||||
parameters={
|
||||
'indicator_type': 'rsi',
|
||||
'rsi_period': 14,
|
||||
'lookback': 20,
|
||||
'min_swings': 2,
|
||||
'min_confidence': 0.5
|
||||
}
|
||||
)
|
||||
|
||||
def test_initialization(self, strategy):
|
||||
"""Test strategy initialization."""
|
||||
assert strategy.indicator_type == 'rsi'
|
||||
assert strategy.rsi_period == 14
|
||||
assert strategy.lookback == 20
|
||||
assert strategy.min_swings == 2
|
||||
assert strategy.min_confidence == 0.5
|
||||
|
||||
def test_on_tick_insufficient_data(self, strategy):
|
||||
"""Test that strategy returns None with insufficient data."""
|
||||
signal = strategy.on_tick(
|
||||
symbol="BTC/USD",
|
||||
price=Decimal("50000"),
|
||||
timeframe="1h",
|
||||
data={'volume': 1000}
|
||||
)
|
||||
assert signal is None
|
||||
|
||||
def test_indicator_type_selection(self, strategy):
|
||||
"""Test indicator type selection."""
|
||||
assert strategy.indicator_type == 'rsi'
|
||||
|
||||
# Test MACD indicator type
|
||||
macd_strategy = DivergenceStrategy(
|
||||
name="test_divergence_macd",
|
||||
parameters={'indicator_type': 'macd'}
|
||||
)
|
||||
assert macd_strategy.indicator_type == 'macd'
|
||||
|
||||
def test_strategy_metadata(self, strategy):
|
||||
"""Test strategy metadata."""
|
||||
assert strategy.name == "test_divergence"
|
||||
assert strategy.enabled is False
|
||||
|
||||
69
tests/unit/strategies/test_grid_strategy.py
Normal file
69
tests/unit/strategies/test_grid_strategy.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Tests for Grid strategy."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from src.strategies.grid.grid_strategy import GridStrategy
|
||||
|
||||
|
||||
def test_grid_strategy_initialization():
|
||||
"""Test Grid strategy initializes correctly."""
|
||||
strategy = GridStrategy("Test Grid", {
|
||||
"grid_spacing": 1,
|
||||
"num_levels": 10,
|
||||
"profit_target": 2
|
||||
})
|
||||
assert strategy.name == "Test Grid"
|
||||
assert strategy.grid_spacing == Decimal("0.01")
|
||||
assert strategy.num_levels == 10
|
||||
|
||||
|
||||
def test_grid_levels_calculation():
|
||||
"""Test grid levels are calculated correctly."""
|
||||
strategy = GridStrategy("Test Grid", {
|
||||
"grid_spacing": 1,
|
||||
"num_levels": 5,
|
||||
"center_price": 100
|
||||
})
|
||||
|
||||
strategy._update_grid_levels(Decimal("100"))
|
||||
assert len(strategy.buy_levels) == 5
|
||||
assert len(strategy.sell_levels) == 5
|
||||
|
||||
# Buy levels should be below center
|
||||
assert all(level < Decimal("100") for level in strategy.buy_levels)
|
||||
# Sell levels should be above center
|
||||
assert all(level > Decimal("100") for level in strategy.sell_levels)
|
||||
|
||||
|
||||
def test_grid_buy_signal():
|
||||
"""Test grid generates buy signal at lower level."""
|
||||
strategy = GridStrategy("Test Grid", {
|
||||
"grid_spacing": 1,
|
||||
"num_levels": 5,
|
||||
"center_price": 100,
|
||||
"position_size": Decimal("0.1")
|
||||
})
|
||||
|
||||
# Price at buy level
|
||||
signal = strategy.on_tick("BTC/USD", Decimal("99"), "1h", {})
|
||||
assert signal is not None
|
||||
assert signal.signal_type.value == "buy"
|
||||
|
||||
|
||||
def test_grid_profit_taking():
|
||||
"""Test grid takes profit at target."""
|
||||
strategy = GridStrategy("Test Grid", {
|
||||
"grid_spacing": 1,
|
||||
"num_levels": 5,
|
||||
"profit_target": 2
|
||||
})
|
||||
|
||||
# Simulate position
|
||||
entry_price = Decimal("100")
|
||||
strategy.positions[entry_price] = Decimal("0.1")
|
||||
|
||||
# Price with profit
|
||||
signal = strategy.on_tick("BTC/USD", Decimal("102"), "1h", {})
|
||||
assert signal is not None
|
||||
assert signal.signal_type.value == "sell"
|
||||
assert entry_price not in strategy.positions # Position removed
|
||||
45
tests/unit/strategies/test_macd_strategy.py
Normal file
45
tests/unit/strategies/test_macd_strategy.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Tests for MACD strategy."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from src.strategies.technical.macd_strategy import MACDStrategy
|
||||
|
||||
|
||||
class TestMACDStrategy:
|
||||
"""Tests for MACDStrategy."""
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create MACD strategy instance."""
|
||||
return MACDStrategy(
|
||||
strategy_id=1,
|
||||
name="test_macd",
|
||||
symbol="BTC/USD",
|
||||
timeframe="1h",
|
||||
parameters={
|
||||
"fast_period": 12,
|
||||
"slow_period": 26,
|
||||
"signal_period": 9
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_macd_strategy_initialization(self, strategy):
|
||||
"""Test MACD strategy initialization."""
|
||||
assert strategy.fast_period == 12
|
||||
assert strategy.slow_period == 26
|
||||
assert strategy.signal_period == 9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_signal(self, strategy):
|
||||
"""Test signal generation."""
|
||||
# Create minimal data
|
||||
data = pd.DataFrame({
|
||||
'close': [100 + i * 0.1 for i in range(50)]
|
||||
})
|
||||
strategy.current_data = data
|
||||
|
||||
signal = await strategy.generate_signal()
|
||||
assert "signal" in signal
|
||||
assert signal["signal"] in ["buy", "sell", "hold"]
|
||||
|
||||
72
tests/unit/strategies/test_momentum_strategy.py
Normal file
72
tests/unit/strategies/test_momentum_strategy.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Tests for Momentum strategy."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
import pandas as pd
|
||||
from src.strategies.momentum.momentum_strategy import MomentumStrategy
|
||||
|
||||
|
||||
def test_momentum_strategy_initialization():
|
||||
"""Test Momentum strategy initializes correctly."""
|
||||
strategy = MomentumStrategy("Test Momentum", {
|
||||
"lookback_period": 20,
|
||||
"momentum_threshold": 0.05
|
||||
})
|
||||
assert strategy.name == "Test Momentum"
|
||||
assert strategy.lookback_period == 20
|
||||
assert strategy.momentum_threshold == Decimal("0.05")
|
||||
|
||||
|
||||
def test_momentum_calculation():
|
||||
"""Test momentum calculation."""
|
||||
strategy = MomentumStrategy("Test", {"lookback_period": 5})
|
||||
|
||||
# Create price history
|
||||
prices = pd.Series([100, 101, 102, 103, 104, 105])
|
||||
momentum = strategy._calculate_momentum(prices)
|
||||
|
||||
# Should be positive (price increased)
|
||||
assert momentum > 0
|
||||
assert momentum == 0.05 # (105 - 100) / 100
|
||||
|
||||
|
||||
def test_momentum_entry_signal():
|
||||
"""Test momentum generates entry signal."""
|
||||
strategy = MomentumStrategy("Test", {
|
||||
"lookback_period": 5,
|
||||
"momentum_threshold": 0.05,
|
||||
"volume_threshold": 1.0
|
||||
})
|
||||
|
||||
# Build price history with momentum
|
||||
for i in range(10):
|
||||
price = 100 + i * 2 # Strong upward momentum
|
||||
volume = 1000 * (1.5 if i >= 5 else 1.0) # Volume increase
|
||||
strategy.on_tick("BTC/USD", Decimal(str(price)), "1h", {"volume": volume})
|
||||
|
||||
# Should generate buy signal
|
||||
signal = strategy.on_tick("BTC/USD", Decimal("120"), "1h", {"volume": 2000})
|
||||
assert signal is not None
|
||||
assert signal.signal_type.value == "buy"
|
||||
assert strategy._in_position == True
|
||||
|
||||
|
||||
def test_momentum_exit_signal():
|
||||
"""Test momentum generates exit signal on reversal."""
|
||||
strategy = MomentumStrategy("Test", {
|
||||
"lookback_period": 5,
|
||||
"exit_threshold": -0.02
|
||||
})
|
||||
|
||||
strategy._in_position = True
|
||||
strategy._entry_price = Decimal("100")
|
||||
|
||||
# Build history with reversal
|
||||
for i in range(10):
|
||||
price = 100 - i # Downward momentum
|
||||
strategy.on_tick("BTC/USD", Decimal(str(price)), "1h", {"volume": 1000})
|
||||
|
||||
signal = strategy.on_tick("BTC/USD", Decimal("90"), "1h", {"volume": 1000})
|
||||
assert signal is not None
|
||||
assert signal.signal_type.value == "sell"
|
||||
assert strategy._in_position == False
|
||||
45
tests/unit/strategies/test_moving_avg_strategy.py
Normal file
45
tests/unit/strategies/test_moving_avg_strategy.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Tests for Moving Average strategy."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from src.strategies.technical.moving_avg_strategy import MovingAverageStrategy
|
||||
|
||||
|
||||
class TestMovingAverageStrategy:
|
||||
"""Tests for MovingAverageStrategy."""
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create Moving Average strategy instance."""
|
||||
return MovingAverageStrategy(
|
||||
strategy_id=1,
|
||||
name="test_ma",
|
||||
symbol="BTC/USD",
|
||||
timeframe="1h",
|
||||
parameters={
|
||||
"short_period": 10,
|
||||
"long_period": 30,
|
||||
"ma_type": "SMA"
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ma_strategy_initialization(self, strategy):
|
||||
"""Test Moving Average strategy initialization."""
|
||||
assert strategy.short_period == 10
|
||||
assert strategy.long_period == 30
|
||||
assert strategy.ma_type == "SMA"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_signal(self, strategy):
|
||||
"""Test signal generation."""
|
||||
# Create data with trend
|
||||
data = pd.DataFrame({
|
||||
'close': [100 + i * 0.5 for i in range(50)]
|
||||
})
|
||||
strategy.current_data = data
|
||||
|
||||
signal = await strategy.generate_signal()
|
||||
assert "signal" in signal
|
||||
assert signal["signal"] in ["buy", "sell", "hold"]
|
||||
|
||||
89
tests/unit/strategies/test_pairs_trading.py
Normal file
89
tests/unit/strategies/test_pairs_trading.py
Normal file
@@ -0,0 +1,89 @@
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from decimal import Decimal
|
||||
from src.strategies.technical.pairs_trading import PairsTradingStrategy
|
||||
from src.strategies.base import SignalType
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pricing_service():
|
||||
service = MagicMock()
|
||||
service.get_ohlcv = MagicMock()
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(mock_pricing_service):
|
||||
with patch('src.strategies.technical.pairs_trading.get_pricing_service', return_value=mock_pricing_service):
|
||||
params = {
|
||||
'second_symbol': 'AVAX/USD',
|
||||
'lookback_period': 5,
|
||||
'z_score_threshold': 1.5,
|
||||
'symbol': 'SOL/USD'
|
||||
}
|
||||
strat = PairsTradingStrategy("test_pairs", params)
|
||||
strat.enabled = True
|
||||
return strat
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pairs_trading_short_spread_signal(strategy, mock_pricing_service):
|
||||
# Setup Data
|
||||
# Scenario: SOL (A) pumps relative to AVAX (B) -> Spread widens -> Z-Score High -> Sell A / Buy B
|
||||
|
||||
# Prices for A (SOL): 100, 100, 100, 100, 120 (Pump)
|
||||
ohlcv_a = [
|
||||
[0, 100, 100, 100, 100, 1000],
|
||||
[0, 100, 100, 100, 100, 1000],
|
||||
[0, 100, 100, 100, 100, 1000],
|
||||
[0, 100, 100, 100, 100, 1000],
|
||||
[0, 120, 120, 120, 120, 1000],
|
||||
]
|
||||
|
||||
# Prices for B (AVAX): 25, 25, 25, 25, 25 (Flat)
|
||||
ohlcv_b = [
|
||||
[0, 25, 25, 25, 25, 1000],
|
||||
[0, 25, 25, 25, 25, 1000],
|
||||
[0, 25, 25, 25, 25, 1000],
|
||||
[0, 25, 25, 25, 25, 1000],
|
||||
[0, 25, 25, 25, 25, 1000],
|
||||
]
|
||||
|
||||
# Spread: 4, 4, 4, 4, 4.8
|
||||
# Mean: 4.16, StdDev: approx small but let's see.
|
||||
# Actually StdDev will be non-zero because of the last value.
|
||||
|
||||
mock_pricing_service.get_ohlcv.side_effect = [ohlcv_a, ohlcv_b]
|
||||
|
||||
# Execute
|
||||
signal = await strategy.on_tick("SOL/USD", Decimal(120), "1h", {})
|
||||
|
||||
# Verify
|
||||
assert signal is not None
|
||||
assert signal.signal_type == SignalType.SELL # Sell Primary (SOL)
|
||||
assert signal.metadata['secondary_action'] == 'buy' # Buy Secondary (AVAX)
|
||||
assert signal.metadata['z_score'] > 1.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pairs_trading_long_spread_signal(strategy, mock_pricing_service):
|
||||
# Scenario: SOL (A) dumps -> Spread drops -> Z-Score Low -> Buy A / Sell B
|
||||
|
||||
ohlcv_a = [
|
||||
[0, 100, 100, 100, 100, 1000],
|
||||
[0, 100, 100, 100, 100, 1000],
|
||||
[0, 100, 100, 100, 100, 1000],
|
||||
[0, 100, 100, 100, 100, 1000],
|
||||
[0, 80, 80, 80, 80, 1000], # Dump
|
||||
]
|
||||
ohlcv_b = [
|
||||
[0, 25, 25, 25, 25, 1000] for _ in range(5)
|
||||
]
|
||||
|
||||
mock_pricing_service.get_ohlcv.side_effect = [ohlcv_a, ohlcv_b]
|
||||
|
||||
signal = await strategy.on_tick("SOL/USD", Decimal(80), "1h", {})
|
||||
|
||||
assert signal is not None
|
||||
assert signal.signal_type == SignalType.BUY # Buy Primary
|
||||
assert signal.metadata['secondary_action'] == 'sell' # Sell Secondary
|
||||
assert signal.metadata['z_score'] < -1.5
|
||||
67
tests/unit/strategies/test_rsi_strategy.py
Normal file
67
tests/unit/strategies/test_rsi_strategy.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Tests for RSI strategy."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from src.strategies.technical.rsi_strategy import RSIStrategy
|
||||
|
||||
|
||||
class TestRSIStrategy:
|
||||
"""Tests for RSIStrategy."""
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create RSI strategy instance."""
|
||||
return RSIStrategy(
|
||||
strategy_id=1,
|
||||
name="test_rsi",
|
||||
symbol="BTC/USD",
|
||||
timeframe="1h",
|
||||
parameters={
|
||||
"rsi_period": 14,
|
||||
"overbought": 70,
|
||||
"oversold": 30
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample price data."""
|
||||
dates = pd.date_range(start='2025-01-01', periods=50, freq='1H')
|
||||
# Create data with clear trend for RSI calculation
|
||||
prices = [100 - i * 0.5 for i in range(50)] # Downward trend
|
||||
return pd.DataFrame({
|
||||
'timestamp': dates,
|
||||
'open': prices,
|
||||
'high': [p + 1 for p in prices],
|
||||
'low': [p - 1 for p in prices],
|
||||
'close': prices,
|
||||
'volume': [1000.0] * 50
|
||||
})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rsi_strategy_initialization(self, strategy):
|
||||
"""Test RSI strategy initialization."""
|
||||
assert strategy.rsi_period == 14
|
||||
assert strategy.overbought == 70
|
||||
assert strategy.oversold == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_data(self, strategy, sample_data):
|
||||
"""Test on_data method."""
|
||||
await strategy.on_data(sample_data)
|
||||
assert len(strategy.current_data) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_signal_oversold(self, strategy, sample_data):
|
||||
"""Test signal generation for oversold condition."""
|
||||
await strategy.on_data(sample_data)
|
||||
# Calculate RSI - should be low for downward trend
|
||||
from src.data.indicators import get_indicators
|
||||
indicators = get_indicators()
|
||||
rsi = indicators.rsi(strategy.current_data['close'], period=14)
|
||||
|
||||
# If RSI is low, should generate buy signal
|
||||
signal = await strategy.generate_signal()
|
||||
assert "signal" in signal
|
||||
assert signal["signal"] in ["buy", "sell", "hold"]
|
||||
|
||||
88
tests/unit/strategies/test_trend_filter.py
Normal file
88
tests/unit/strategies/test_trend_filter.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Tests for trend filter functionality."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
|
||||
from src.strategies.technical.rsi_strategy import RSIStrategy
|
||||
|
||||
|
||||
class TestTrendFilter:
|
||||
"""Tests for trend filter in BaseStrategy."""
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self):
|
||||
"""Create strategy instance with trend filter enabled."""
|
||||
strategy = RSIStrategy(
|
||||
name="test_rsi_with_filter",
|
||||
parameters={'use_trend_filter': True}
|
||||
)
|
||||
return strategy
|
||||
|
||||
@pytest.fixture
|
||||
def ohlcv_data(self):
|
||||
"""Create sample OHLCV data."""
|
||||
dates = pd.date_range(start='2025-01-01', periods=50, freq='1H')
|
||||
base_price = 50000
|
||||
return pd.DataFrame({
|
||||
'high': [base_price + 100 + i * 10 for i in range(50)],
|
||||
'low': [base_price - 100 + i * 10 for i in range(50)],
|
||||
'close': [base_price + i * 10 for i in range(50)],
|
||||
'open': [base_price - 50 + i * 10 for i in range(50)],
|
||||
'volume': [1000.0] * 50
|
||||
}, index=dates)
|
||||
|
||||
def test_trend_filter_method_exists(self, strategy):
|
||||
"""Test that apply_trend_filter method exists."""
|
||||
assert hasattr(strategy, 'apply_trend_filter')
|
||||
assert callable(getattr(strategy, 'apply_trend_filter'))
|
||||
|
||||
def test_trend_filter_insufficient_data(self, strategy):
|
||||
"""Test trend filter with insufficient data."""
|
||||
signal = StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol="BTC/USD",
|
||||
strength=0.8,
|
||||
price=Decimal("50000")
|
||||
)
|
||||
|
||||
insufficient_data = pd.DataFrame({
|
||||
'high': [51000],
|
||||
'low': [49000],
|
||||
'close': [50000]
|
||||
})
|
||||
|
||||
# Should allow signal when insufficient data
|
||||
result = strategy.apply_trend_filter(signal, insufficient_data)
|
||||
assert result is not None
|
||||
|
||||
def test_trend_filter_none_data(self, strategy):
|
||||
"""Test trend filter with None data."""
|
||||
signal = StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol="BTC/USD",
|
||||
strength=0.8,
|
||||
price=Decimal("50000")
|
||||
)
|
||||
|
||||
# Should allow signal when no data provided
|
||||
result = strategy.apply_trend_filter(signal, None)
|
||||
assert result is not None
|
||||
|
||||
def test_trend_filter_when_disabled(self, strategy):
|
||||
"""Test that trend filter doesn't filter when disabled."""
|
||||
strategy_no_filter = RSIStrategy(
|
||||
name="test_rsi_no_filter",
|
||||
parameters={'use_trend_filter': False}
|
||||
)
|
||||
|
||||
signal = StrategySignal(
|
||||
signal_type=SignalType.BUY,
|
||||
symbol="BTC/USD",
|
||||
strength=0.8,
|
||||
price=Decimal("50000")
|
||||
)
|
||||
|
||||
result = strategy_no_filter.apply_trend_filter(signal, None)
|
||||
assert result == signal
|
||||
|
||||
284
tests/unit/test_autopilot_training.py
Normal file
284
tests/unit/test_autopilot_training.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""Tests for autopilot model training functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_train_task():
|
||||
"""Mock Celery train_model_task."""
|
||||
with patch('backend.api.autopilot.train_model_task') as mock:
|
||||
mock_result = Mock()
|
||||
mock_result.id = "test-task-id-12345"
|
||||
mock.delay.return_value = mock_result
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_result():
|
||||
"""Mock Celery AsyncResult."""
|
||||
with patch('backend.api.autopilot.AsyncResult') as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_strategy_selector():
|
||||
"""Mock StrategySelector."""
|
||||
selector = Mock()
|
||||
selector.model = Mock()
|
||||
selector.model.is_trained = True
|
||||
selector.model.model_type = "classifier"
|
||||
selector.model.feature_names = ["rsi", "macd", "sma_20"]
|
||||
selector.model.training_metadata = {
|
||||
"trained_at": "2024-01-01T00:00:00",
|
||||
"metrics": {"test_accuracy": 0.85}
|
||||
}
|
||||
selector.get_model_info.return_value = {
|
||||
"is_trained": True,
|
||||
"model_type": "classifier",
|
||||
"available_strategies": ["rsi", "macd", "momentum"],
|
||||
"feature_count": 54,
|
||||
"training_metadata": {
|
||||
"trained_at": "2024-01-01T00:00:00",
|
||||
"metrics": {"test_accuracy": 0.85},
|
||||
"training_symbols": ["BTC/USD", "ETH/USD"]
|
||||
}
|
||||
}
|
||||
return selector
|
||||
|
||||
|
||||
class TestBootstrapConfig:
|
||||
"""Tests for bootstrap configuration endpoints."""
|
||||
|
||||
def test_get_bootstrap_config(self, client):
|
||||
"""Test getting bootstrap configuration."""
|
||||
response = client.get("/api/autopilot/bootstrap-config")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify required fields exist
|
||||
assert "days" in data
|
||||
assert "timeframe" in data
|
||||
assert "min_samples_per_strategy" in data
|
||||
assert "symbols" in data
|
||||
|
||||
# Verify types
|
||||
assert isinstance(data["days"], int)
|
||||
assert isinstance(data["timeframe"], str)
|
||||
assert isinstance(data["min_samples_per_strategy"], int)
|
||||
assert isinstance(data["symbols"], list)
|
||||
|
||||
def test_update_bootstrap_config(self, client):
|
||||
"""Test updating bootstrap configuration."""
|
||||
new_config = {
|
||||
"days": 365,
|
||||
"timeframe": "4h",
|
||||
"min_samples_per_strategy": 50,
|
||||
"symbols": ["BTC/USD", "ETH/USD", "SOL/USD"]
|
||||
}
|
||||
|
||||
response = client.put(
|
||||
"/api/autopilot/bootstrap-config",
|
||||
json=new_config
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "success"
|
||||
|
||||
# Verify the config was updated
|
||||
response = client.get("/api/autopilot/bootstrap-config")
|
||||
data = response.json()
|
||||
assert data["days"] == 365
|
||||
assert data["timeframe"] == "4h"
|
||||
assert data["min_samples_per_strategy"] == 50
|
||||
assert "SOL/USD" in data["symbols"]
|
||||
|
||||
|
||||
class TestModelTraining:
|
||||
"""Tests for model training endpoints."""
|
||||
|
||||
def test_trigger_retrain(self, client, mock_train_task):
|
||||
"""Test triggering model retraining."""
|
||||
response = client.post("/api/autopilot/intelligent/retrain?force=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "queued"
|
||||
assert "task_id" in data
|
||||
assert data["task_id"] == "test-task-id-12345"
|
||||
|
||||
# Verify task was called with correct parameters
|
||||
mock_train_task.delay.assert_called_once()
|
||||
call_kwargs = mock_train_task.delay.call_args.kwargs
|
||||
assert call_kwargs["force_retrain"] is True
|
||||
assert call_kwargs["bootstrap"] is True
|
||||
assert "symbols" in call_kwargs
|
||||
assert "days" in call_kwargs
|
||||
assert "timeframe" in call_kwargs
|
||||
assert "min_samples_per_strategy" in call_kwargs
|
||||
|
||||
def test_get_task_status_pending(self, client, mock_async_result):
|
||||
"""Test getting status of a pending task."""
|
||||
mock_result = Mock()
|
||||
mock_result.status = "PENDING"
|
||||
mock_result.result = None
|
||||
mock_async_result.return_value = mock_result
|
||||
|
||||
response = client.get("/api/autopilot/tasks/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "PENDING"
|
||||
|
||||
def test_get_task_status_progress(self, client, mock_async_result):
|
||||
"""Test getting status of a task in progress."""
|
||||
mock_result = Mock()
|
||||
mock_result.status = "PROGRESS"
|
||||
mock_result.result = None
|
||||
mock_result.info = {
|
||||
"step": "fetching",
|
||||
"progress": 50,
|
||||
"message": "Fetching BTC/USD data..."
|
||||
}
|
||||
mock_async_result.return_value = mock_result
|
||||
|
||||
response = client.get("/api/autopilot/tasks/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "PROGRESS"
|
||||
assert data["meta"]["progress"] == 50
|
||||
|
||||
def test_get_task_status_success(self, client, mock_async_result):
|
||||
"""Test getting status of a successful task."""
|
||||
mock_result = Mock()
|
||||
mock_result.status = "SUCCESS"
|
||||
mock_result.result = {
|
||||
"train_accuracy": 0.85,
|
||||
"test_accuracy": 0.78,
|
||||
"n_samples": 1000,
|
||||
"best_model": "xgboost"
|
||||
}
|
||||
mock_async_result.return_value = mock_result
|
||||
|
||||
response = client.get("/api/autopilot/tasks/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "SUCCESS"
|
||||
assert data["result"]["best_model"] == "xgboost"
|
||||
|
||||
def test_get_task_status_failure(self, client, mock_async_result):
|
||||
"""Test getting status of a failed task."""
|
||||
mock_result = Mock()
|
||||
mock_result.status = "FAILURE"
|
||||
mock_result.result = Exception("Training failed: insufficient data")
|
||||
mock_result.traceback = "Traceback (most recent call last)..."
|
||||
mock_async_result.return_value = mock_result
|
||||
|
||||
response = client.get("/api/autopilot/tasks/test-task-id")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "FAILURE"
|
||||
assert "error" in data["result"]
|
||||
|
||||
|
||||
class TestModelInfo:
|
||||
"""Tests for model info endpoint."""
|
||||
|
||||
@patch('backend.api.autopilot.get_strategy_selector')
|
||||
def test_get_model_info_trained(self, mock_get_selector, client, mock_strategy_selector):
|
||||
"""Test getting info for a trained model."""
|
||||
mock_get_selector.return_value = mock_strategy_selector
|
||||
|
||||
response = client.get("/api/autopilot/intelligent/model-info")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_trained"] is True
|
||||
assert "available_strategies" in data
|
||||
assert "feature_count" in data
|
||||
|
||||
@patch('backend.api.autopilot.get_strategy_selector')
|
||||
def test_get_model_info_untrained(self, mock_get_selector, client):
|
||||
"""Test getting info for an untrained model."""
|
||||
mock_selector = Mock()
|
||||
mock_selector.get_model_info.return_value = {
|
||||
"is_trained": False,
|
||||
"model_type": "classifier",
|
||||
"available_strategies": ["rsi", "macd"],
|
||||
"feature_count": 0
|
||||
}
|
||||
mock_get_selector.return_value = mock_selector
|
||||
|
||||
response = client.get("/api/autopilot/intelligent/model-info")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_trained"] is False
|
||||
assert data["feature_count"] == 0
|
||||
|
||||
|
||||
class TestModelReset:
|
||||
"""Tests for model reset endpoint."""
|
||||
|
||||
@patch('backend.api.autopilot.get_strategy_selector')
|
||||
def test_reset_model(self, mock_get_selector, client):
|
||||
"""Test resetting the model."""
|
||||
mock_selector = Mock()
|
||||
mock_selector.reset_model = AsyncMock(return_value={"status": "success"})
|
||||
mock_get_selector.return_value = mock_selector
|
||||
|
||||
response = client.post("/api/autopilot/intelligent/reset")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTrainingWorkflow:
|
||||
"""Integration tests for the complete training workflow."""
|
||||
|
||||
@patch('backend.api.autopilot.train_model_task')
|
||||
def test_config_and_retrain_workflow(self, mock_train_task, client):
|
||||
"""Test configure -> train workflow passes config correctly."""
|
||||
# Setup mock
|
||||
mock_task_result = Mock()
|
||||
mock_task_result.id = "test-task-123"
|
||||
mock_train_task.delay.return_value = mock_task_result
|
||||
|
||||
# 1. Configure bootstrap settings with specific values
|
||||
config = {
|
||||
"days": 180,
|
||||
"timeframe": "4h",
|
||||
"min_samples_per_strategy": 25,
|
||||
"symbols": ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"]
|
||||
}
|
||||
response = client.put("/api/autopilot/bootstrap-config", json=config)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 2. Trigger retraining
|
||||
response = client.post("/api/autopilot/intelligent/retrain?force=true")
|
||||
assert response.status_code == 200
|
||||
|
||||
# 3. Verify the task was called with the correct config
|
||||
mock_train_task.delay.assert_called_once()
|
||||
call_kwargs = mock_train_task.delay.call_args.kwargs
|
||||
|
||||
# All config should be passed to the task
|
||||
assert call_kwargs["days"] == 180
|
||||
assert call_kwargs["timeframe"] == "4h"
|
||||
assert call_kwargs["min_samples_per_strategy"] == 25
|
||||
assert call_kwargs["symbols"] == ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"]
|
||||
assert call_kwargs["force_retrain"] is True
|
||||
assert call_kwargs["bootstrap"] is True
|
||||
2
tests/unit/trading/__init__.py
Normal file
2
tests/unit/trading/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for trading engine."""
|
||||
|
||||
88
tests/unit/trading/test_engine.py
Normal file
88
tests/unit/trading/test_engine.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Tests for trading engine."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from decimal import Decimal
|
||||
from src.trading.engine import get_trading_engine, TradingEngine
|
||||
from src.core.database import OrderSide, OrderType, OrderStatus
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTradingEngine:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_data(self, db_session):
|
||||
"""Setup test data."""
|
||||
from src.core.database import Exchange
|
||||
from sqlalchemy import select
|
||||
|
||||
# Check if exchange exists
|
||||
result = await db_session.execute(select(Exchange).where(Exchange.id == 1))
|
||||
if not result.scalar_one_or_none():
|
||||
try:
|
||||
# Try enabled (new schema)
|
||||
exchange = Exchange(id=1, name="coinbase", enabled=True)
|
||||
db_session.add(exchange)
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
# Fallback if I was wrong about schema, but I checked it.
|
||||
await db_session.rollback()
|
||||
exchange = Exchange(id=1, name="coinbase")
|
||||
db_session.add(exchange)
|
||||
print("Exchange created in setup_data")
|
||||
|
||||
async def test_execute_order_paper_trading(self, mock_exchange_adapter):
|
||||
"""Test executing paper trading order."""
|
||||
engine = TradingEngine()
|
||||
|
||||
# Mock dependencies
|
||||
engine.get_exchange_adapter = AsyncMock(return_value=mock_exchange_adapter)
|
||||
engine.risk_manager.check_order_risk = Mock(return_value=(True, ""))
|
||||
engine.paper_trading.get_balance = Mock(return_value=Decimal("10000"))
|
||||
engine.paper_trading.execute_order = Mock(return_value=True)
|
||||
# Mock logger
|
||||
engine.logger = Mock()
|
||||
|
||||
order = await engine.execute_order(
|
||||
exchange_id=1,
|
||||
strategy_id=None,
|
||||
symbol="BTC/USD",
|
||||
side=OrderSide.BUY,
|
||||
order_type=OrderType.LIMIT,
|
||||
quantity=Decimal("0.1"),
|
||||
price=Decimal("50000"),
|
||||
paper_trading=True
|
||||
)
|
||||
|
||||
# Check if error occurred
|
||||
if order is None:
|
||||
engine.logger.error.assert_called()
|
||||
call_args = engine.logger.error.call_args
|
||||
print(f"\nCaught exception in engine: {call_args}")
|
||||
|
||||
assert order is not None
|
||||
assert order.symbol == "BTC/USD"
|
||||
assert order.quantity == Decimal("0.1")
|
||||
# Status might be PENDING/OPEN depending on implementation
|
||||
assert order.status in [OrderStatus.PENDING, OrderStatus.OPEN]
|
||||
|
||||
async def test_execute_order_live_trading(self, mock_exchange_adapter):
|
||||
"""Test executing live trading order."""
|
||||
engine = TradingEngine()
|
||||
|
||||
# Mock dependencies
|
||||
engine.get_exchange_adapter = AsyncMock(return_value=mock_exchange_adapter)
|
||||
engine.risk_manager.check_order_risk = Mock(return_value=(True, ""))
|
||||
|
||||
order = await engine.execute_order(
|
||||
exchange_id=1,
|
||||
strategy_id=None,
|
||||
symbol="BTC/USD",
|
||||
side=OrderSide.BUY,
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=Decimal("0.1"),
|
||||
paper_trading=False
|
||||
)
|
||||
|
||||
assert order is not None
|
||||
assert order.paper_trading is False
|
||||
mock_exchange_adapter.place_order.assert_called_once()
|
||||
161
tests/unit/trading/test_fee_calculator.py
Normal file
161
tests/unit/trading/test_fee_calculator.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Tests for fee calculator functionality."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch
|
||||
from src.trading.fee_calculator import FeeCalculator, get_fee_calculator
|
||||
from src.core.database import OrderType
|
||||
|
||||
|
||||
class TestFeeCalculator:
|
||||
"""Tests for FeeCalculator class."""
|
||||
|
||||
@pytest.fixture
|
||||
def calculator(self):
|
||||
"""Create fee calculator instance."""
|
||||
return FeeCalculator()
|
||||
|
||||
def test_get_fee_calculator_singleton(self):
|
||||
"""Test that get_fee_calculator returns singleton."""
|
||||
calc1 = get_fee_calculator()
|
||||
calc2 = get_fee_calculator()
|
||||
assert calc1 is calc2
|
||||
|
||||
def test_calculate_fee_basic(self, calculator):
|
||||
"""Test basic fee calculation."""
|
||||
fee = calculator.calculate_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('100.0'),
|
||||
order_type=OrderType.MARKET
|
||||
)
|
||||
assert fee > 0
|
||||
assert isinstance(fee, Decimal)
|
||||
|
||||
def test_calculate_fee_zero_quantity(self, calculator):
|
||||
"""Test fee calculation with zero quantity."""
|
||||
fee = calculator.calculate_fee(
|
||||
quantity=Decimal('0'),
|
||||
price=Decimal('100.0'),
|
||||
order_type=OrderType.MARKET
|
||||
)
|
||||
assert fee == Decimal('0')
|
||||
|
||||
def test_calculate_fee_maker_vs_taker(self, calculator):
|
||||
"""Test maker fees are typically lower than taker fees."""
|
||||
maker_fee = calculator.calculate_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
order_type=OrderType.LIMIT,
|
||||
is_maker=True
|
||||
)
|
||||
taker_fee = calculator.calculate_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
order_type=OrderType.MARKET,
|
||||
is_maker=False
|
||||
)
|
||||
# Maker fees should be <= taker fees
|
||||
assert maker_fee <= taker_fee
|
||||
|
||||
|
||||
class TestFeeCalculatorPaperTrading:
|
||||
"""Tests for paper trading fee calculation."""
|
||||
|
||||
@pytest.fixture
|
||||
def calculator(self):
|
||||
"""Create fee calculator instance."""
|
||||
return FeeCalculator()
|
||||
|
||||
def test_get_fee_structure_by_exchange_name_coinbase(self, calculator):
|
||||
"""Test getting Coinbase fee structure."""
|
||||
with patch.object(calculator, 'config') as mock_config:
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
'trading.default_fees': {'maker': 0.001, 'taker': 0.001},
|
||||
'trading.exchanges.coinbase.fees': {'maker': 0.004, 'taker': 0.006}
|
||||
}.get(key, default)
|
||||
|
||||
fees = calculator.get_fee_structure_by_exchange_name('coinbase')
|
||||
assert fees['maker'] == 0.004
|
||||
assert fees['taker'] == 0.006
|
||||
|
||||
def test_get_fee_structure_by_exchange_name_unknown(self, calculator):
|
||||
"""Test getting fee structure for unknown exchange returns defaults."""
|
||||
with patch.object(calculator, 'config') as mock_config:
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
'trading.default_fees': {'maker': 0.001, 'taker': 0.001},
|
||||
'trading.exchanges.unknown.fees': None
|
||||
}.get(key, default)
|
||||
|
||||
fees = calculator.get_fee_structure_by_exchange_name('unknown')
|
||||
assert fees['maker'] == 0.001
|
||||
assert fees['taker'] == 0.001
|
||||
|
||||
def test_calculate_fee_for_paper_trading(self, calculator):
|
||||
"""Test paper trading fee calculation."""
|
||||
with patch.object(calculator, 'config') as mock_config:
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
'paper_trading.fee_exchange': 'coinbase',
|
||||
'trading.exchanges.coinbase.fees': {'maker': 0.004, 'taker': 0.006},
|
||||
'trading.default_fees': {'maker': 0.001, 'taker': 0.001}
|
||||
}.get(key, default)
|
||||
|
||||
fee = calculator.calculate_fee_for_paper_trading(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
order_type=OrderType.MARKET,
|
||||
is_maker=False
|
||||
)
|
||||
# Taker fee at 0.6% of $1000 = $6
|
||||
expected = Decimal('1000.0') * Decimal('0.006')
|
||||
assert fee == expected
|
||||
|
||||
def test_calculate_fee_for_paper_trading_zero(self, calculator):
|
||||
"""Test paper trading fee with zero values."""
|
||||
fee = calculator.calculate_fee_for_paper_trading(
|
||||
quantity=Decimal('0'),
|
||||
price=Decimal('100.0'),
|
||||
order_type=OrderType.MARKET
|
||||
)
|
||||
assert fee == Decimal('0')
|
||||
|
||||
|
||||
class TestRoundTripFees:
|
||||
"""Tests for round-trip fee calculations."""
|
||||
|
||||
@pytest.fixture
|
||||
def calculator(self):
|
||||
"""Create fee calculator instance."""
|
||||
return FeeCalculator()
|
||||
|
||||
def test_estimate_round_trip_fee(self, calculator):
|
||||
"""Test round-trip fee estimation."""
|
||||
fee = calculator.estimate_round_trip_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0')
|
||||
)
|
||||
# Round-trip should include buy and sell fees
|
||||
assert fee > 0
|
||||
|
||||
single_fee = calculator.calculate_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
order_type=OrderType.MARKET
|
||||
)
|
||||
# Round-trip should be approximately 2x single fee
|
||||
assert fee >= single_fee
|
||||
|
||||
def test_get_minimum_profit_threshold(self, calculator):
|
||||
"""Test minimum profit threshold calculation."""
|
||||
threshold = calculator.get_minimum_profit_threshold(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
multiplier=2.0
|
||||
)
|
||||
assert threshold > 0
|
||||
|
||||
round_trip_fee = calculator.estimate_round_trip_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0')
|
||||
)
|
||||
# Threshold should be 2x round-trip fee
|
||||
assert threshold == round_trip_fee * Decimal('2.0')
|
||||
70
tests/unit/trading/test_order_manager.py
Normal file
70
tests/unit/trading/test_order_manager.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Tests for order manager."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from src.trading.order_manager import get_order_manager, OrderManager
|
||||
|
||||
|
||||
class TestOrderManager:
|
||||
"""Tests for OrderManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def order_manager(self, mock_database):
|
||||
"""Create order manager instance."""
|
||||
return get_order_manager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_order(self, order_manager, mock_database):
|
||||
"""Test order creation."""
|
||||
engine, Session = mock_database
|
||||
|
||||
order = await order_manager.create_order(
|
||||
exchange_id=1,
|
||||
strategy_id=1,
|
||||
symbol="BTC/USD",
|
||||
side="buy",
|
||||
order_type="market",
|
||||
amount=0.01,
|
||||
price=50000.0,
|
||||
is_paper_trade=True
|
||||
)
|
||||
|
||||
assert order is not None
|
||||
assert order.symbol == "BTC/USD"
|
||||
assert order.side == "buy"
|
||||
assert order.status == "pending"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_order_status(self, order_manager, mock_database):
|
||||
"""Test order status update."""
|
||||
engine, Session = mock_database
|
||||
|
||||
# Create order first
|
||||
order = await order_manager.create_order(
|
||||
exchange_id=1,
|
||||
strategy_id=1,
|
||||
symbol="BTC/USD",
|
||||
side="buy",
|
||||
order_type="market",
|
||||
amount=0.01,
|
||||
is_paper_trade=True
|
||||
)
|
||||
|
||||
# Update status
|
||||
updated = await order_manager.update_order_status(
|
||||
client_order_id=order.client_order_id,
|
||||
new_status="filled",
|
||||
filled_amount=0.01,
|
||||
cost=500.0
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.status == "filled"
|
||||
|
||||
def test_get_order(self, order_manager, mock_database):
|
||||
"""Test getting order."""
|
||||
# This would require creating an order first
|
||||
# Simplified test
|
||||
order = order_manager.get_order(client_order_id="nonexistent")
|
||||
assert order is None
|
||||
|
||||
32
tests/unit/trading/test_paper_trading.py
Normal file
32
tests/unit/trading/test_paper_trading.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Tests for paper trading simulator."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock
|
||||
from src.trading.paper_trading import get_paper_trading, PaperTradingSimulator
|
||||
from src.core.database import Order, OrderSide, OrderType
|
||||
|
||||
|
||||
class TestPaperTradingSimulator:
|
||||
"""Tests for PaperTradingSimulator."""
|
||||
|
||||
@pytest.fixture
|
||||
def simulator(self):
|
||||
"""Create paper trading simulator."""
|
||||
return PaperTradingSimulator(initial_capital=Decimal('1000.0'))
|
||||
|
||||
def test_initialization(self, simulator):
|
||||
"""Test simulator initialization."""
|
||||
assert simulator.initial_capital == Decimal('1000.0')
|
||||
assert simulator.cash == Decimal('1000.0')
|
||||
|
||||
def test_get_balance(self, simulator):
|
||||
"""Test getting balance."""
|
||||
balance = simulator.get_balance()
|
||||
assert balance == Decimal('1000.0')
|
||||
|
||||
def test_get_positions(self, simulator):
|
||||
"""Test getting positions."""
|
||||
positions = simulator.get_positions()
|
||||
assert isinstance(positions, list)
|
||||
|
||||
1
tests/unit/worker/__init__.py
Normal file
1
tests/unit/worker/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test init file."""
|
||||
178
tests/unit/worker/test_tasks.py
Normal file
178
tests/unit/worker/test_tasks.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Tests for Celery tasks."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
|
||||
|
||||
class TestAsyncToSync:
|
||||
"""Tests for async_to_sync helper."""
|
||||
|
||||
def test_runs_awaitable(self):
|
||||
"""Test that async_to_sync runs awaitable and returns result."""
|
||||
from src.worker.tasks import async_to_sync
|
||||
|
||||
async def async_func():
|
||||
return "test_result"
|
||||
|
||||
result = async_to_sync(async_func())
|
||||
assert result == "test_result"
|
||||
|
||||
def test_handles_exception(self):
|
||||
"""Test that async_to_sync propagates exceptions."""
|
||||
from src.worker.tasks import async_to_sync
|
||||
|
||||
async def async_error():
|
||||
raise ValueError("test error")
|
||||
|
||||
with pytest.raises(ValueError, match="test error"):
|
||||
async_to_sync(async_error())
|
||||
|
||||
|
||||
class TestTrainModelTask:
|
||||
"""Tests for train_model_task."""
|
||||
|
||||
@patch('src.worker.tasks.get_strategy_selector')
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_train_model_basic(self, mock_async_to_sync, mock_get_selector):
|
||||
"""Test basic model training task."""
|
||||
# Setup mocks
|
||||
mock_selector = Mock()
|
||||
mock_selector.bootstrap_symbols = ["BTC/USD"]
|
||||
mock_get_selector.return_value = mock_selector
|
||||
|
||||
mock_async_to_sync.side_effect = [
|
||||
{"X": [1, 2, 3]}, # prepare_training_data result
|
||||
{"accuracy": 0.9} # train_model result
|
||||
]
|
||||
|
||||
from src.worker.tasks import train_model_task
|
||||
|
||||
# Call the task directly - Celery will bind self automatically
|
||||
# For testing, we need to access the underlying function
|
||||
result = train_model_task.run(force_retrain=True, bootstrap=False)
|
||||
|
||||
assert result == {"accuracy": 0.9}
|
||||
mock_get_selector.assert_called_once()
|
||||
|
||||
@patch('src.worker.tasks.get_strategy_selector')
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_train_model_with_bootstrap(self, mock_async_to_sync, mock_get_selector):
|
||||
"""Test model training with bootstrapping."""
|
||||
mock_selector = Mock()
|
||||
mock_selector.bootstrap_symbols = ["BTC/USD", "ETH/USD"]
|
||||
mock_get_selector.return_value = mock_selector
|
||||
|
||||
# First call returns empty data, triggering bootstrap
|
||||
mock_async_to_sync.side_effect = [
|
||||
{"X": []}, # Empty training data
|
||||
{"total_samples": 100}, # First symbol bootstrap
|
||||
{"total_samples": 50}, # Second symbol bootstrap
|
||||
{"accuracy": 0.85} # Final training
|
||||
]
|
||||
|
||||
from src.worker.tasks import train_model_task
|
||||
|
||||
result = train_model_task.run(force_retrain=False, bootstrap=True)
|
||||
|
||||
assert result == {"accuracy": 0.85}
|
||||
|
||||
|
||||
class TestBootstrapTask:
|
||||
"""Tests for bootstrap_task."""
|
||||
|
||||
@patch('src.worker.tasks.get_strategy_selector')
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_bootstrap_basic(self, mock_async_to_sync, mock_get_selector):
|
||||
"""Test basic bootstrap task."""
|
||||
mock_selector = Mock()
|
||||
mock_get_selector.return_value = mock_selector
|
||||
mock_async_to_sync.return_value = {"total_samples": 200}
|
||||
|
||||
from src.worker.tasks import bootstrap_task
|
||||
|
||||
result = bootstrap_task.run(days=90, timeframe="1h")
|
||||
|
||||
assert result == {"total_samples": 200}
|
||||
|
||||
|
||||
class TestGenerateReportTask:
|
||||
"""Tests for generate_report_task."""
|
||||
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_generate_report_unknown_type(self, mock_async_to_sync):
|
||||
"""Test report generation with unknown type."""
|
||||
from src.worker.tasks import generate_report_task
|
||||
|
||||
result = generate_report_task.run("unknown", {})
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "Unknown report type" in result["message"]
|
||||
|
||||
|
||||
class TestOptimizeStrategyTask:
|
||||
"""Tests for optimize_strategy_task."""
|
||||
|
||||
@patch('src.optimization.genetic.GeneticOptimizer')
|
||||
def test_optimize_genetic_basic(self, mock_optimizer_class):
|
||||
"""Test basic genetic optimization."""
|
||||
from src.worker.tasks import optimize_strategy_task
|
||||
|
||||
mock_optimizer = Mock()
|
||||
mock_optimizer.optimize.return_value = {
|
||||
"best_params": {"period": 14},
|
||||
"best_score": 0.85
|
||||
}
|
||||
mock_optimizer_class.return_value = mock_optimizer
|
||||
|
||||
result = optimize_strategy_task.run(
|
||||
strategy_type="rsi",
|
||||
symbol="BTC/USD",
|
||||
param_ranges={"period": (5, 50)},
|
||||
method="genetic",
|
||||
population_size=10,
|
||||
generations=5
|
||||
)
|
||||
|
||||
assert result["best_params"] == {"period": 14}
|
||||
assert result["best_score"] == 0.85
|
||||
|
||||
def test_optimize_unknown_method(self):
|
||||
"""Test optimization with unknown method."""
|
||||
from src.worker.tasks import optimize_strategy_task
|
||||
|
||||
result = optimize_strategy_task.run(
|
||||
strategy_type="rsi",
|
||||
symbol="BTC/USD",
|
||||
param_ranges={"period": (5, 50)},
|
||||
method="unknown_method"
|
||||
)
|
||||
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestExportDataTask:
|
||||
"""Tests for export_data_task."""
|
||||
|
||||
@patch('src.reporting.csv_exporter.get_csv_exporter')
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_export_orders(self, mock_async_to_sync, mock_exporter_func):
|
||||
"""Test order export."""
|
||||
mock_exporter = Mock()
|
||||
mock_exporter.export_orders.return_value = True
|
||||
mock_exporter_func.return_value = mock_exporter
|
||||
mock_async_to_sync.return_value = [] # Empty orders list
|
||||
|
||||
from src.worker.tasks import export_data_task
|
||||
|
||||
result = export_data_task.run("orders", {})
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["export_type"] == "orders"
|
||||
|
||||
def test_export_unknown_type(self):
|
||||
"""Test export with unknown type."""
|
||||
from src.worker.tasks import export_data_task
|
||||
|
||||
result = export_data_task.run("unknown", {})
|
||||
|
||||
assert result["status"] == "error"
|
||||
Reference in New Issue
Block a user