Local changes: Updated model training, removed debug instrumentation, and configuration improvements

This commit is contained in:
kfox
2025-12-26 01:15:43 -05:00
commit cc60da49e7
388 changed files with 57127 additions and 0 deletions

2
tests/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""Test suite for Crypto Trader."""

86
tests/conftest.py Normal file
View File

@@ -0,0 +1,86 @@
"""Pytest configuration and fixtures."""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, PropertyMock
from decimal import Decimal
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from src.core.database import Base, Database, get_database
from sqlalchemy.pool import StaticPool
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def db_engine():
"""Create async database engine."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=False
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
"""Create async database session."""
async_session = async_sessionmaker(bind=db_engine, expire_on_commit=False)
async with async_session() as session:
yield session
@pytest.fixture(autouse=True)
def override_get_database(db_engine, monkeypatch):
"""Override get_database to use test engine."""
test_db = Database()
# We mock the internal attributes to return our test engine/session
test_db.engine = db_engine
test_db.SessionLocal = async_sessionmaker(bind=db_engine, class_=AsyncSession, expire_on_commit=False)
# Patch the global get_database
monkeypatch.setattr("src.core.database._db_instance", test_db)
return test_db
@pytest.fixture
def mock_exchange_adapter():
"""Mock exchange adapter."""
from src.exchanges.base import BaseExchangeAdapter
adapter = AsyncMock(spec=BaseExchangeAdapter)
adapter.get_ticker.return_value = {'last': Decimal("50000")}
adapter.place_order.return_value = {'id': 'test_order_123', 'status': 'open'}
adapter.get_balance.return_value = {'USD': Decimal("10000")}
# Helper methods should be sync mocks
# Note: If extract_fee... is not part of BaseExchangeAdapter spec, we have to attach it manually
# But checking base.py, it likely IS or isn't.
# Safe to attach it manually even with spec if we traverse __dict__ or simply assign.
# However, standard mock might block unknown attribs.
# Actually BaseExchangeAdapter is abstract.
# Let's inspect BaseExchangeAdapter structure if needed.
# For now, let's assume usage of spec is the right direction.
# But if extract_fee... is NOT in BaseExchangeAdapter, we might need to mock a Concrete class like Coinbase
pass
# Better approach: Just Delete get_fee_structure from the mock to ensure AttributeError
adapter = AsyncMock()
del adapter.get_fee_structure
# Wait, AsyncMock creates attrs on access. del might not work if not existing.
# We can se side_effect to raise AttributeError
adapter.get_ticker.return_value = {'last': Decimal("50000")}
adapter.place_order.return_value = {'id': 'test_order_123', 'status': 'open'}
adapter.get_balance.return_value = {'USD': Decimal("10000")}
adapter.extract_fee_from_order_response = Mock(return_value=Decimal("1.0"))
adapter.name = "coinbase" # FeeCalculator accesses .name
type(adapter).get_fee_structure = PropertyMock(side_effect=AttributeError)
# Accessing methods on AsyncMock...
return adapter

2
tests/e2e/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""End-to-end tests."""

View File

@@ -0,0 +1,18 @@
"""End-to-end tests for backtesting."""
import pytest
from datetime import datetime, timedelta
from src.backtesting.engine import get_backtest_engine
@pytest.mark.e2e
class TestBacktestE2E:
"""End-to-end tests for backtesting."""
@pytest.mark.asyncio
async def test_backtest_scenario(self):
"""Test complete backtesting scenario."""
engine = get_backtest_engine()
assert engine is not None
# Full E2E test would require strategy and historical data setup

View File

@@ -0,0 +1,51 @@
"""End-to-end tests for paper trading."""
import pytest
from src.trading.paper_trading import get_paper_trading
from src.trading.engine import get_trading_engine
@pytest.mark.e2e
class TestPaperTradingE2E:
"""End-to-end tests for paper trading."""
@pytest.mark.asyncio
async def test_paper_trading_scenario(self):
"""Test complete paper trading scenario."""
# Initialize
paper_trading = get_paper_trading()
engine = get_trading_engine()
await engine.initialize()
# Place buy order
result1 = await engine.execute_trade(
exchange_name="paper_trading",
strategy_id=1,
symbol="BTC/USD",
side="buy",
order_type="market",
amount=0.01,
price=50000.0,
is_paper_trade=True
)
assert result1 is not None
# Place sell order
result2 = await engine.execute_trade(
exchange_name="paper_trading",
strategy_id=1,
symbol="BTC/USD",
side="sell",
order_type="market",
amount=0.01,
price=51000.0,
is_paper_trade=True
)
assert result2 is not None
# Check balance
balance = paper_trading.get_balance()
assert balance is not None

View File

@@ -0,0 +1,340 @@
"""End-to-end tests for pricing data flow."""
import pytest
import asyncio
from unittest.mock import Mock, patch, AsyncMock
from decimal import Decimal
from datetime import datetime
from src.data.pricing_service import get_pricing_service, PricingService
from src.data.providers.base_provider import BasePricingProvider
class MockProvider(BasePricingProvider):
"""Mock provider for E2E testing."""
def __init__(self, name: str = "MockProvider"):
super().__init__()
self._name = name
self._ticker_data = {
'symbol': 'BTC/USD',
'bid': Decimal('50000'),
'ask': Decimal('50001'),
'last': Decimal('50000.5'),
'high': Decimal('51000'),
'low': Decimal('49000'),
'volume': Decimal('1000000'),
'timestamp': int(datetime.now().timestamp() * 1000),
}
self._ohlcv_data = [
[int(datetime.now().timestamp() * 1000), 50000, 51000, 49000, 50000, 1000],
]
self._callbacks = []
@property
def name(self) -> str:
return self._name
@property
def supports_websocket(self) -> bool:
return False
def connect(self) -> bool:
self._connected = True
return True
def disconnect(self):
self._connected = False
self._callbacks.clear()
def get_ticker(self, symbol: str):
return self._ticker_data.copy()
def get_ohlcv(self, symbol: str, timeframe: str = '1h', since=None, limit: int = 100):
return self._ohlcv_data.copy()
def subscribe_ticker(self, symbol: str, callback) -> bool:
if symbol not in self._subscribers:
self._subscribers[symbol] = []
self._subscribers[symbol].append(callback)
self._callbacks.append((symbol, callback))
# Simulate price update
import threading
def send_update():
import time
time.sleep(0.1)
if callback and symbol in self._subscribers:
callback(self._ticker_data.copy())
thread = threading.Thread(target=send_update, daemon=True)
thread.start()
return True
@pytest.mark.e2e
class TestPricingDataE2E:
"""End-to-end tests for pricing data system."""
@pytest.fixture(autouse=True)
def reset_service(self):
"""Reset pricing service between tests."""
import src.data.pricing_service
src.data.pricing_service._pricing_service = None
yield
src.data.pricing_service._pricing_service = None
@patch('src.data.pricing_service.get_config')
def test_pricing_service_initialization(self, mock_get_config):
"""Test pricing service initializes correctly."""
mock_config = Mock()
mock_config.get = Mock(side_effect=lambda key, default=None: {
"data_providers.primary": [
{"name": "mock", "enabled": True, "priority": 1}
],
"data_providers.fallback": {"enabled": True, "api_key": ""},
"data_providers.caching.ticker_ttl": 2,
"data_providers.caching.ohlcv_ttl": 60,
"data_providers.caching.max_cache_size": 1000,
}.get(key, default))
mock_get_config.return_value = mock_config
# Patch provider initialization to use mock
with patch('src.data.pricing_service.CCXTProvider') as mock_ccxt:
mock_provider = MockProvider("CCXT-Mock")
mock_ccxt.return_value = mock_provider
service = get_pricing_service()
assert service is not None
assert service.cache is not None
assert service.health_monitor is not None
def test_get_ticker_with_failover(self):
"""Test getting ticker with provider failover."""
# Create service with mock providers
service = PricingService()
# Create two providers - one will fail, one will succeed
provider1 = MockProvider("Provider1")
provider1.get_ticker = Mock(side_effect=Exception("Provider1 failed"))
provider1.connect = Mock(return_value=True)
provider2 = MockProvider("Provider2")
provider2.connect = Mock(return_value=True)
service._providers = {"Provider1": provider1, "Provider2": provider2}
service._provider_priority = ["Provider1", "Provider2"]
service._active_provider = "Provider1"
# Get ticker - should failover to Provider2
ticker = service.get_ticker("BTC/USD", use_cache=False)
assert ticker is not None
assert ticker['symbol'] == 'BTC/USD'
assert provider1.get_ticker.called
assert provider2.get_ticker.called
def test_caching_works(self):
"""Test that caching works correctly."""
service = PricingService()
provider = MockProvider()
provider.connect()
service._providers["MockProvider"] = provider
service._active_provider = "MockProvider"
# First call - should hit provider
ticker1 = service.get_ticker("BTC/USD", use_cache=True)
# Modify provider response
provider._ticker_data['last'] = Decimal('60000')
# Second call - should get cached value
ticker2 = service.get_ticker("BTC/USD", use_cache=True)
# Should be same as first call (cached)
assert ticker1['last'] == ticker2['last']
def test_subscription_and_updates(self):
"""Test subscribing to price updates."""
service = PricingService()
provider = MockProvider()
provider.connect()
service._providers["MockProvider"] = provider
service._active_provider = "MockProvider"
received_updates = []
def callback(data):
received_updates.append(data)
# Subscribe
success = service.subscribe_ticker("BTC/USD", callback)
assert success is True
# Wait for update
import time
time.sleep(0.2)
# Should have received at least one update
assert len(received_updates) > 0
assert received_updates[0]['symbol'] == 'BTC/USD'
def test_health_monitoring(self):
"""Test health monitoring tracks provider status."""
service = PricingService()
provider = MockProvider("TestProvider")
provider.connect()
service._providers["TestProvider"] = provider
service._active_provider = "TestProvider"
# Record some operations
service.health_monitor.record_success("TestProvider", 0.1)
service.health_monitor.record_success("TestProvider", 0.2)
service.health_monitor.record_failure("TestProvider")
# Check health
health = service.get_provider_health("TestProvider")
assert health is not None
assert health['success_count'] == 2
assert health['failure_count'] == 1
assert health['avg_response_time'] > 0
def test_provider_priority_selection(self):
"""Test that providers are selected by priority."""
service = PricingService()
provider1 = MockProvider("Provider1")
provider1.connect()
provider2 = MockProvider("Provider2")
provider2.connect()
service._providers = {"Provider1": provider1, "Provider2": provider2}
service._provider_priority = ["Provider1", "Provider2"]
# Select active provider
active = service._select_active_provider()
assert active == "Provider1" # Should select first in priority
def test_cache_stats(self):
"""Test cache statistics are tracked."""
service = PricingService()
# Perform some cache operations
service.cache.set("key1", "value1")
service.cache.get("key1") # Hit
service.cache.get("missing") # Miss
stats = service.get_cache_stats()
assert stats['hits'] >= 1
assert stats['misses'] >= 1
assert stats['size'] >= 1
def test_get_ohlcv_flow(self):
"""Test complete OHLCV data flow."""
service = PricingService()
provider = MockProvider()
provider.connect()
service._providers["MockProvider"] = provider
service._active_provider = "MockProvider"
# Get OHLCV data
ohlcv = service.get_ohlcv("BTC/USD", "1h", limit=10, use_cache=False)
assert len(ohlcv) > 0
assert len(ohlcv[0]) == 6 # timestamp, open, high, low, close, volume
assert ohlcv[0][0] > 0 # Valid timestamp
def test_unsubscribe_ticker(self):
"""Test unsubscribing from ticker updates."""
service = PricingService()
provider = MockProvider()
provider.connect()
service._providers["MockProvider"] = provider
service._active_provider = "MockProvider"
callback = Mock()
# Subscribe
service.subscribe_ticker("BTC/USD", callback)
assert "ticker:BTC/USD" in service._subscriptions
# Unsubscribe
service.unsubscribe_ticker("BTC/USD", callback)
assert "ticker:BTC/USD" not in service._subscriptions
def test_multiple_symbol_subscriptions(self):
"""Test subscribing to multiple symbols."""
service = PricingService()
provider = MockProvider()
provider.connect()
service._providers["MockProvider"] = provider
service._active_provider = "MockProvider"
callback1 = Mock()
callback2 = Mock()
# Subscribe to multiple symbols
service.subscribe_ticker("BTC/USD", callback1)
service.subscribe_ticker("ETH/USD", callback2)
assert "ticker:BTC/USD" in service._subscriptions
assert "ticker:ETH/USD" in service._subscriptions
assert len(service._subscriptions) == 2
@pytest.mark.e2e
@pytest.mark.asyncio
class TestPricingDataWebSocketE2E:
"""E2E tests for WebSocket pricing updates."""
async def test_websocket_price_broadcast(self):
"""Test WebSocket broadcasts price updates."""
# This test would require a running WebSocket server
# For now, we test the integration point
from backend.api.websocket import ConnectionManager
manager = ConnectionManager()
# Mock pricing service
with patch('backend.api.websocket.get_pricing_service') as mock_get_service:
mock_service = Mock()
mock_service.subscribe_ticker = Mock(return_value=True)
mock_get_service.return_value = mock_service
# Subscribe to symbol
manager.subscribe_to_symbol("BTC/USD")
assert "BTC/USD" in manager.subscribed_symbols
assert mock_service.subscribe_ticker.called
async def test_websocket_subscription_flow(self):
"""Test WebSocket subscription and unsubscription."""
from backend.api.websocket import ConnectionManager
manager = ConnectionManager()
with patch('backend.api.websocket.get_pricing_service') as mock_get_service:
mock_service = Mock()
mock_service.subscribe_ticker = Mock(return_value=True)
mock_service.unsubscribe_ticker = Mock()
mock_get_service.return_value = mock_service
# Subscribe
manager.subscribe_to_symbol("BTC/USD")
assert "BTC/USD" in manager.subscribed_symbols
# Unsubscribe
manager.unsubscribe_from_symbol("BTC/USD")
assert "BTC/USD" not in manager.subscribed_symbols
assert mock_service.unsubscribe_ticker.called

View File

@@ -0,0 +1,34 @@
"""End-to-end tests for strategy lifecycle."""
import pytest
from src.strategies.technical.rsi_strategy import RSIStrategy
from src.trading.engine import get_trading_engine
@pytest.mark.e2e
class TestStrategyLifecycle:
"""End-to-end tests for strategy lifecycle."""
@pytest.mark.asyncio
async def test_strategy_lifecycle(self):
"""Test complete strategy lifecycle."""
engine = get_trading_engine()
await engine.initialize()
# Create strategy
strategy = RSIStrategy(
strategy_id=1,
name="test_rsi",
symbol="BTC/USD",
timeframe="1h",
parameters={"rsi_period": 14}
)
# Start
await engine.start_strategy(strategy)
# Stop
await engine.stop_strategy(1)
await engine.shutdown()

2
tests/fixtures/__init__.py vendored Normal file
View File

@@ -0,0 +1,2 @@
"""Test fixtures and mocks."""

76
tests/fixtures/mock_exchange.py vendored Normal file
View File

@@ -0,0 +1,76 @@
"""Mock exchange adapter for testing."""
from unittest.mock import Mock, AsyncMock
from src.exchanges.base import BaseExchange
class MockExchange(BaseExchange):
"""Mock exchange adapter for testing."""
def __init__(self, name: str = "mock_exchange", **kwargs):
super().__init__(name, "mock_key", "mock_secret")
self.is_connected = True
self._mock_responses = {}
async def connect(self):
"""Mock connection."""
self.is_connected = True
async def disconnect(self):
"""Mock disconnection."""
self.is_connected = False
async def fetch_balance(self):
"""Mock balance fetch."""
return self._mock_responses.get('balance', {
'USD': {'free': 1000.0, 'used': 0.0, 'total': 1000.0}
})
async def place_order(self, symbol, side, order_type, amount, price=None, params=None):
"""Mock order placement."""
return self._mock_responses.get('order', {
'id': 'mock_order_123',
'status': 'filled',
'filled': amount,
'price': price or 50000.0
})
async def cancel_order(self, order_id, symbol=None):
"""Mock order cancellation."""
return {'id': order_id, 'status': 'canceled'}
async def fetch_order_status(self, order_id, symbol=None):
"""Mock order status fetch."""
return self._mock_responses.get('order_status', {
'id': order_id,
'status': 'filled'
})
async def fetch_ohlcv(self, symbol, timeframe, since=None, limit=None):
"""Mock OHLCV fetch."""
return self._mock_responses.get('ohlcv', [])
async def subscribe_ohlcv(self, symbol, timeframe, callback):
"""Mock OHLCV subscription."""
pass
async def subscribe_trades(self, symbol, callback):
"""Mock trades subscription."""
pass
async def subscribe_order_book(self, symbol, callback):
"""Mock order book subscription."""
pass
async def fetch_open_orders(self, symbol=None):
"""Mock open orders fetch."""
return self._mock_responses.get('open_orders', [])
async def fetch_positions(self, symbol=None):
"""Mock positions fetch."""
return self._mock_responses.get('positions', [])
async def fetch_markets(self):
"""Mock markets fetch."""
return self._mock_responses.get('markets', [])

81
tests/fixtures/sample_data.py vendored Normal file
View File

@@ -0,0 +1,81 @@
"""Sample data generators for testing."""
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from faker import Faker
fake = Faker()
def generate_ohlcv_data(
symbol: str = "BTC/USD",
periods: int = 100,
start_price: float = 50000.0,
timeframe: str = "1h"
) -> pd.DataFrame:
"""Generate sample OHLCV data.
Args:
symbol: Trading pair
periods: Number of periods
start_price: Starting price
timeframe: Data timeframe
Returns:
DataFrame with OHLCV data
"""
dates = pd.date_range(
start=datetime.now() - timedelta(hours=periods),
periods=periods,
freq=timeframe
)
# Generate realistic price movement
prices = [start_price]
for i in range(1, periods):
change = np.random.randn() * 0.02 # 2% volatility
prices.append(prices[-1] * (1 + change))
return pd.DataFrame({
'timestamp': dates,
'open': prices,
'high': [p * 1.01 for p in prices],
'low': [p * 0.99 for p in prices],
'close': prices,
'volume': [1000.0 + np.random.randn() * 100 for _ in range(periods)]
})
def generate_trade_data(
symbol: str = "BTC/USD",
count: int = 10
) -> list:
"""Generate sample trade data.
Args:
symbol: Trading pair
count: Number of trades
Returns:
List of trade dictionaries
"""
trades = []
base_price = 50000.0
for i in range(count):
trades.append({
'order_id': f'trade_{i}',
'symbol': symbol,
'side': 'buy' if i % 2 == 0 else 'sell',
'type': 'market',
'price': base_price + np.random.randn() * 100,
'amount': 0.01,
'cost': 500.0,
'fee': 0.5,
'status': 'filled',
'timestamp': datetime.now() - timedelta(hours=count-i)
})
return trades

View File

@@ -0,0 +1,2 @@
"""Integration tests."""

View File

@@ -0,0 +1,2 @@
"""Backend integration tests."""

View File

@@ -0,0 +1,95 @@
"""Integration tests for API workflows."""
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch, Mock
from backend.main import app
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.mark.integration
class TestTradingWorkflow:
"""Test complete trading workflow through API."""
@patch('backend.api.trading.get_trading_engine')
@patch('backend.api.trading.get_db')
def test_complete_trading_workflow(self, mock_get_db, mock_get_engine, client):
"""Test: Place order → Check order status → Get positions."""
# Setup mocks
mock_engine = Mock()
mock_order = Mock()
mock_order.id = 1
mock_order.symbol = "BTC/USD"
mock_order.side = "buy"
mock_order.status = "filled"
mock_engine.execute_order.return_value = mock_order
mock_get_engine.return_value = mock_engine
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
# Mock order query
mock_session.query.return_value.filter_by.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_order]
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
# Place order
order_data = {
"exchange_id": 1,
"symbol": "BTC/USD",
"side": "buy",
"order_type": "market",
"quantity": "0.1",
"paper_trading": True
}
create_response = client.post("/api/trading/orders", json=order_data)
assert create_response.status_code == 200
order_id = create_response.json()["id"]
# Get order status
status_response = client.get(f"/api/trading/orders/{order_id}")
assert status_response.status_code == 200
assert status_response.json()["id"] == order_id
# Get orders list
orders_response = client.get("/api/trading/orders")
assert orders_response.status_code == 200
assert isinstance(orders_response.json(), list)
@pytest.mark.integration
class TestPortfolioWorkflow:
"""Test portfolio workflow through API."""
@patch('backend.api.portfolio.get_portfolio_tracker')
def test_portfolio_workflow(self, mock_get_tracker, client):
"""Test: Get current portfolio → Get portfolio history."""
mock_tracker = Mock()
mock_tracker.get_current_portfolio.return_value = {
"positions": [],
"performance": {"total_return": 0.1},
"timestamp": "2025-01-01T00:00:00"
}
mock_tracker.get_portfolio_history.return_value = [
{"timestamp": "2025-01-01T00:00:00", "total_value": 1000.0, "total_pnl": 0.0}
]
mock_get_tracker.return_value = mock_tracker
# Get current portfolio
current_response = client.get("/api/portfolio/current?paper_trading=true")
assert current_response.status_code == 200
assert "positions" in current_response.json()
# Get portfolio history
history_response = client.get("/api/portfolio/history?paper_trading=true&days=30")
assert history_response.status_code == 200
assert "dates" in history_response.json()

View File

@@ -0,0 +1,323 @@
"""Integration tests for frontend API workflows.
These tests verify that all frontend-accessible API endpoints work correctly
and return data in the expected format for the React frontend.
"""
import pytest
from fastapi.testclient import TestClient
from decimal import Decimal
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
from backend.main import app
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def mock_strategy():
"""Mock strategy data."""
return {
"id": 1,
"name": "Test RSI Strategy",
"strategy_type": "rsi",
"class_name": "rsi",
"parameters": {
"symbol": "BTC/USD",
"exchange_id": 1,
"rsi_period": 14,
"oversold": 30,
"overbought": 70
},
"timeframes": ["1h"],
"enabled": False,
"paper_trading": True,
"version": "1.0.0",
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat()
}
@pytest.fixture
def mock_exchange():
"""Mock exchange data."""
return {
"id": 1,
"name": "coinbase",
"sandbox": False,
"read_only": True,
"enabled": True,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat()
}
class TestStrategyManagementAPI:
"""Test strategy management API endpoints."""
def test_list_strategies(self, client):
"""Test listing all strategies."""
with patch('backend.api.strategies.get_db') as mock_db:
mock_session = Mock()
mock_strategy = Mock()
mock_strategy.id = 1
mock_strategy.name = "Test Strategy"
mock_strategy.strategy_type = "rsi"
mock_strategy.class_name = "rsi"
mock_strategy.parameters = {}
mock_strategy.timeframes = ["1h"]
mock_strategy.enabled = False
mock_strategy.paper_trading = True
mock_strategy.version = "1.0.0"
mock_strategy.description = None
mock_strategy.schedule = None
mock_strategy.created_at = datetime.now()
mock_strategy.updated_at = datetime.now()
mock_session.query.return_value.order_by.return_value.all.return_value = [mock_strategy]
mock_db.return_value.get_session.return_value.__enter__.return_value = mock_session
mock_db.return_value.get_session.return_value.__exit__ = Mock(return_value=None)
response = client.get("/api/strategies/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
def test_get_available_strategies(self, client):
"""Test getting available strategy types."""
with patch('backend.api.strategies.get_strategy_registry') as mock_registry:
mock_registry.return_value.list_available.return_value = [
"rsi", "macd", "moving_average", "dca", "grid", "momentum"
]
response = client.get("/api/strategies/available")
assert response.status_code == 200
data = response.json()
assert "strategies" in data
assert isinstance(data["strategies"], list)
def test_create_strategy(self, client, mock_strategy):
"""Test creating a new strategy."""
with patch('backend.api.strategies.get_db') as mock_db:
mock_session = Mock()
mock_db.return_value.get_session.return_value.__enter__.return_value = mock_session
mock_db.return_value.get_session.return_value.__exit__ = Mock(return_value=None)
mock_session.add = Mock()
mock_session.commit = Mock()
mock_session.refresh = Mock()
# Create strategy instance
created_strategy = Mock()
created_strategy.id = 1
created_strategy.name = mock_strategy["name"]
created_strategy.strategy_type = mock_strategy["strategy_type"]
created_strategy.class_name = mock_strategy["class_name"]
created_strategy.parameters = mock_strategy["parameters"]
created_strategy.timeframes = mock_strategy["timeframes"]
created_strategy.enabled = False
created_strategy.paper_trading = mock_strategy["paper_trading"]
created_strategy.version = "1.0.0"
created_strategy.description = None
created_strategy.schedule = None
created_strategy.created_at = datetime.now()
created_strategy.updated_at = datetime.now()
mock_session.refresh.side_effect = lambda x: setattr(x, 'id', 1)
response = client.post(
"/api/strategies/",
json={
"name": mock_strategy["name"],
"strategy_type": mock_strategy["strategy_type"],
"class_name": mock_strategy["class_name"],
"parameters": mock_strategy["parameters"],
"timeframes": mock_strategy["timeframes"],
"paper_trading": mock_strategy["paper_trading"]
}
)
# May return 200 or 500 depending on implementation
assert response.status_code in [200, 201, 500]
class TestTradingAPI:
"""Test trading API endpoints."""
def test_get_positions(self, client):
"""Test getting positions."""
with patch('src.trading.paper_trading.get_paper_trading') as mock_pt:
mock_pt.return_value.get_positions.return_value = []
response = client.get("/api/trading/positions?paper_trading=true")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
def test_get_orders(self, client):
"""Test getting orders."""
with patch('backend.api.trading.get_db') as mock_db:
mock_session = Mock()
mock_session.query.return_value.filter_by.return_value.order_by.return_value.limit.return_value.all.return_value = []
mock_db.return_value.get_session.return_value.__enter__.return_value = mock_session
mock_db.return_value.get_session.return_value.__exit__ = Mock(return_value=None)
response = client.get("/api/trading/orders?paper_trading=true&limit=10")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
def test_get_balance(self, client):
"""Test getting balance."""
with patch('src.trading.paper_trading.get_paper_trading') as mock_pt:
mock_pt.return_value.get_balance.return_value = Decimal("100.00")
mock_pt.return_value.get_performance.return_value = {}
response = client.get("/api/trading/balance?paper_trading=true")
assert response.status_code == 200
data = response.json()
assert "balance" in data
class TestPortfolioAPI:
"""Test portfolio API endpoints."""
def test_get_current_portfolio(self, client):
"""Test getting current portfolio."""
with patch('backend.api.portfolio.get_portfolio_tracker') as mock_tracker:
mock_tracker.return_value.get_current_portfolio.return_value = {
"positions": [],
"performance": {
"current_value": 100.0,
"unrealized_pnl": 0.0,
"realized_pnl": 0.0
},
"timestamp": datetime.now().isoformat()
}
response = client.get("/api/portfolio/current?paper_trading=true")
assert response.status_code == 200
data = response.json()
assert "positions" in data
assert "performance" in data
def test_get_portfolio_history(self, client):
"""Test getting portfolio history."""
with patch('backend.api.portfolio.get_portfolio_tracker') as mock_tracker:
mock_tracker.return_value.get_portfolio_history.return_value = {
"dates": [datetime.now().isoformat()],
"values": [100.0],
"pnl": [0.0]
}
response = client.get("/api/portfolio/history?days=30&paper_trading=true")
assert response.status_code == 200
data = response.json()
assert "dates" in data
assert "values" in data
class TestBacktestingAPI:
"""Test backtesting API endpoints."""
def test_run_backtest(self, client):
"""Test running a backtest."""
with patch('backend.api.backtesting.get_backtesting_engine') as mock_engine, \
patch('backend.api.backtesting.get_db') as mock_db:
mock_session = Mock()
mock_strategy = Mock()
mock_strategy.id = 1
mock_strategy.class_name = "rsi"
mock_strategy.parameters = {}
mock_strategy.timeframes = ["1h"]
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_strategy
mock_db.return_value.get_session.return_value.__enter__.return_value = mock_session
mock_db.return_value.get_session.return_value.__exit__ = Mock(return_value=None)
mock_engine.return_value.run_backtest.return_value = {
"total_return": 0.1,
"sharpe_ratio": 1.5,
"max_drawdown": -0.05,
"win_rate": 0.6,
"total_trades": 10,
"final_value": 110.0
}
response = client.post(
"/api/backtesting/run",
json={
"strategy_id": 1,
"symbol": "BTC/USD",
"exchange": "coinbase",
"timeframe": "1h",
"start_date": (datetime.now() - timedelta(days=30)).isoformat(),
"end_date": datetime.now().isoformat(),
"initial_capital": 100.0,
"slippage": 0.001,
"fee_rate": 0.001
}
)
# May return 200 or error depending on implementation
assert response.status_code in [200, 400, 500]
class TestAlertsAPI:
"""Test alerts API endpoints."""
def test_list_alerts(self, client):
"""Test listing alerts."""
with patch('backend.api.alerts.get_alert_manager') as mock_manager:
mock_manager.return_value.list_alerts.return_value = []
response = client.get("/api/alerts/?enabled_only=false")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
def test_create_alert(self, client):
"""Test creating an alert."""
with patch('backend.api.alerts.get_alert_manager') as mock_manager:
mock_alert = Mock()
mock_alert.id = 1
mock_alert.name = "Test Alert"
mock_alert.alert_type = "price"
mock_alert.condition = {"symbol": "BTC/USD", "price_threshold": 50000}
mock_alert.enabled = True
mock_alert.triggered = False
mock_alert.triggered_at = None
mock_alert.created_at = datetime.now()
mock_alert.updated_at = datetime.now()
mock_manager.return_value.create_alert.return_value = mock_alert
response = client.post(
"/api/alerts/",
json={
"name": "Test Alert",
"alert_type": "price",
"condition": {"symbol": "BTC/USD", "price_threshold": 50000}
}
)
# May return 200 or 500 depending on implementation
assert response.status_code in [200, 201, 500]
class TestExchangesAPI:
"""Test exchanges API endpoints."""
def test_list_exchanges(self, client):
"""Test listing exchanges."""
with patch('backend.api.exchanges.get_db') as mock_db:
mock_session = Mock()
mock_session.query.return_value.all.return_value = []
mock_db.return_value.get_session.return_value.__enter__.return_value = mock_session
mock_db.return_value.get_session.return_value.__exit__ = Mock(return_value=None)
response = client.get("/api/exchanges/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)

View File

@@ -0,0 +1,32 @@
"""Integration tests for WebSocket connections."""
import pytest
from fastapi.testclient import TestClient
from backend.main import app
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.mark.integration
class TestWebSocketConnection:
"""Test WebSocket connections."""
def test_websocket_connection(self, client):
"""Test WebSocket connection."""
with client.websocket_connect("/ws/") as websocket:
# Connection should be established
assert websocket is not None
def test_websocket_message_handling(self, client):
"""Test WebSocket message handling."""
with client.websocket_connect("/ws/") as websocket:
# Send a test message
websocket.send_json({"type": "ping"})
# WebSocket should accept the connection
# Note: Actual message handling depends on implementation

View File

@@ -0,0 +1,138 @@
"""Integration tests for autopilot workflows."""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from fastapi.testclient import TestClient
from backend.main import app
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.fixture
def mock_intelligent_autopilot():
"""Mock intelligent autopilot."""
autopilot = Mock()
autopilot.symbol = "BTC/USD"
autopilot.is_running = False
autopilot.enable_auto_execution = False
autopilot.get_status.return_value = {
"symbol": "BTC/USD",
"timeframe": "1h",
"running": False,
"selected_strategy": None,
"trades_today": 0,
"max_trades_per_day": 10,
"min_confidence_threshold": 0.75,
"enable_auto_execution": False,
"last_analysis": None,
"model_info": {},
}
return autopilot
@pytest.mark.integration
class TestAutopilotWorkflow:
"""Integration tests for autopilot workflows."""
@patch('backend.api.autopilot.get_intelligent_autopilot')
def test_start_intelligent_mode_workflow(
self, mock_get_intelligent, client, mock_intelligent_autopilot
):
"""Test complete workflow for starting intelligent mode autopilot."""
mock_get_intelligent.return_value = mock_intelligent_autopilot
# Start autopilot
response = client.post(
"/api/autopilot/start-unified",
json={
"symbol": "BTC/USD",
"mode": "intelligent",
"auto_execute": True,
"exchange_id": 1,
"timeframe": "1h",
},
)
assert response.status_code == 200
assert mock_intelligent_autopilot.enable_auto_execution is True
# Check status
response = client.get(
"/api/autopilot/status-unified/BTC/USD?mode=intelligent&timeframe=1h"
)
assert response.status_code == 200
data = response.json()
assert data["mode"] == "intelligent"
# Stop autopilot
response = client.post(
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=intelligent&timeframe=1h"
)
assert response.status_code == 200
assert mock_intelligent_autopilot.stop.called
@pytest.mark.integration
class TestTrainingConfigWorkflow:
"""Integration tests for training configuration workflow."""
def test_configure_and_verify_bootstrap(self, client):
"""Test configuring bootstrap settings and verifying they persist."""
# Set custom config
custom_config = {
"days": 180,
"timeframe": "4h",
"min_samples_per_strategy": 25,
"symbols": ["BTC/USD", "ETH/USD", "SOL/USD", "DOGE/USD"]
}
response = client.put("/api/autopilot/bootstrap-config", json=custom_config)
assert response.status_code == 200
# Verify it was saved
response = client.get("/api/autopilot/bootstrap-config")
assert response.status_code == 200
data = response.json()
assert data["days"] == 180
assert data["timeframe"] == "4h"
assert data["min_samples_per_strategy"] == 25
assert len(data["symbols"]) == 4
assert "DOGE/USD" in data["symbols"]
@patch('backend.api.autopilot.train_model_task')
def test_training_uses_configured_symbols(self, mock_task, client):
"""Test that training uses the configured symbols."""
# Setup mock
mock_result = Mock()
mock_result.id = "test-task-123"
mock_task.delay.return_value = mock_result
# Configure with specific symbols
config = {
"days": 90,
"timeframe": "1h",
"min_samples_per_strategy": 10,
"symbols": ["BTC/USD", "ETH/USD", "XRP/USD"]
}
client.put("/api/autopilot/bootstrap-config", json=config)
# Trigger training
response = client.post("/api/autopilot/intelligent/retrain")
assert response.status_code == 200
# Verify the symbols were passed
call_kwargs = mock_task.delay.call_args.kwargs
assert call_kwargs["symbols"] == ["BTC/USD", "ETH/USD", "XRP/USD"]
assert call_kwargs["days"] == 90
assert call_kwargs["timeframe"] == "1h"
assert call_kwargs["min_samples_per_strategy"] == 10

View File

@@ -0,0 +1,18 @@
"""Integration tests for backtesting workflow."""
import pytest
from datetime import datetime, timedelta
from src.backtesting.engine import get_backtest_engine
@pytest.mark.integration
class TestBacktestingWorkflow:
"""Integration tests for backtesting workflow."""
@pytest.mark.asyncio
async def test_backtesting_workflow(self):
"""Test complete backtesting workflow."""
engine = get_backtest_engine()
assert engine is not None
# Full workflow test would require strategy and data setup

View File

@@ -0,0 +1,21 @@
"""Integration tests for data pipeline."""
import pytest
from src.data.collector import get_data_collector
from src.data.storage import get_data_storage
@pytest.mark.integration
class TestDataPipeline:
"""Integration tests for data collection and storage."""
@pytest.mark.asyncio
async def test_data_collection_storage(self):
"""Test data collection and storage pipeline."""
collector = get_data_collector()
storage = get_data_storage()
# Test components exist
assert collector is not None
assert storage is not None

View File

@@ -0,0 +1,31 @@
"""Integration tests for portfolio tracking."""
import pytest
from decimal import Decimal
from src.portfolio.tracker import get_portfolio_tracker
from src.trading.paper_trading import get_paper_trading
@pytest.mark.integration
class TestPortfolioTracking:
"""Integration tests for portfolio tracking."""
@pytest.mark.asyncio
async def test_portfolio_tracking_workflow(self):
"""Test portfolio tracking workflow."""
tracker = get_portfolio_tracker()
paper_trading = get_paper_trading()
# Get initial portfolio
portfolio1 = await tracker.get_current_portfolio(paper_trading=True)
# Portfolio should have structure
assert "positions" in portfolio1
assert "performance" in portfolio1
# Get updated portfolio
portfolio2 = await tracker.get_current_portfolio(paper_trading=True)
# Should return valid portfolio
assert portfolio2 is not None

View File

@@ -0,0 +1,71 @@
"""Integration tests for pricing provider system."""
import pytest
from datetime import datetime
from src.data.pricing_service import get_pricing_service
@pytest.mark.integration
class TestPricingProviderIntegration:
"""Integration tests for pricing providers."""
@pytest.fixture(autouse=True)
def setup(self):
"""Set up test fixtures."""
# Reset global service instance
import src.data.pricing_service
src.data.pricing_service._pricing_service = None
def test_service_initialization(self):
"""Test that pricing service initializes correctly."""
service = get_pricing_service()
assert service is not None
assert service.cache is not None
assert service.health_monitor is not None
@pytest.mark.skip(reason="Requires network access - run manually")
def test_get_ticker_integration(self):
"""Test getting ticker data from real providers."""
service = get_pricing_service()
# This will try to connect to real providers
ticker = service.get_ticker("BTC/USD", use_cache=False)
# Should get some data if providers are available
if ticker:
assert 'symbol' in ticker
assert 'last' in ticker
assert ticker['last'] > 0
@pytest.mark.skip(reason="Requires network access - run manually")
def test_provider_failover(self):
"""Test provider failover mechanism."""
service = get_pricing_service()
# Get active provider
active = service.get_active_provider()
assert active is not None or len(service._providers) == 0
def test_cache_integration(self):
"""Test cache integration with service."""
service = get_pricing_service()
# Set a value
service.cache.set("test:key", "test_value", ttl=60)
# Get it back
value = service.cache.get("test:key")
assert value == "test_value"
def test_health_monitoring(self):
"""Test health monitoring integration."""
service = get_pricing_service()
# Record some metrics
service.health_monitor.record_success("test_provider", 0.5)
service.health_monitor.record_failure("test_provider")
# Check health
is_healthy = service.health_monitor.is_healthy("test_provider")
assert isinstance(is_healthy, bool)

View File

@@ -0,0 +1,36 @@
"""Integration tests for strategy execution."""
import pytest
from src.strategies.technical.rsi_strategy import RSIStrategy
from src.trading.engine import get_trading_engine
@pytest.mark.integration
class TestStrategyExecution:
"""Integration tests for strategy execution."""
@pytest.mark.asyncio
async def test_strategy_execution_workflow(self):
"""Test complete strategy execution workflow."""
engine = get_trading_engine()
await engine.initialize()
# Create strategy
strategy = RSIStrategy(
strategy_id=1,
name="test_rsi",
symbol="BTC/USD",
timeframe="1h",
parameters={"rsi_period": 14}
)
# Start strategy
await engine.start_strategy(strategy)
assert strategy.is_active
# Stop strategy
await engine.stop_strategy(1)
assert not strategy.is_active
await engine.shutdown()

View File

@@ -0,0 +1,47 @@
"""Integration tests for trading workflow."""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from src.trading.engine import get_trading_engine
from src.strategies.technical.rsi_strategy import RSIStrategy
@pytest.mark.integration
class TestTradingWorkflow:
"""Integration tests for complete trading workflow."""
@pytest.mark.asyncio
async def test_complete_trading_workflow(self, mock_database):
"""Test complete trading workflow."""
# Initialize trading engine
engine = get_trading_engine()
await engine.initialize()
# Create strategy
strategy = RSIStrategy(
strategy_id=1,
name="test_rsi",
symbol="BTC/USD",
timeframe="1h",
parameters={"rsi_period": 14}
)
# Start strategy
await engine.start_strategy(strategy)
# Execute trade (paper trading)
result = await engine.execute_trade(
exchange_name="paper_trading",
strategy_id=1,
symbol="BTC/USD",
side="buy",
order_type="market",
amount=0.01,
is_paper_trade=True
)
assert result is not None
# Cleanup
await engine.shutdown()

View File

@@ -0,0 +1,65 @@
"""Integration tests for UI backtest workflow."""
import pytest
from datetime import datetime, timedelta
from PyQt6.QtWidgets import QApplication
from unittest.mock import Mock, patch
from src.ui.widgets.backtest_view import BacktestViewWidget
from src.strategies.technical.rsi_strategy import RSIStrategy
@pytest.fixture
def app():
"""Create QApplication for tests."""
if not QApplication.instance():
return QApplication([])
return QApplication.instance()
@pytest.fixture
def backtest_view(app):
"""Create BacktestViewWidget."""
view = BacktestViewWidget()
# Mock backtest engine
mock_engine = Mock()
mock_engine.run_backtest.return_value = {
'total_return': 10.5,
'sharpe_ratio': 1.2,
'max_drawdown': -5.0,
'win_rate': 0.55,
'total_trades': 50,
'final_value': 110.5,
'trades': []
}
view.backtest_engine = mock_engine
return view
def test_backtest_configuration(backtest_view):
"""Test backtest configuration form."""
assert backtest_view.strategy_combo is not None
assert backtest_view.symbol_input is not None
assert backtest_view.start_date is not None
assert backtest_view.end_date is not None
def test_backtest_results_display(backtest_view):
"""Test backtest results are displayed."""
results = {
'total_return': 10.5,
'sharpe_ratio': 1.2,
'max_drawdown': -5.0,
'win_rate': 0.55,
'total_trades': 50,
'final_value': 110.5,
'trades': []
}
backtest_view._display_results(results)
# Verify metrics text contains results
metrics_text = backtest_view.metrics_text.toPlainText()
assert "10.5" in metrics_text # Total return
assert "1.2" in metrics_text # Sharpe ratio

View File

@@ -0,0 +1,62 @@
"""Integration tests for UI strategy workflow."""
import pytest
from PyQt6.QtWidgets import QApplication
from unittest.mock import Mock, patch
from src.ui.widgets.strategy_manager import StrategyManagerWidget, StrategyDialog
from src.core.database import Strategy
@pytest.fixture
def app():
"""Create QApplication for tests."""
if not QApplication.instance():
return QApplication([])
return QApplication.instance()
@pytest.fixture
def strategy_manager(app):
"""Create StrategyManagerWidget."""
return StrategyManagerWidget()
def test_strategy_creation_workflow(strategy_manager, app):
"""Test creating strategy via UI."""
# Mock database
with patch.object(strategy_manager.db, 'get_session') as mock_session:
mock_session.return_value.__enter__.return_value.add = Mock()
mock_session.return_value.__enter__.return_value.commit = Mock()
# Simulate add strategy
dialog = StrategyDialog(strategy_manager)
dialog.name_input.setText("Test Strategy")
dialog.symbol_input.setText("BTC/USD")
dialog.type_combo.setCurrentText("rsi")
# Would need to set exchange combo data
# For now, just verify dialog structure
assert dialog.name_input.text() == "Test Strategy"
dialog.close()
def test_strategy_table_population(strategy_manager):
"""Test strategies table is populated from database."""
# Mock database query
mock_strategy = Mock(spec=Strategy)
mock_strategy.id = 1
mock_strategy.name = "Test Strategy"
mock_strategy.strategy_type = "rsi"
mock_strategy.parameters = {"symbol": "BTC/USD"}
mock_strategy.enabled = True
with patch.object(strategy_manager.db, 'get_session') as mock_session:
mock_query = Mock()
mock_query.all.return_value = [mock_strategy]
mock_session.return_value.__enter__.return_value.query.return_value.filter_by.return_value.all = Mock(return_value=[mock_strategy])
mock_session.return_value.__enter__.return_value.query.return_value.all = Mock(return_value=[mock_strategy])
strategy_manager._refresh_strategies()
# Verify table has data
# Note: Actual implementation would verify table contents

View File

@@ -0,0 +1,77 @@
"""Integration tests for UI trading workflow."""
import pytest
from decimal import Decimal
from PyQt6.QtWidgets import QApplication
from unittest.mock import Mock, patch
from src.ui.widgets.trading_view import TradingView
from src.core.database import Order, OrderStatus, OrderSide, OrderType
@pytest.fixture
def app():
"""Create QApplication for tests."""
if not QApplication.instance():
return QApplication([])
return QApplication.instance()
@pytest.fixture
def trading_view(app):
"""Create TradingView with mocked backend."""
view = TradingView()
# Mock trading engine
mock_engine = Mock()
mock_order = Mock(spec=Order)
mock_order.id = 1
mock_engine.execute_order.return_value = mock_order
view.trading_engine = mock_engine
return view
def test_order_placement_workflow(trading_view):
"""Test complete order placement workflow."""
# Set up form
trading_view.current_exchange_id = 1
trading_view.current_symbol = "BTC/USD"
trading_view.order_type_combo.setCurrentText("Market")
trading_view.side_combo.setCurrentText("Buy")
trading_view.quantity_input.setValue(0.1)
# Place order
trading_view._place_order()
# Verify engine was called
trading_view.trading_engine.execute_order.assert_called_once()
call_args = trading_view.trading_engine.execute_order.call_args
assert call_args[1]['symbol'] == "BTC/USD"
assert call_args[1]['side'] == OrderSide.BUY
def test_position_table_update(trading_view):
"""Test positions table updates with portfolio data."""
# Mock portfolio data
mock_portfolio = {
'positions': [
{
'symbol': 'BTC/USD',
'quantity': 0.1,
'entry_price': 50000,
'current_price': 51000,
'unrealized_pnl': 100
}
],
'performance': {
'current_value': 5100,
'unrealized_pnl': 100,
'realized_pnl': 0
}
}
trading_view.portfolio_tracker.get_current_portfolio = Mock(return_value=mock_portfolio)
trading_view._update_positions()
assert trading_view.positions_table.rowCount() == 1
assert trading_view.positions_table.item(0, 0).text() == "BTC/USD"

View File

@@ -0,0 +1,2 @@
"""Performance tests."""

View File

@@ -0,0 +1,26 @@
"""Performance benchmarks for backtesting."""
import pytest
import time
from src.backtesting.engine import get_backtest_engine
@pytest.mark.slow
class TestBacktestPerformance:
"""Performance tests for backtesting."""
@pytest.mark.asyncio
async def test_backtest_speed(self):
"""Test backtesting speed."""
engine = get_backtest_engine()
# Create minimal backtest scenario
start_time = time.time()
# Run minimal backtest (would need actual implementation)
# This is a placeholder
elapsed = time.time() - start_time
# Backtest should complete in reasonable time
assert elapsed < 60 # Less than 60 seconds for basic test

8
tests/requirements.txt Normal file
View File

@@ -0,0 +1,8 @@
pytest>=7.4.0
pytest-cov>=4.1.0
pytest-asyncio>=0.21.0
pytest-mock>=3.11.0
coverage>=7.3.0
faker>=19.0.0
freezegun>=1.2.0

2
tests/unit/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""Unit tests."""

View File

@@ -0,0 +1,2 @@
"""Unit tests for alert system."""

View File

@@ -0,0 +1,33 @@
"""Tests for alert engine."""
import pytest
from src.alerts.engine import get_alert_engine, AlertEngine
class TestAlertEngine:
"""Tests for AlertEngine."""
@pytest.fixture
def alert_engine(self):
"""Create alert engine instance."""
return get_alert_engine()
def test_alert_engine_initialization(self, alert_engine):
"""Test alert engine initialization."""
assert alert_engine is not None
@pytest.mark.asyncio
async def test_process_alert(self, alert_engine):
"""Test processing alert."""
# Create test alert
alert = {
'name': 'test_alert',
'type': 'price',
'condition': 'BTC/USD > 50000',
'is_active': True
}
# Process alert (simplified)
result = await alert_engine.process_alert(alert, {'BTC/USD': 51000.0})
assert isinstance(result, bool)

View File

@@ -0,0 +1,2 @@
"""Tests for autopilot module."""

View File

@@ -0,0 +1,175 @@
"""Tests for intelligent autopilot functionality."""
import pytest
from decimal import Decimal
from unittest.mock import Mock, patch, AsyncMock
from src.core.database import OrderSide, OrderType
class TestPreFlightValidation:
"""Tests for pre-flight order validation."""
@pytest.fixture
def mock_autopilot(self):
"""Create mock autopilot with necessary attributes."""
from src.autopilot.intelligent_autopilot import IntelligentAutopilot
with patch.object(IntelligentAutopilot, '__init__', lambda x, *args, **kwargs: None):
autopilot = IntelligentAutopilot.__new__(IntelligentAutopilot)
autopilot.symbol = 'BTC/USD'
autopilot.paper_trading = True
autopilot.logger = Mock()
# Mock trading engine
autopilot.trading_engine = Mock()
autopilot.trading_engine.paper_trading = Mock()
autopilot.trading_engine.paper_trading.get_balance.return_value = Decimal('1000.0')
autopilot.trading_engine.paper_trading.get_positions.return_value = []
return autopilot
@pytest.mark.asyncio
async def test_can_execute_order_insufficient_funds(self, mock_autopilot):
"""Test that insufficient funds returns False."""
mock_autopilot.trading_engine.paper_trading.get_balance.return_value = Decimal('10.0')
can_execute, reason = await mock_autopilot._can_execute_order(
side=OrderSide.BUY,
quantity=Decimal('1.0'),
price=Decimal('100.0')
)
assert can_execute is False
assert 'Insufficient funds' in reason
@pytest.mark.asyncio
async def test_can_execute_order_sufficient_funds(self, mock_autopilot):
"""Test that sufficient funds returns True."""
mock_autopilot.trading_engine.paper_trading.get_balance.return_value = Decimal('1000.0')
can_execute, reason = await mock_autopilot._can_execute_order(
side=OrderSide.BUY,
quantity=Decimal('1.0'),
price=Decimal('100.0')
)
assert can_execute is True
assert reason == 'OK'
@pytest.mark.asyncio
async def test_can_execute_order_no_position_for_sell(self, mock_autopilot):
"""Test that SELL without position returns False."""
mock_autopilot.trading_engine.paper_trading.get_positions.return_value = []
can_execute, reason = await mock_autopilot._can_execute_order(
side=OrderSide.SELL,
quantity=Decimal('1.0'),
price=Decimal('100.0')
)
assert can_execute is False
assert 'No position to sell' in reason
@pytest.mark.asyncio
async def test_can_execute_order_minimum_value(self, mock_autopilot):
"""Test that order below minimum value returns False."""
can_execute, reason = await mock_autopilot._can_execute_order(
side=OrderSide.BUY,
quantity=Decimal('0.001'),
price=Decimal('0.10') # Order value = $0.0001
)
assert can_execute is False
assert 'below minimum' in reason
class TestSmartOrderTypeSelection:
"""Tests for smart order type selection."""
@pytest.fixture
def mock_autopilot(self):
"""Create mock autopilot for order type tests."""
from src.autopilot.intelligent_autopilot import IntelligentAutopilot
with patch.object(IntelligentAutopilot, '__init__', lambda x, *args, **kwargs: None):
autopilot = IntelligentAutopilot.__new__(IntelligentAutopilot)
autopilot.logger = Mock()
return autopilot
def test_strong_signal_uses_market(self, mock_autopilot):
"""Test that strong signals (>80%) use MARKET orders."""
order_type, price = mock_autopilot._determine_order_type_and_price(
side=OrderSide.BUY,
signal_strength=0.85,
current_price=Decimal('100.0'),
is_stop_loss=False
)
assert order_type == OrderType.MARKET
assert price is None
def test_normal_signal_uses_limit(self, mock_autopilot):
"""Test that normal signals use LIMIT orders."""
order_type, price = mock_autopilot._determine_order_type_and_price(
side=OrderSide.BUY,
signal_strength=0.65,
current_price=Decimal('100.0'),
is_stop_loss=False
)
assert order_type == OrderType.LIMIT
assert price is not None
# BUY limit should be below market
assert price < Decimal('100.0')
def test_stop_loss_uses_market(self, mock_autopilot):
"""Test that stop-loss exits use MARKET orders."""
order_type, price = mock_autopilot._determine_order_type_and_price(
side=OrderSide.SELL,
signal_strength=0.5,
current_price=Decimal('100.0'),
is_stop_loss=True
)
assert order_type == OrderType.MARKET
assert price is None
def test_take_profit_uses_limit(self, mock_autopilot):
"""Test that take-profit exits can use LIMIT orders."""
order_type, price = mock_autopilot._determine_order_type_and_price(
side=OrderSide.SELL,
signal_strength=0.6,
current_price=Decimal('100.0'),
is_stop_loss=False
)
assert order_type == OrderType.LIMIT
assert price is not None
# SELL limit should be above market
assert price > Decimal('100.0')
def test_buy_limit_price_discount(self, mock_autopilot):
"""Test BUY LIMIT price is 0.1% below market."""
order_type, price = mock_autopilot._determine_order_type_and_price(
side=OrderSide.BUY,
signal_strength=0.6,
current_price=Decimal('1000.00'),
is_stop_loss=False
)
# 0.1% discount = 999.00
expected = Decimal('999.00')
assert price == expected
def test_sell_limit_price_premium(self, mock_autopilot):
"""Test SELL LIMIT price is 0.1% above market."""
order_type, price = mock_autopilot._determine_order_type_and_price(
side=OrderSide.SELL,
signal_strength=0.6,
current_price=Decimal('1000.00'),
is_stop_loss=False
)
# 0.1% premium = 1001.00
expected = Decimal('1001.00')
assert price == expected

View File

@@ -0,0 +1,161 @@
"""Tests for strategy grouping module."""
import pytest
from src.autopilot.strategy_groups import (
StrategyGroup,
STRATEGY_TO_GROUP,
GROUP_TO_STRATEGIES,
get_strategy_group,
get_strategies_in_group,
get_all_groups,
get_best_strategy_in_group,
convert_strategy_to_group_label,
)
class TestStrategyGroupMappings:
"""Tests for strategy group mappings."""
def test_all_strategies_have_group(self):
"""Verify all registered strategies are mapped to a group."""
# These are the strategies registered in src/strategies/__init__.py
expected_strategies = [
"rsi", "macd", "moving_average", "confirmed", "divergence",
"bollinger_mean_reversion", "dca", "grid", "momentum",
"consensus", "pairs_trading", "volatility_breakout",
"sentiment", "market_making"
]
for strategy in expected_strategies:
group = get_strategy_group(strategy)
assert group is not None, f"Strategy '{strategy}' is not mapped to any group"
def test_get_strategy_group_case_insensitive(self):
"""Test that strategy lookup is case-insensitive."""
assert get_strategy_group("RSI") == get_strategy_group("rsi")
assert get_strategy_group("MACD") == get_strategy_group("macd")
assert get_strategy_group("Moving_Average") == get_strategy_group("moving_average")
def test_get_strategy_group_unknown(self):
"""Test that unknown strategies return None."""
assert get_strategy_group("nonexistent_strategy") is None
assert get_strategy_group("") is None
def test_group_to_strategies_reverse_mapping(self):
"""Verify GROUP_TO_STRATEGIES is the reverse of STRATEGY_TO_GROUP."""
for strategy, group in STRATEGY_TO_GROUP.items():
assert strategy in GROUP_TO_STRATEGIES[group]
def test_all_groups_have_strategies(self):
"""Verify all groups have at least one strategy."""
for group in StrategyGroup:
strategies = get_strategies_in_group(group)
assert len(strategies) > 0, f"Group '{group}' has no strategies"
class TestGetAllGroups:
"""Tests for get_all_groups function."""
def test_returns_all_groups(self):
"""Verify all groups are returned."""
groups = get_all_groups()
assert len(groups) == 5
assert StrategyGroup.TREND_FOLLOWING in groups
assert StrategyGroup.MEAN_REVERSION in groups
assert StrategyGroup.MOMENTUM in groups
assert StrategyGroup.MARKET_MAKING in groups
assert StrategyGroup.SENTIMENT_BASED in groups
class TestGetStrategiesInGroup:
"""Tests for get_strategies_in_group function."""
def test_trend_following_strategies(self):
"""Test trend following group contains expected strategies."""
strategies = get_strategies_in_group(StrategyGroup.TREND_FOLLOWING)
assert "moving_average" in strategies
assert "macd" in strategies
assert "confirmed" in strategies
def test_mean_reversion_strategies(self):
"""Test mean reversion group contains expected strategies."""
strategies = get_strategies_in_group(StrategyGroup.MEAN_REVERSION)
assert "rsi" in strategies
assert "bollinger_mean_reversion" in strategies
assert "grid" in strategies
def test_momentum_strategies(self):
"""Test momentum group contains expected strategies."""
strategies = get_strategies_in_group(StrategyGroup.MOMENTUM)
assert "momentum" in strategies
assert "volatility_breakout" in strategies
class TestGetBestStrategyInGroup:
"""Tests for get_best_strategy_in_group function."""
def test_trend_following_high_adx(self):
"""Test trend following selection with high ADX."""
features = {"adx": 35.0, "rsi": 50.0, "atr_percent": 2.0, "volume_ratio": 1.0}
strategy, confidence = get_best_strategy_in_group(
StrategyGroup.TREND_FOLLOWING, features
)
assert strategy == "confirmed"
assert confidence > 0.5
def test_mean_reversion_extreme_rsi(self):
"""Test mean reversion selection with extreme RSI."""
features = {"adx": 15.0, "rsi": 25.0, "atr_percent": 1.5, "volume_ratio": 1.0}
strategy, confidence = get_best_strategy_in_group(
StrategyGroup.MEAN_REVERSION, features
)
assert strategy == "rsi"
assert confidence > 0.5
def test_momentum_high_volume(self):
"""Test momentum selection with high volume."""
features = {"adx": 25.0, "rsi": 55.0, "atr_percent": 3.0, "volume_ratio": 2.0}
strategy, confidence = get_best_strategy_in_group(
StrategyGroup.MOMENTUM, features
)
assert strategy == "volatility_breakout"
assert confidence > 0.5
def test_respects_available_strategies(self):
"""Test that only available strategies are selected."""
features = {"adx": 35.0, "rsi": 50.0}
# Only MACD available from trend following
strategy, confidence = get_best_strategy_in_group(
StrategyGroup.TREND_FOLLOWING,
features,
available_strategies=["macd"]
)
assert strategy == "macd"
def test_fallback_when_no_strategies_available(self):
"""Test fallback when no strategies in group are available."""
features = {"adx": 25.0, "rsi": 50.0}
strategy, confidence = get_best_strategy_in_group(
StrategyGroup.TREND_FOLLOWING,
features,
available_strategies=["some_other_strategy"]
)
# Should return safe default
assert strategy == "rsi"
assert confidence == 0.5
class TestConvertStrategyToGroupLabel:
"""Tests for convert_strategy_to_group_label function."""
def test_converts_known_strategies(self):
"""Test conversion of known strategies."""
assert convert_strategy_to_group_label("rsi") == "mean_reversion"
assert convert_strategy_to_group_label("macd") == "trend_following"
assert convert_strategy_to_group_label("momentum") == "momentum"
assert convert_strategy_to_group_label("dca") == "market_making"
assert convert_strategy_to_group_label("sentiment") == "sentiment_based"
def test_unknown_strategy_returns_original(self):
"""Test that unknown strategies return original name."""
assert convert_strategy_to_group_label("unknown") == "unknown"

View File

@@ -0,0 +1,2 @@
"""Backend API tests."""

View File

@@ -0,0 +1,2 @@
"""Backend API endpoint tests."""

View File

@@ -0,0 +1,379 @@
"""Tests for autopilot API endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock, AsyncMock
from fastapi.testclient import TestClient
from backend.main import app
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.fixture
def mock_autopilot():
"""Mock autopilot instance."""
autopilot = Mock()
autopilot.symbol = "BTC/USD"
autopilot.is_running = False
autopilot.get_status.return_value = {
"symbol": "BTC/USD",
"running": False,
"interval": 60.0,
"has_market_data": True,
"market_data_length": 100,
"headlines_count": 5,
"last_sentiment_score": 0.5,
"last_pattern": "head_and_shoulders",
"last_signal": None,
}
autopilot.last_signal = None
autopilot.analyze_once.return_value = None
return autopilot
@pytest.fixture
def mock_intelligent_autopilot():
"""Mock intelligent autopilot instance."""
autopilot = Mock()
autopilot.symbol = "BTC/USD"
autopilot.is_running = False
autopilot.enable_auto_execution = False
autopilot.get_status.return_value = {
"symbol": "BTC/USD",
"timeframe": "1h",
"running": False,
"selected_strategy": None,
"trades_today": 0,
"max_trades_per_day": 10,
"min_confidence_threshold": 0.75,
"enable_auto_execution": False,
"last_analysis": None,
"model_info": {},
}
return autopilot
class TestUnifiedAutopilotEndpoints:
"""Tests for unified autopilot endpoints."""
@patch('backend.api.autopilot.get_autopilot_mode_info')
def test_get_modes(self, mock_get_mode_info, client):
"""Test getting autopilot mode information."""
mock_get_mode_info.return_value = {
"modes": {
"pattern": {"name": "Pattern-Based Autopilot"},
"intelligent": {"name": "ML-Based Autopilot"},
},
"comparison": {},
}
response = client.get("/api/autopilot/modes")
assert response.status_code == 200
data = response.json()
assert "modes" in data
assert "pattern" in data["modes"]
assert "intelligent" in data["modes"]
@patch('backend.api.autopilot.get_autopilot')
@patch('backend.api.autopilot.run_autopilot_loop')
def test_start_unified_pattern_mode(
self, mock_run_loop, mock_get_autopilot, client, mock_autopilot
):
"""Test starting unified autopilot in pattern mode."""
mock_get_autopilot.return_value = mock_autopilot
response = client.post(
"/api/autopilot/start-unified",
json={
"symbol": "BTC/USD",
"mode": "pattern",
"auto_execute": False,
"interval": 60.0,
"pattern_order": 5,
"auto_fetch_news": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "started"
assert data["mode"] == "pattern"
assert data["symbol"] == "BTC/USD"
assert data["auto_execute"] is False
mock_get_autopilot.assert_called_once()
@patch('backend.api.autopilot.get_intelligent_autopilot')
def test_start_unified_intelligent_mode(
self, mock_get_intelligent, client, mock_intelligent_autopilot
):
"""Test starting unified autopilot in intelligent mode."""
mock_get_intelligent.return_value = mock_intelligent_autopilot
response = client.post(
"/api/autopilot/start-unified",
json={
"symbol": "BTC/USD",
"mode": "intelligent",
"auto_execute": True,
"exchange_id": 1,
"timeframe": "1h",
"interval": 60.0,
"paper_trading": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "started"
assert data["mode"] == "intelligent"
assert data["symbol"] == "BTC/USD"
assert data["auto_execute"] is True
assert mock_intelligent_autopilot.enable_auto_execution is True
mock_get_intelligent.assert_called_once()
def test_start_unified_invalid_mode(self, client):
"""Test starting unified autopilot with invalid mode."""
response = client.post(
"/api/autopilot/start-unified",
json={
"symbol": "BTC/USD",
"mode": "invalid_mode",
"auto_execute": False,
},
)
assert response.status_code == 400
assert "Invalid mode" in response.json()["detail"]
@patch('backend.api.autopilot.get_autopilot')
def test_stop_unified_pattern_mode(
self, mock_get_autopilot, client, mock_autopilot
):
"""Test stopping unified autopilot in pattern mode."""
mock_get_autopilot.return_value = mock_autopilot
response = client.post(
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=pattern"
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "stopped"
assert data["symbol"] == "BTC/USD"
assert data["mode"] == "pattern"
mock_autopilot.stop.assert_called_once()
@patch('backend.api.autopilot.get_intelligent_autopilot')
def test_stop_unified_intelligent_mode(
self, mock_get_intelligent, client, mock_intelligent_autopilot
):
"""Test stopping unified autopilot in intelligent mode."""
mock_get_intelligent.return_value = mock_intelligent_autopilot
response = client.post(
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=intelligent&timeframe=1h"
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "stopped"
assert data["symbol"] == "BTC/USD"
assert data["mode"] == "intelligent"
mock_intelligent_autopilot.stop.assert_called_once()
def test_stop_unified_invalid_mode(self, client):
"""Test stopping unified autopilot with invalid mode."""
response = client.post(
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=invalid_mode"
)
assert response.status_code == 400
assert "Invalid mode" in response.json()["detail"]
@patch('backend.api.autopilot.get_autopilot')
def test_get_unified_status_pattern_mode(
self, mock_get_autopilot, client, mock_autopilot
):
"""Test getting unified autopilot status in pattern mode."""
mock_get_autopilot.return_value = mock_autopilot
response = client.get(
"/api/autopilot/status-unified/BTC/USD?mode=pattern"
)
assert response.status_code == 200
data = response.json()
assert data["symbol"] == "BTC/USD"
assert data["mode"] == "pattern"
assert "running" in data
@patch('backend.api.autopilot.get_intelligent_autopilot')
def test_get_unified_status_intelligent_mode(
self, mock_get_intelligent, client, mock_intelligent_autopilot
):
"""Test getting unified autopilot status in intelligent mode."""
mock_get_intelligent.return_value = mock_intelligent_autopilot
response = client.get(
"/api/autopilot/status-unified/BTC/USD?mode=intelligent&timeframe=1h"
)
assert response.status_code == 200
data = response.json()
assert data["symbol"] == "BTC/USD"
assert data["mode"] == "intelligent"
assert "running" in data
def test_get_unified_status_invalid_mode(self, client):
"""Test getting unified autopilot status with invalid mode."""
response = client.get(
"/api/autopilot/status-unified/BTC/USD?mode=invalid_mode"
)
assert response.status_code == 400
assert "Invalid mode" in response.json()["detail"]
class TestModeSelection:
"""Tests for mode selection logic."""
@patch('backend.api.autopilot.get_autopilot_mode_info')
def test_mode_info_structure(self, mock_get_mode_info):
"""Test that mode info has correct structure."""
from src.autopilot import get_autopilot_mode_info
mode_info = get_autopilot_mode_info()
assert "modes" in mode_info
assert "pattern" in mode_info["modes"]
assert "intelligent" in mode_info["modes"]
assert "comparison" in mode_info
# Check pattern mode structure
pattern = mode_info["modes"]["pattern"]
assert "name" in pattern
assert "description" in pattern
assert "how_it_works" in pattern
assert "best_for" in pattern
assert "tradeoffs" in pattern
assert "features" in pattern
assert "requirements" in pattern
# Check intelligent mode structure
intelligent = mode_info["modes"]["intelligent"]
assert "name" in intelligent
assert "description" in intelligent
assert "how_it_works" in intelligent
assert "best_for" in intelligent
assert "tradeoffs" in intelligent
assert "features" in intelligent
assert "requirements" in intelligent
# Check comparison structure
comparison = mode_info["comparison"]
assert "transparency" in comparison
assert "adaptability" in comparison
assert "setup_time" in comparison
assert "resource_usage" in comparison
class TestAutoExecution:
"""Tests for auto-execution functionality."""
@patch('backend.api.autopilot.get_intelligent_autopilot')
def test_auto_execute_enabled(
self, mock_get_intelligent, client, mock_intelligent_autopilot
):
"""Test that auto-execute is set when enabled."""
mock_get_intelligent.return_value = mock_intelligent_autopilot
response = client.post(
"/api/autopilot/start-unified",
json={
"symbol": "BTC/USD",
"mode": "intelligent",
"auto_execute": True,
"exchange_id": 1,
"timeframe": "1h",
},
)
assert response.status_code == 200
assert mock_intelligent_autopilot.enable_auto_execution is True
@patch('backend.api.autopilot.get_intelligent_autopilot')
def test_auto_execute_disabled(
self, mock_get_intelligent, client, mock_intelligent_autopilot
):
"""Test that auto-execute is not set when disabled."""
mock_get_intelligent.return_value = mock_intelligent_autopilot
response = client.post(
"/api/autopilot/start-unified",
json={
"symbol": "BTC/USD",
"mode": "intelligent",
"auto_execute": False,
"exchange_id": 1,
"timeframe": "1h",
},
)
assert response.status_code == 200
# Note: enable_auto_execution may have a default value, so we check it's not True
# The actual behavior depends on the implementation
class TestBackwardCompatibility:
"""Tests for backward compatibility with old endpoints."""
@patch('backend.api.autopilot.get_autopilot')
@patch('backend.api.autopilot.run_autopilot_loop')
def test_old_start_endpoint_still_works(
self, mock_run_loop, mock_get_autopilot, client, mock_autopilot
):
"""Test that old /start endpoint still works (deprecated but functional)."""
mock_get_autopilot.return_value = mock_autopilot
response = client.post(
"/api/autopilot/start",
json={
"symbol": "BTC/USD",
"interval": 60.0,
"pattern_order": 5,
"auto_fetch_news": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "started"
assert data["symbol"] == "BTC/USD"
@patch('backend.api.autopilot.get_intelligent_autopilot')
def test_old_intelligent_start_endpoint_still_works(
self, mock_get_intelligent, client, mock_intelligent_autopilot
):
"""Test that old /intelligent/start endpoint still works (deprecated but functional)."""
mock_get_intelligent.return_value = mock_intelligent_autopilot
response = client.post(
"/api/autopilot/intelligent/start",
json={
"symbol": "BTC/USD",
"exchange_id": 1,
"timeframe": "1h",
"interval": 60.0,
"paper_trading": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "started"
assert data["symbol"] == "BTC/USD"

View File

@@ -0,0 +1,81 @@
"""Tests for exchanges API endpoints."""
import pytest
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from backend.main import app
from src.core.database import Exchange
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.fixture
def mock_exchange():
"""Mock exchange object."""
exchange = Mock(spec=Exchange)
exchange.id = 1
exchange.name = "coinbase"
exchange.is_enabled = True
exchange.api_permissions = "read_only"
return exchange
class TestListExchanges:
"""Tests for GET /api/exchanges."""
@patch('backend.api.exchanges.get_db')
def test_list_exchanges_success(self, mock_get_db, client):
"""Test listing exchanges."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_session.query.return_value.all.return_value = []
response = client.get("/api/exchanges")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
class TestGetExchange:
"""Tests for GET /api/exchanges/{exchange_id}."""
@patch('backend.api.exchanges.get_db')
def test_get_exchange_success(self, mock_get_db, client, mock_exchange):
"""Test getting exchange by ID."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_exchange
response = client.get("/api/exchanges/1")
assert response.status_code == 200
data = response.json()
assert data["id"] == 1
@patch('backend.api.exchanges.get_db')
def test_get_exchange_not_found(self, mock_get_db, client):
"""Test getting non-existent exchange."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_session.query.return_value.filter_by.return_value.first.return_value = None
response = client.get("/api/exchanges/999")
assert response.status_code == 404
assert "Exchange not found" in response.json()["detail"]

View File

@@ -0,0 +1,65 @@
"""Tests for portfolio API endpoints."""
import pytest
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from backend.main import app
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.fixture
def mock_portfolio_tracker():
"""Mock portfolio tracker."""
tracker = Mock()
tracker.get_current_portfolio.return_value = {
"positions": [],
"performance": {"total_return": 0.1},
"timestamp": "2025-01-01T00:00:00"
}
tracker.get_portfolio_history.return_value = {
"dates": ["2025-01-01"],
"values": [1000.0],
"pnl": [0.0]
}
return tracker
class TestGetCurrentPortfolio:
"""Tests for GET /api/portfolio/current."""
@patch('backend.api.portfolio.get_portfolio_tracker')
def test_get_current_portfolio_success(self, mock_get_tracker, client, mock_portfolio_tracker):
"""Test getting current portfolio."""
mock_get_tracker.return_value = mock_portfolio_tracker
response = client.get("/api/portfolio/current?paper_trading=true")
assert response.status_code == 200
data = response.json()
assert "positions" in data
assert "performance" in data
assert "timestamp" in data
class TestGetPortfolioHistory:
"""Tests for GET /api/portfolio/history."""
@patch('backend.api.portfolio.get_portfolio_tracker')
def test_get_portfolio_history_success(self, mock_get_tracker, client, mock_portfolio_tracker):
"""Test getting portfolio history."""
mock_get_tracker.return_value = mock_portfolio_tracker
response = client.get("/api/portfolio/history?paper_trading=true&days=30")
assert response.status_code == 200
data = response.json()
assert "dates" in data
assert "values" in data
assert "pnl" in data

View File

@@ -0,0 +1,108 @@
"""Tests for strategies API endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient
from backend.main import app
from src.core.database import Strategy
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.fixture
def mock_strategy_registry():
"""Mock strategy registry."""
registry = Mock()
registry.list_available.return_value = ["RSIStrategy", "MACDStrategy"]
return registry
@pytest.fixture
def mock_strategy():
"""Mock strategy object."""
strategy = Mock(spec=Strategy)
strategy.id = 1
strategy.name = "Test Strategy"
strategy.type = "RSIStrategy"
strategy.status = "active" # Use string instead of enum
strategy.symbol = "BTC/USD"
strategy.params = {}
return strategy
class TestListStrategies:
"""Tests for GET /api/strategies."""
@patch('backend.api.strategies.get_db')
def test_list_strategies_success(self, mock_get_db, client):
"""Test listing strategies."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_session.query.return_value.all.return_value = []
response = client.get("/api/strategies")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
class TestGetStrategy:
"""Tests for GET /api/strategies/{strategy_id}."""
@patch('backend.api.strategies.get_db')
def test_get_strategy_success(self, mock_get_db, client, mock_strategy):
"""Test getting strategy by ID."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_strategy
response = client.get("/api/strategies/1")
assert response.status_code == 200
data = response.json()
assert data["id"] == 1
@patch('backend.api.strategies.get_db')
def test_get_strategy_not_found(self, mock_get_db, client):
"""Test getting non-existent strategy."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_session.query.return_value.filter_by.return_value.first.return_value = None
response = client.get("/api/strategies/999")
assert response.status_code == 404
assert "Strategy not found" in response.json()["detail"]
class TestListAvailableStrategyTypes:
"""Tests for GET /api/strategies/types."""
@patch('backend.api.strategies.get_strategy_registry')
def test_list_strategy_types_success(self, mock_get_registry, client, mock_strategy_registry):
"""Test listing available strategy types."""
mock_get_registry.return_value = mock_strategy_registry
response = client.get("/api/strategies/types")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert "RSIStrategy" in data

View File

@@ -0,0 +1,293 @@
"""Tests for trading API endpoints."""
import pytest
from decimal import Decimal
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient
from datetime import datetime
from backend.main import app
from backend.core.schemas import OrderCreate, OrderSide, OrderType
from src.core.database import Order, OrderStatus
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.fixture
def mock_trading_engine():
"""Mock trading engine."""
engine = Mock()
engine.order_manager = Mock()
return engine
@pytest.fixture
def mock_order():
"""Mock order object."""
order = Mock()
order.id = 1
order.exchange_id = 1
order.strategy_id = None
order.symbol = "BTC/USD"
order.order_type = OrderType.MARKET
order.side = OrderSide.BUY
order.status = OrderStatus.FILLED
order.quantity = Decimal("0.1")
order.price = Decimal("50000")
order.filled_quantity = Decimal("0.1")
order.average_fill_price = Decimal("50000")
order.fee = Decimal("5")
order.paper_trading = True
order.created_at = datetime.now()
order.updated_at = datetime.now()
order.filled_at = datetime.now()
return order
class TestCreateOrder:
"""Tests for POST /api/trading/orders."""
@patch('backend.api.trading.get_trading_engine')
def test_create_order_success(self, mock_get_engine, client, mock_trading_engine, mock_order):
"""Test successful order creation."""
mock_get_engine.return_value = mock_trading_engine
mock_trading_engine.execute_order.return_value = mock_order
order_data = {
"exchange_id": 1,
"symbol": "BTC/USD",
"side": "buy",
"order_type": "market",
"quantity": "0.1",
"paper_trading": True
}
response = client.post("/api/trading/orders", json=order_data)
assert response.status_code == 200
data = response.json()
assert data["id"] == 1
assert data["symbol"] == "BTC/USD"
assert data["side"] == "buy"
mock_trading_engine.execute_order.assert_called_once()
@patch('backend.api.trading.get_trading_engine')
def test_create_order_execution_failed(self, mock_get_engine, client, mock_trading_engine):
"""Test order creation when execution fails."""
mock_get_engine.return_value = mock_trading_engine
mock_trading_engine.execute_order.return_value = None
order_data = {
"exchange_id": 1,
"symbol": "BTC/USD",
"side": "buy",
"order_type": "market",
"quantity": "0.1",
"paper_trading": True
}
response = client.post("/api/trading/orders", json=order_data)
assert response.status_code == 400
assert "Order execution failed" in response.json()["detail"]
@patch('backend.api.trading.get_trading_engine')
def test_create_order_invalid_data(self, client):
"""Test order creation with invalid data."""
order_data = {
"exchange_id": "invalid",
"symbol": "BTC/USD",
"side": "buy",
"order_type": "market",
"quantity": "0.1"
}
response = client.post("/api/trading/orders", json=order_data)
assert response.status_code == 422 # Validation error
class TestGetOrders:
"""Tests for GET /api/trading/orders."""
@patch('backend.api.trading.get_db')
def test_get_orders_success(self, mock_get_db, client, mock_database):
"""Test getting orders."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
# Create mock orders
mock_order1 = Mock(spec=Order)
mock_order1.id = 1
mock_order1.symbol = "BTC/USD"
mock_order1.paper_trading = True
mock_order2 = Mock(spec=Order)
mock_order2.id = 2
mock_order2.symbol = "ETH/USD"
mock_order2.paper_trading = True
mock_session.query.return_value.filter_by.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_order1, mock_order2]
response = client.get("/api/trading/orders?paper_trading=true&limit=10")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
class TestGetOrder:
"""Tests for GET /api/trading/orders/{order_id}."""
@patch('backend.api.trading.get_db')
def test_get_order_success(self, mock_get_db, client, mock_database):
"""Test getting order by ID."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_order = Mock(spec=Order)
mock_order.id = 1
mock_order.symbol = "BTC/USD"
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
response = client.get("/api/trading/orders/1")
assert response.status_code == 200
data = response.json()
assert data["id"] == 1
@patch('backend.api.trading.get_db')
def test_get_order_not_found(self, mock_get_db, client):
"""Test getting non-existent order."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_session.query.return_value.filter_by.return_value.first.return_value = None
response = client.get("/api/trading/orders/999")
assert response.status_code == 404
assert "Order not found" in response.json()["detail"]
class TestCancelOrder:
"""Tests for POST /api/trading/orders/{order_id}/cancel."""
@patch('backend.api.trading.get_trading_engine')
@patch('backend.api.trading.get_db')
def test_cancel_order_success(self, mock_get_db, mock_get_engine, client, mock_trading_engine):
"""Test successful order cancellation."""
mock_get_engine.return_value = mock_trading_engine
mock_trading_engine.order_manager.cancel_order.return_value = True
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_order = Mock(spec=Order)
mock_order.id = 1
mock_order.status = OrderStatus.OPEN
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
response = client.post("/api/trading/orders/1/cancel")
assert response.status_code == 200
data = response.json()
assert data["status"] == "cancelled"
assert data["order_id"] == 1
@patch('backend.api.trading.get_db')
def test_cancel_order_not_found(self, mock_get_db, client):
"""Test cancelling non-existent order."""
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_session.query.return_value.filter_by.return_value.first.return_value = None
response = client.post("/api/trading/orders/999/cancel")
assert response.status_code == 404
assert "Order not found" in response.json()["detail"]
@patch('backend.api.trading.get_trading_engine')
@patch('backend.api.trading.get_db')
def test_cancel_order_already_filled(self, mock_get_db, mock_get_engine, client, mock_trading_engine):
"""Test cancelling already filled order."""
mock_get_engine.return_value = mock_trading_engine
mock_db = Mock()
mock_session = Mock()
mock_db.get_session.return_value = mock_session
mock_get_db.return_value = mock_db
mock_order = Mock(spec=Order)
mock_order.id = 1
mock_order.status = OrderStatus.FILLED
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
response = client.post("/api/trading/orders/1/cancel")
assert response.status_code == 400
assert "cannot be cancelled" in response.json()["detail"]
class TestGetPositions:
"""Tests for GET /api/trading/positions."""
@patch('backend.api.trading.get_paper_trading')
def test_get_positions_paper_trading(self, mock_get_paper, client):
"""Test getting positions for paper trading."""
mock_paper = Mock()
mock_position = Mock()
mock_position.symbol = "BTC/USD"
mock_position.quantity = Decimal("0.1")
mock_position.entry_price = Decimal("50000")
mock_position.current_price = Decimal("51000")
mock_position.unrealized_pnl = Decimal("100")
mock_position.realized_pnl = Decimal("0")
mock_paper.get_positions.return_value = [mock_position]
mock_get_paper.return_value = mock_paper
response = client.get("/api/trading/positions?paper_trading=true")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
if len(data) > 0:
assert data[0]["symbol"] == "BTC/USD"
class TestGetBalance:
"""Tests for GET /api/trading/balance."""
@patch('backend.api.trading.get_paper_trading')
def test_get_balance_paper_trading(self, mock_get_paper, client):
"""Test getting balance for paper trading."""
mock_paper = Mock()
mock_paper.get_balance.return_value = Decimal("1000")
mock_paper.get_performance.return_value = {"total_return": 0.1}
mock_get_paper.return_value = mock_paper
response = client.get("/api/trading/balance?paper_trading=true")
assert response.status_code == 200
data = response.json()
assert "balance" in data
assert "performance" in data
assert data["balance"] == 1000.0

View File

@@ -0,0 +1,77 @@
"""Tests for backend dependencies."""
import pytest
from unittest.mock import patch, Mock
from backend.core.dependencies import (
get_database, get_trading_engine, get_portfolio_tracker,
get_strategy_registry, get_backtesting_engine, get_exchange_factory
)
class TestGetDatabase:
"""Tests for get_database dependency."""
def test_get_database_singleton(self):
"""Test that get_database returns same instance."""
db1 = get_database()
db2 = get_database()
# Should be cached/same instance due to lru_cache
assert db1 is db2
class TestGetTradingEngine:
"""Tests for get_trading_engine dependency."""
@patch('backend.core.dependencies.get_trading_engine')
def test_get_trading_engine(self, mock_get_engine):
"""Test getting trading engine."""
mock_engine = Mock()
mock_get_engine.return_value = mock_engine
engine = get_trading_engine()
assert engine is not None
class TestGetPortfolioTracker:
"""Tests for get_portfolio_tracker dependency."""
@patch('backend.core.dependencies.get_portfolio_tracker')
def test_get_portfolio_tracker(self, mock_get_tracker):
"""Test getting portfolio tracker."""
mock_tracker = Mock()
mock_get_tracker.return_value = mock_tracker
tracker = get_portfolio_tracker()
assert tracker is not None
class TestGetStrategyRegistry:
"""Tests for get_strategy_registry dependency."""
@patch('backend.core.dependencies.get_strategy_registry')
def test_get_strategy_registry(self, mock_get_registry):
"""Test getting strategy registry."""
mock_registry = Mock()
mock_get_registry.return_value = mock_registry
registry = get_strategy_registry()
assert registry is not None
class TestGetBacktestingEngine:
"""Tests for get_backtesting_engine dependency."""
@patch('backend.core.dependencies.get_backtesting_engine')
def test_get_backtesting_engine(self, mock_get_engine):
"""Test getting backtesting engine."""
mock_engine = Mock()
mock_get_engine.return_value = mock_engine
engine = get_backtesting_engine()
assert engine is not None
class TestGetExchangeFactory:
"""Tests for get_exchange_factory dependency."""
def test_get_exchange_factory(self):
"""Test getting exchange factory."""
factory = get_exchange_factory()
assert factory is not None

View File

@@ -0,0 +1,128 @@
"""Tests for Pydantic schemas."""
import pytest
from decimal import Decimal
from datetime import datetime
from pydantic import ValidationError
from backend.core.schemas import (
OrderCreate, OrderResponse, OrderSide, OrderType, OrderStatus,
PositionResponse, PortfolioResponse, PortfolioHistoryResponse
)
class TestOrderCreate:
"""Tests for OrderCreate schema."""
def test_order_create_valid(self):
"""Test valid order creation."""
order = OrderCreate(
exchange_id=1,
symbol="BTC/USD",
side=OrderSide.BUY,
order_type=OrderType.MARKET,
quantity=Decimal("0.1"),
paper_trading=True
)
assert order.exchange_id == 1
assert order.symbol == "BTC/USD"
assert order.side == OrderSide.BUY
def test_order_create_with_price(self):
"""Test order creation with price."""
order = OrderCreate(
exchange_id=1,
symbol="BTC/USD",
side=OrderSide.BUY,
order_type=OrderType.LIMIT,
quantity=Decimal("0.1"),
price=Decimal("50000"),
paper_trading=True
)
assert order.price == Decimal("50000")
def test_order_create_invalid_side(self):
"""Test order creation with invalid side."""
with pytest.raises(ValidationError):
OrderCreate(
exchange_id=1,
symbol="BTC/USD",
side="invalid",
order_type=OrderType.MARKET,
quantity=Decimal("0.1")
)
class TestOrderResponse:
"""Tests for OrderResponse schema."""
def test_order_response_from_dict(self):
"""Test creating OrderResponse from dictionary."""
order_data = {
"id": 1,
"exchange_id": 1,
"strategy_id": None,
"symbol": "BTC/USD",
"order_type": OrderType.MARKET,
"side": OrderSide.BUY,
"status": OrderStatus.FILLED,
"quantity": Decimal("0.1"),
"price": Decimal("50000"),
"filled_quantity": Decimal("0.1"),
"average_fill_price": Decimal("50000"),
"fee": Decimal("5"),
"paper_trading": True,
"created_at": datetime.now(),
"updated_at": datetime.now(),
"filled_at": datetime.now()
}
order = OrderResponse(**order_data)
assert order.id == 1
assert order.symbol == "BTC/USD"
class TestPositionResponse:
"""Tests for PositionResponse schema."""
def test_position_response_valid(self):
"""Test valid position response."""
position = PositionResponse(
symbol="BTC/USD",
quantity=Decimal("0.1"),
entry_price=Decimal("50000"),
current_price=Decimal("51000"),
unrealized_pnl=Decimal("100"),
realized_pnl=Decimal("0")
)
assert position.symbol == "BTC/USD"
assert position.unrealized_pnl == Decimal("100")
class TestPortfolioResponse:
"""Tests for PortfolioResponse schema."""
def test_portfolio_response_valid(self):
"""Test valid portfolio response."""
portfolio = PortfolioResponse(
positions=[],
performance={"total_return": 0.1},
timestamp="2025-01-01T00:00:00"
)
assert portfolio.positions == []
assert portfolio.performance["total_return"] == 0.1
class TestPortfolioHistoryResponse:
"""Tests for PortfolioHistoryResponse schema."""
def test_portfolio_history_response_valid(self):
"""Test valid portfolio history response."""
history = PortfolioHistoryResponse(
dates=["2025-01-01", "2025-01-02"],
values=[1000.0, 1100.0],
pnl=[0.0, 100.0]
)
assert len(history.dates) == 2
assert len(history.values) == 2

View File

@@ -0,0 +1,2 @@
"""Unit tests for backtesting."""

View File

@@ -0,0 +1,26 @@
"""Tests for backtesting engine."""
import pytest
from datetime import datetime, timedelta
from src.backtesting.engine import get_backtest_engine, BacktestingEngine
class TestBacktestingEngine:
"""Tests for BacktestingEngine."""
@pytest.fixture
def backtest_engine(self):
"""Create backtesting engine instance."""
return get_backtest_engine()
def test_engine_initialization(self, backtest_engine):
"""Test backtesting engine initialization."""
assert backtest_engine is not None
@pytest.mark.asyncio
async def test_run_backtest(self, backtest_engine):
"""Test running a backtest."""
# This would require a full strategy implementation
# Simplified test
assert backtest_engine is not None

View File

@@ -0,0 +1,85 @@
"""Tests for slippage model."""
import pytest
from decimal import Decimal
from src.backtesting.slippage import SlippageModel, FeeModel
class TestSlippageModel:
"""Tests for SlippageModel."""
@pytest.fixture
def slippage_model(self):
"""Create slippage model."""
return SlippageModel(slippage_rate=0.001)
def test_calculate_fill_price_market_buy(self, slippage_model):
"""Test fill price calculation for market buy."""
order_price = Decimal('50000.0')
market_price = Decimal('50000.0')
fill_price = slippage_model.calculate_fill_price(
order_price, "buy", "market", market_price
)
assert fill_price > market_price # Buy orders pay more
def test_calculate_fill_price_market_sell(self, slippage_model):
"""Test fill price calculation for market sell."""
order_price = Decimal('50000.0')
market_price = Decimal('50000.0')
fill_price = slippage_model.calculate_fill_price(
order_price, "sell", "market", market_price
)
assert fill_price < market_price # Sell orders receive less
def test_calculate_fill_price_limit(self, slippage_model):
"""Test fill price for limit orders."""
order_price = Decimal('49000.0')
market_price = Decimal('50000.0')
fill_price = slippage_model.calculate_fill_price(
order_price, "buy", "limit", market_price
)
assert fill_price == order_price # Limit orders fill at order price
class TestFeeModel:
"""Tests for FeeModel."""
@pytest.fixture
def fee_model(self):
"""Create fee model."""
return FeeModel(maker_fee=0.001, taker_fee=0.002)
def test_calculate_fee_maker(self, fee_model):
"""Test maker fee calculation."""
fee = fee_model.calculate_fee(
quantity=Decimal('0.01'),
price=Decimal('50000.0'),
is_maker=True
)
assert fee > 0
# Fee should be 0.1% of trade value
expected = Decimal('0.01') * Decimal('50000.0') * Decimal('0.001')
assert abs(float(fee - expected)) < 0.01
def test_calculate_fee_taker(self, fee_model):
"""Test taker fee calculation."""
fee = fee_model.calculate_fee(
quantity=Decimal('0.01'),
price=Decimal('50000.0'),
is_maker=False
)
assert fee > 0
# Taker fee should be higher than maker
maker_fee = fee_model.calculate_fee(
Decimal('0.01'), Decimal('50000.0'), is_maker=True
)
assert fee > maker_fee

View File

@@ -0,0 +1 @@
"""Test init file."""

View File

@@ -0,0 +1,44 @@
"""Tests for configuration system."""
import pytest
import os
import tempfile
from pathlib import Path
from unittest.mock import patch
from src.core.config import Config, get_config
class TestConfig:
"""Tests for Config class."""
def test_config_initialization(self, tmp_path):
"""Test config initialization."""
config_file = tmp_path / "config.yaml"
config = Config(config_file=str(config_file))
assert config is not None
assert config.config_dir is not None
assert config.data_dir is not None
def test_config_get(self, tmp_path):
"""Test config get method."""
config_file = tmp_path / "config.yaml"
config = Config(config_file=str(config_file))
# Test nested key access
value = config.get('paper_trading.default_capital')
assert value is not None
def test_config_set(self, tmp_path):
"""Test config set method."""
config_file = tmp_path / "config.yaml"
config = Config(config_file=str(config_file))
config.set('paper_trading.default_capital', 200.0)
value = config.get('paper_trading.default_capital')
assert value == 200.0
def test_config_defaults(self, tmp_path):
"""Test default configuration values."""
config_file = tmp_path / "config.yaml"
config = Config(config_file=str(config_file))
assert config.get('paper_trading.default_capital') == 100.0
assert config.get('database.type') == 'postgresql'

View File

@@ -0,0 +1,97 @@
"""Tests for database system."""
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from src.core.database import (
get_database, Base, Exchange, Strategy, Trade, Position, Order
)
class TestDatabase:
"""Tests for database operations."""
def test_database_initialization(self):
"""Test database initialization."""
db = get_database()
assert db is not None
assert db.engine is not None
def test_table_creation(self, mock_database):
"""Test table creation."""
engine, Session = mock_database
# Verify tables exist
assert Base.metadata.tables.get('exchanges') is not None
assert Base.metadata.tables.get('strategies') is not None
assert Base.metadata.tables.get('trades') is not None
def test_exchange_model(self, mock_database):
"""Test Exchange model."""
engine, Session = mock_database
session = Session()
exchange = Exchange(
name="test_exchange",
api_key="encrypted_key",
secret_key="encrypted_secret",
api_permissions="read_only",
is_enabled=True
)
session.add(exchange)
session.commit()
retrieved = session.query(Exchange).filter_by(name="test_exchange").first()
assert retrieved is not None
assert retrieved.name == "test_exchange"
assert retrieved.api_permissions == "read_only"
session.close()
def test_strategy_model(self, mock_database):
"""Test Strategy model."""
engine, Session = mock_database
session = Session()
strategy = Strategy(
name="test_strategy",
strategy_type="RSI",
parameters='{"rsi_period": 14}',
is_enabled=True,
is_paper_trading=True
)
session.add(strategy)
session.commit()
retrieved = session.query(Strategy).filter_by(name="test_strategy").first()
assert retrieved is not None
assert retrieved.strategy_type == "RSI"
session.close()
def test_trade_model(self, mock_database):
"""Test Trade model."""
engine, Session = mock_database
session = Session()
trade = Trade(
order_id="test_order_123",
symbol="BTC/USD",
side="buy",
type="market",
price=50000.0,
amount=0.01,
cost=500.0,
fee=0.5,
status="filled",
is_paper_trade=True
)
session.add(trade)
session.commit()
retrieved = session.query(Trade).filter_by(order_id="test_order_123").first()
assert retrieved is not None
assert retrieved.symbol == "BTC/USD"
assert retrieved.status == "filled"
session.close()

View File

@@ -0,0 +1,43 @@
"""Tests for logging system."""
import pytest
import logging
from pathlib import Path
from unittest.mock import patch
from src.core.logger import setup_logging, get_logger
class TestLogger:
"""Tests for logging system."""
def test_logger_setup(self, test_log_dir):
"""Test logger setup."""
with patch('src.core.logger.get_config') as mock_get_config:
mock_config = mock_get_config.return_value
mock_config.get.side_effect = lambda key, default=None: {
'logging.dir': str(test_log_dir),
'logging.retention_days': 30,
'logging.level': 'INFO'
}.get(key, default)
setup_logging()
logger = get_logger('test')
assert logger is not None
assert isinstance(logger, logging.Logger)
def test_logger_get(self):
"""Test getting logger instance."""
logger = get_logger('test_module')
assert logger is not None
assert logger.name == 'test_module'
def test_logger_levels(self):
"""Test different log levels."""
logger = get_logger('test')
# Should not raise exceptions
logger.debug("Debug message")
logger.info("Info message")
logger.warning("Warning message")
logger.error("Error message")
logger.critical("Critical message")

View File

@@ -0,0 +1,92 @@
"""Tests for Redis client wrapper."""
import pytest
from unittest.mock import Mock, patch, MagicMock, AsyncMock
class TestRedisClient:
"""Tests for RedisClient class."""
@patch('src.core.redis.get_config')
def test_get_client_creates_connection(self, mock_config):
"""Test that get_client creates a Redis connection."""
# Setup mock config
mock_config.return_value.get.return_value = {
"host": "localhost",
"port": 6379,
"db": 0,
"password": None,
"socket_connect_timeout": 5
}
from src.core.redis import RedisClient
client = RedisClient()
# Should not have connected yet
assert client._client is None
# Get client should trigger connection
with patch('src.core.redis.redis.ConnectionPool') as mock_pool:
with patch('src.core.redis.redis.Redis') as mock_redis:
redis_client = client.get_client()
mock_pool.assert_called_once()
mock_redis.assert_called_once()
@patch('src.core.redis.get_config')
def test_get_client_reuses_existing(self, mock_config):
"""Test that get_client reuses existing connection."""
mock_config.return_value.get.return_value = {
"host": "localhost",
"port": 6379,
"db": 0,
}
from src.core.redis import RedisClient
client = RedisClient()
# Pre-set a mock client
mock_redis = Mock()
client._client = mock_redis
# Should return existing
result = client.get_client()
assert result is mock_redis
@patch('src.core.redis.get_config')
@pytest.mark.asyncio
async def test_close_connection(self, mock_config):
"""Test closing Redis connection."""
mock_config.return_value.get.return_value = {"host": "localhost"}
from src.core.redis import RedisClient
client = RedisClient()
mock_redis = AsyncMock()
client._client = mock_redis
await client.close()
mock_redis.aclose.assert_called_once()
class TestGetRedisClient:
"""Tests for get_redis_client singleton."""
@patch('src.core.redis.get_config')
def test_returns_singleton(self, mock_config):
"""Test that get_redis_client returns same instance."""
mock_config.return_value.get.return_value = {"host": "localhost"}
# Reset the global
import src.core.redis as redis_module
redis_module._redis_client = None
from src.core.redis import get_redis_client
client1 = get_redis_client()
client2 = get_redis_client()
assert client1 is client2

View File

@@ -0,0 +1 @@
"""Test init file."""

View File

@@ -0,0 +1,139 @@
"""Unit tests for CCXT pricing provider."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from decimal import Decimal
from datetime import datetime
from src.data.providers.ccxt_provider import CCXTProvider
@pytest.fixture
def mock_ccxt_exchange():
"""Create a mock CCXT exchange."""
exchange = Mock()
exchange.markets = {
'BTC/USDT': {},
'ETH/USDT': {},
'BTC/USD': {},
}
exchange.id = 'kraken'
exchange.fetch_ticker = Mock(return_value={
'bid': 50000.0,
'ask': 50001.0,
'last': 50000.5,
'high': 51000.0,
'low': 49000.0,
'quoteVolume': 1000000.0,
'timestamp': 1609459200000,
})
exchange.fetch_ohlcv = Mock(return_value=[
[1609459200000, 50000, 51000, 49000, 50000, 1000],
])
exchange.load_markets = Mock(return_value=exchange.markets)
return exchange
@pytest.fixture
def provider():
"""Create a CCXT provider instance."""
return CCXTProvider(exchange_name='kraken')
class TestCCXTProvider:
"""Tests for CCXTProvider."""
def test_init(self, provider):
"""Test provider initialization."""
assert provider.name == "CCXT Provider"
assert not provider._connected
assert provider.exchange is None
def test_name_property(self, provider):
"""Test name property."""
provider._selected_exchange_id = 'kraken'
assert provider.name == "CCXT-Kraken"
@patch('src.data.providers.ccxt_provider.ccxt')
def test_connect_success(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test successful connection."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
result = provider.connect()
assert result is True
assert provider._connected is True
assert provider.exchange == mock_ccxt_exchange
assert provider._selected_exchange_id == 'kraken'
@patch('src.data.providers.ccxt_provider.ccxt')
def test_connect_failure(self, mock_ccxt, provider):
"""Test connection failure."""
mock_ccxt.kraken = Mock(side_effect=Exception("Connection failed"))
result = provider.connect()
assert result is False
assert not provider._connected
@patch('src.data.providers.ccxt_provider.ccxt')
def test_get_ticker(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test getting ticker data."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
provider.connect()
ticker = provider.get_ticker('BTC/USDT')
assert ticker['symbol'] == 'BTC/USDT'
assert isinstance(ticker['bid'], Decimal)
assert isinstance(ticker['last'], Decimal)
assert ticker['last'] > 0
@patch('src.data.providers.ccxt_provider.ccxt')
def test_get_ohlcv(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test getting OHLCV data."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
provider.connect()
ohlcv = provider.get_ohlcv('BTC/USDT', '1h', limit=10)
assert len(ohlcv) > 0
assert len(ohlcv[0]) == 6 # timestamp, open, high, low, close, volume
@patch('src.data.providers.ccxt_provider.ccxt')
def test_subscribe_ticker(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test subscribing to ticker updates."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
provider.connect()
callback = Mock()
result = provider.subscribe_ticker('BTC/USDT', callback)
assert result is True
assert 'ticker_BTC/USDT' in provider._subscribers
def test_normalize_symbol(self, provider):
"""Test symbol normalization."""
# Test with exchange
with patch.object(provider, 'exchange') as mock_exchange:
mock_exchange.markets = {'BTC/USDT': {}}
normalized = provider.normalize_symbol('btc-usdt')
assert normalized == 'BTC/USDT'
# Test without exchange
provider.exchange = None
normalized = provider.normalize_symbol('btc-usdt')
assert normalized == 'BTC/USDT'
@patch('src.data.providers.ccxt_provider.ccxt')
def test_disconnect(self, mock_ccxt, provider, mock_ccxt_exchange):
"""Test disconnection."""
mock_ccxt.kraken = Mock(return_value=mock_ccxt_exchange)
provider.connect()
provider.subscribe_ticker('BTC/USDT', Mock())
provider.disconnect()
assert not provider._connected
assert provider.exchange is None
assert len(provider._subscribers) == 0

View File

@@ -0,0 +1,113 @@
"""Unit tests for CoinGecko pricing provider."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from decimal import Decimal
import httpx
from src.data.providers.coingecko_provider import CoinGeckoProvider
@pytest.fixture
def provider():
"""Create a CoinGecko provider instance."""
return CoinGeckoProvider()
@pytest.fixture
def mock_response():
"""Create a mock HTTP response."""
response = Mock()
response.status_code = 200
response.json = Mock(return_value={
'bitcoin': {
'usd': 50000.0,
'usd_24h_change': 2.5,
'usd_24h_vol': 1000000.0,
}
})
return response
class TestCoinGeckoProvider:
"""Tests for CoinGeckoProvider."""
def test_init(self, provider):
"""Test provider initialization."""
assert provider.name == "CoinGecko"
assert not provider.supports_websocket
assert not provider._connected
@patch('src.data.providers.coingecko_provider.httpx.Client')
def test_connect_success(self, mock_client_class, provider, mock_response):
"""Test successful connection."""
mock_client = Mock()
mock_client.get = Mock(return_value=mock_response)
mock_client_class.return_value = mock_client
result = provider.connect()
assert result is True
assert provider._connected is True
@patch('src.data.providers.coingecko_provider.httpx.Client')
def test_connect_failure(self, mock_client_class, provider):
"""Test connection failure."""
mock_client = Mock()
mock_client.get = Mock(side_effect=Exception("Connection failed"))
mock_client_class.return_value = mock_client
result = provider.connect()
assert result is False
assert not provider._connected
def test_parse_symbol(self, provider):
"""Test symbol parsing."""
coin_id, currency = provider._parse_symbol('BTC/USD')
assert coin_id == 'bitcoin'
assert currency == 'usd'
coin_id, currency = provider._parse_symbol('ETH/USDT')
assert coin_id == 'ethereum'
assert currency == 'usd' # USDT maps to USD
@patch('src.data.providers.coingecko_provider.httpx.Client')
def test_get_ticker(self, mock_client_class, provider, mock_response):
"""Test getting ticker data."""
mock_client = Mock()
mock_client.get = Mock(return_value=mock_response)
mock_client_class.return_value = mock_client
provider.connect()
ticker = provider.get_ticker('BTC/USD')
assert ticker['symbol'] == 'BTC/USD'
assert isinstance(ticker['last'], Decimal)
assert ticker['last'] > 0
assert 'timestamp' in ticker
@patch('src.data.providers.coingecko_provider.httpx.Client')
def test_get_ohlcv(self, mock_client_class, provider):
"""Test getting OHLCV data."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=[
[1609459200000, 50000, 51000, 49000, 50000],
])
mock_client = Mock()
mock_client.get = Mock(return_value=mock_response)
mock_client_class.return_value = mock_client
provider.connect()
ohlcv = provider.get_ohlcv('BTC/USD', '1h', limit=10)
assert len(ohlcv) > 0
# CoinGecko returns 5 elements, we add volume as 0
assert len(ohlcv[0]) == 6
def test_normalize_symbol(self, provider):
"""Test symbol normalization."""
normalized = provider.normalize_symbol('btc-usdt')
assert normalized == 'BTC/USDT'

View File

@@ -0,0 +1,120 @@
"""Unit tests for cache manager."""
import pytest
import time
from src.data.cache_manager import CacheManager, CacheEntry
class TestCacheEntry:
"""Tests for CacheEntry."""
def test_init(self):
"""Test cache entry initialization."""
entry = CacheEntry("test_data", 60.0)
assert entry.data == "test_data"
assert entry.expires_at > time.time()
assert entry.access_count == 0
def test_is_expired(self):
"""Test expiration checking."""
entry = CacheEntry("test_data", 0.01) # Very short TTL
assert not entry.is_expired()
time.sleep(0.02)
assert entry.is_expired()
def test_touch(self):
"""Test access tracking."""
entry = CacheEntry("test_data", 60.0)
initial_count = entry.access_count
entry.touch()
assert entry.access_count == initial_count + 1
class TestCacheManager:
"""Tests for CacheManager."""
@pytest.fixture
def cache(self):
"""Create a cache manager instance."""
return CacheManager(default_ttl=1.0, max_size=10)
def test_get_set(self, cache):
"""Test basic get and set operations."""
cache.set("key1", "value1")
assert cache.get("key1") == "value1"
def test_get_missing(self, cache):
"""Test getting non-existent key."""
assert cache.get("missing") is None
def test_expiration(self, cache):
"""Test cache entry expiration."""
cache.set("key1", "value1", ttl=0.1)
assert cache.get("key1") == "value1"
time.sleep(0.2)
assert cache.get("key1") is None
def test_lru_eviction(self, cache):
"""Test LRU eviction when max size reached."""
# Fill cache to max size
for i in range(10):
cache.set(f"key{i}", f"value{i}")
# Add one more - should evict oldest
cache.set("key10", "value10")
# Oldest key should be evicted
assert cache.get("key0") is None
assert cache.get("key10") == "value10"
def test_type_specific_ttl(self, cache):
"""Test type-specific TTL."""
cache.set("ticker1", {"price": 100}, cache_type='ticker')
cache.set("ohlcv1", [[1, 2, 3, 4, 5, 6]], cache_type='ohlcv')
# Both should be cached
assert cache.get("ticker1") is not None
assert cache.get("ohlcv1") is not None
def test_delete(self, cache):
"""Test cache entry deletion."""
cache.set("key1", "value1")
assert cache.get("key1") == "value1"
cache.delete("key1")
assert cache.get("key1") is None
def test_clear(self, cache):
"""Test cache clearing."""
cache.set("key1", "value1")
cache.set("key2", "value2")
cache.clear()
assert cache.get("key1") is None
assert cache.get("key2") is None
def test_stats(self, cache):
"""Test cache statistics."""
cache.set("key1", "value1")
cache.get("key1") # Hit
cache.get("missing") # Miss
stats = cache.get_stats()
assert stats['hits'] >= 1
assert stats['misses'] >= 1
assert stats['size'] == 1
assert 'hit_rate' in stats
def test_invalidate_pattern(self, cache):
"""Test pattern-based invalidation."""
cache.set("ticker:BTC/USD", "value1")
cache.set("ticker:ETH/USD", "value2")
cache.set("ohlcv:BTC/USD", "value3")
cache.invalidate_pattern("ticker:")
assert cache.get("ticker:BTC/USD") is None
assert cache.get("ticker:ETH/USD") is None
assert cache.get("ohlcv:BTC/USD") is not None

View File

@@ -0,0 +1,145 @@
"""Unit tests for health monitor."""
import pytest
from datetime import datetime, timedelta
from src.data.health_monitor import HealthMonitor, HealthMetrics, HealthStatus
class TestHealthMetrics:
"""Tests for HealthMetrics."""
def test_record_success(self):
"""Test recording successful operation."""
metrics = HealthMetrics()
metrics.record_success(0.5)
assert metrics.status == HealthStatus.HEALTHY
assert metrics.success_count == 1
assert metrics.consecutive_failures == 0
assert len(metrics.response_times) == 1
def test_record_failure(self):
"""Test recording failed operation."""
metrics = HealthMetrics()
metrics.record_failure()
assert metrics.failure_count == 1
assert metrics.consecutive_failures == 1
def test_circuit_breaker(self):
"""Test circuit breaker opening."""
metrics = HealthMetrics()
# Record 5 failures
for _ in range(5):
metrics.record_failure()
assert metrics.circuit_breaker_open is True
assert metrics.consecutive_failures == 5
def test_should_attempt(self):
"""Test should_attempt logic."""
metrics = HealthMetrics()
# Should attempt if circuit breaker not open
assert metrics.should_attempt() is True
# Open circuit breaker
for _ in range(5):
metrics.record_failure()
# Should not attempt immediately
assert metrics.should_attempt(circuit_breaker_timeout=60) is False
def test_get_avg_response_time(self):
"""Test average response time calculation."""
metrics = HealthMetrics()
metrics.response_times.extend([0.1, 0.2, 0.3])
avg = metrics.get_avg_response_time()
assert avg == 0.2
class TestHealthMonitor:
"""Tests for HealthMonitor."""
@pytest.fixture
def monitor(self):
"""Create a health monitor instance."""
return HealthMonitor()
def test_record_success(self, monitor):
"""Test recording success."""
monitor.record_success("provider1", 0.5)
metrics = monitor.get_metrics("provider1")
assert metrics is not None
assert metrics.status == HealthStatus.HEALTHY
assert metrics.success_count == 1
def test_record_failure(self, monitor):
"""Test recording failure."""
monitor.record_failure("provider1")
metrics = monitor.get_metrics("provider1")
assert metrics is not None
assert metrics.failure_count == 1
assert metrics.consecutive_failures == 1
def test_is_healthy(self, monitor):
"""Test health checking."""
# No metrics yet - assume healthy
assert monitor.is_healthy("provider1") is True
# Record success
monitor.record_success("provider1", 0.5)
assert monitor.is_healthy("provider1") is True
# Record many failures
for _ in range(10):
monitor.record_failure("provider1")
assert monitor.is_healthy("provider1") is False
def test_get_health_status(self, monitor):
"""Test getting health status."""
monitor.record_success("provider1", 0.5)
status = monitor.get_health_status("provider1")
assert status == HealthStatus.HEALTHY
def test_select_healthiest(self, monitor):
"""Test selecting healthiest provider."""
# Make provider1 healthy
monitor.record_success("provider1", 0.1)
monitor.record_success("provider1", 0.2)
# Make provider2 unhealthy
monitor.record_failure("provider2")
monitor.record_failure("provider2")
monitor.record_failure("provider2")
healthiest = monitor.select_healthiest(["provider1", "provider2"])
assert healthiest == "provider1"
def test_reset_circuit_breaker(self, monitor):
"""Test resetting circuit breaker."""
# Open circuit breaker
for _ in range(5):
monitor.record_failure("provider1")
assert monitor.get_metrics("provider1").circuit_breaker_open is True
monitor.reset_circuit_breaker("provider1")
metrics = monitor.get_metrics("provider1")
assert metrics.circuit_breaker_open is False
assert metrics.consecutive_failures == 0
def test_reset_metrics(self, monitor):
"""Test resetting metrics."""
monitor.record_success("provider1", 0.5)
assert monitor.get_metrics("provider1") is not None
monitor.reset_metrics("provider1")
assert monitor.get_metrics("provider1") is None

View File

@@ -0,0 +1,68 @@
"""Tests for technical indicators."""
import pytest
import pandas as pd
import numpy as np
from src.data.indicators import get_indicators, TechnicalIndicators
class TestTechnicalIndicators:
"""Tests for TechnicalIndicators."""
@pytest.fixture
def indicators(self):
"""Create indicators instance."""
return get_indicators()
@pytest.fixture
def sample_data(self):
"""Create sample price data."""
dates = pd.date_range(start='2025-01-01', periods=100, freq='1H')
return pd.DataFrame({
'close': [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)],
'high': [101 + i * 0.1 for i in range(100)],
'low': [99 + i * 0.1 for i in range(100)],
'open': [100 + i * 0.1 for i in range(100)],
'volume': [1000.0] * 100
})
def test_sma(self, indicators, sample_data):
"""Test Simple Moving Average."""
sma = indicators.sma(sample_data['close'], period=20)
assert len(sma) == len(sample_data)
assert not sma.isna().all() # Should have some valid values
def test_ema(self, indicators, sample_data):
"""Test Exponential Moving Average."""
ema = indicators.ema(sample_data['close'], period=20)
assert len(ema) == len(sample_data)
def test_rsi(self, indicators, sample_data):
"""Test Relative Strength Index."""
rsi = indicators.rsi(sample_data['close'], period=14)
assert len(rsi) == len(sample_data)
# RSI should be between 0 and 100
valid_rsi = rsi.dropna()
if len(valid_rsi) > 0:
assert (valid_rsi >= 0).all()
assert (valid_rsi <= 100).all()
def test_macd(self, indicators, sample_data):
"""Test MACD."""
macd_result = indicators.macd(sample_data['close'], fast=12, slow=26, signal=9)
assert 'macd' in macd_result
assert 'signal' in macd_result
assert 'histogram' in macd_result
def test_bollinger_bands(self, indicators, sample_data):
"""Test Bollinger Bands."""
bb = indicators.bollinger_bands(sample_data['close'], period=20, std_dev=2)
assert 'upper' in bb
assert 'middle' in bb
assert 'lower' in bb
# Upper should be above middle, middle above lower
valid_data = bb.dropna()
if len(valid_data) > 0:
assert (valid_data['upper'] >= valid_data['middle']).all()
assert (valid_data['middle'] >= valid_data['lower']).all()

View File

@@ -0,0 +1,80 @@
"""Tests for divergence detection in indicators."""
import pytest
import pandas as pd
import numpy as np
from src.data.indicators import get_indicators
class TestDivergenceDetection:
"""Tests for divergence detection."""
@pytest.fixture
def indicators(self):
"""Create indicators instance."""
return get_indicators()
@pytest.fixture
def sample_data(self):
"""Create sample price data with clear trend."""
dates = pd.date_range(start='2025-01-01', periods=100, freq='1H')
# Create price data with trend
prices = [100 + i * 0.1 + np.random.randn() * 0.5 for i in range(100)]
return pd.Series(prices, index=dates)
def test_detect_divergence_insufficient_data(self, indicators):
"""Test divergence detection with insufficient data."""
prices = pd.Series([100, 101, 102])
indicator = pd.Series([50, 51, 52])
result = indicators.detect_divergence(prices, indicator, lookback=20)
assert result['type'] is None
assert result['confidence'] == 0.0
def test_detect_divergence_structure(self, indicators, sample_data):
"""Test divergence detection returns correct structure."""
# Create indicator data
indicator = pd.Series([50 + i * 0.1 for i in range(100)], index=sample_data.index)
result = indicators.detect_divergence(sample_data, indicator, lookback=20)
# Check structure
assert 'type' in result
assert 'confidence' in result
assert 'price_swing_high' in result
assert 'price_swing_low' in result
assert 'indicator_swing_high' in result
assert 'indicator_swing_low' in result
# Type should be None, 'bullish', or 'bearish'
assert result['type'] in [None, 'bullish', 'bearish']
# Confidence should be 0.0 to 1.0
assert 0.0 <= result['confidence'] <= 1.0
def test_detect_divergence_with_trend(self, indicators):
"""Test divergence detection with clear trend data."""
# Create price making lower lows
prices = pd.Series([100, 95, 90, 85, 80])
# Create indicator making higher lows (bullish divergence)
indicator = pd.Series([30, 32, 34, 36, 38])
# Need more data for lookback
prices_long = pd.concat([pd.Series([110] * 30), prices])
indicator_long = pd.concat([pd.Series([25] * 30), indicator])
result = indicators.detect_divergence(
prices_long,
indicator_long,
lookback=5,
min_swings=2
)
# Should detect bullish divergence (price down, indicator up)
# Note: This may not always detect due to swing detection logic
assert result is not None
assert 'type' in result
assert 'confidence' in result

View File

@@ -0,0 +1,135 @@
"""Unit tests for pricing service."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from decimal import Decimal
from src.data.pricing_service import PricingService, get_pricing_service
from src.data.providers.base_provider import BasePricingProvider
class MockProvider(BasePricingProvider):
"""Mock provider for testing."""
@property
def name(self) -> str:
return "MockProvider"
@property
def supports_websocket(self) -> bool:
return False
def connect(self) -> bool:
self._connected = True
return True
def disconnect(self):
self._connected = False
def get_ticker(self, symbol: str):
return {
'symbol': symbol,
'bid': Decimal('50000'),
'ask': Decimal('50001'),
'last': Decimal('50000.5'),
'high': Decimal('51000'),
'low': Decimal('49000'),
'volume': Decimal('1000000'),
'timestamp': 1609459200000,
}
def get_ohlcv(self, symbol: str, timeframe: str = '1h', since=None, limit: int = 100):
return [[1609459200000, 50000, 51000, 49000, 50000, 1000]]
def subscribe_ticker(self, symbol: str, callback) -> bool:
if symbol not in self._subscribers:
self._subscribers[symbol] = []
self._subscribers[symbol].append(callback)
return True
@pytest.fixture
def mock_config():
"""Create a mock configuration."""
config = Mock()
config.get = Mock(side_effect=lambda key, default=None: {
"data_providers.primary": [
{"name": "mock", "enabled": True, "priority": 1}
],
"data_providers.fallback": {"enabled": True, "api_key": ""},
"data_providers.caching.ticker_ttl": 2,
"data_providers.caching.ohlcv_ttl": 60,
"data_providers.caching.max_cache_size": 1000,
}.get(key, default))
return config
class TestPricingService:
"""Tests for PricingService."""
@patch('src.data.pricing_service.get_config')
@patch('src.data.providers.ccxt_provider.CCXTProvider')
@patch('src.data.providers.coingecko_provider.CoinGeckoProvider')
def test_init(self, mock_coingecko, mock_ccxt, mock_get_config, mock_config):
"""Test service initialization."""
mock_get_config.return_value = mock_config
mock_ccxt_instance = MockProvider()
mock_ccxt.return_value = mock_ccxt_instance
mock_coingecko_instance = MockProvider()
mock_coingecko.return_value = mock_coingecko_instance
service = PricingService()
assert service.cache is not None
assert service.health_monitor is not None
@patch('src.data.pricing_service.get_config')
@patch('src.data.providers.ccxt_provider.CCXTProvider')
def test_get_ticker(self, mock_ccxt, mock_get_config, mock_config):
"""Test getting ticker data."""
mock_get_config.return_value = mock_config
mock_provider = MockProvider()
mock_ccxt.return_value = mock_provider
service = PricingService()
service._providers["MockProvider"] = mock_provider
service._active_provider = "MockProvider"
ticker = service.get_ticker("BTC/USD")
assert ticker['symbol'] == "BTC/USD"
assert isinstance(ticker['last'], Decimal)
@patch('src.data.pricing_service.get_config')
@patch('src.data.providers.ccxt_provider.CCXTProvider')
def test_get_ohlcv(self, mock_ccxt, mock_get_config, mock_config):
"""Test getting OHLCV data."""
mock_get_config.return_value = mock_config
mock_provider = MockProvider()
mock_ccxt.return_value = mock_provider
service = PricingService()
service._providers["MockProvider"] = mock_provider
service._active_provider = "MockProvider"
ohlcv = service.get_ohlcv("BTC/USD", "1h", limit=10)
assert len(ohlcv) > 0
assert len(ohlcv[0]) == 6
@patch('src.data.pricing_service.get_config')
@patch('src.data.providers.ccxt_provider.CCXTProvider')
def test_subscribe_ticker(self, mock_ccxt, mock_get_config, mock_config):
"""Test subscribing to ticker updates."""
mock_get_config.return_value = mock_config
mock_provider = MockProvider()
mock_ccxt.return_value = mock_provider
service = PricingService()
service._providers["MockProvider"] = mock_provider
service._active_provider = "MockProvider"
callback = Mock()
result = service.subscribe_ticker("BTC/USD", callback)
assert result is True

View File

@@ -0,0 +1,118 @@
"""Tests for Redis cache."""
import pytest
from unittest.mock import Mock, patch, AsyncMock
class TestRedisCache:
"""Tests for RedisCache class."""
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_get_ticker_cache_hit(self, mock_get_client):
"""Test getting cached ticker data."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_client.get.return_value = '{"price": 45000.0, "symbol": "BTC/USD"}'
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
result = await cache.get_ticker("BTC/USD")
assert result is not None
assert result["price"] == 45000.0
mock_client.get.assert_called_once()
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_get_ticker_cache_miss(self, mock_get_client):
"""Test ticker cache miss."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_client.get.return_value = None
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
result = await cache.get_ticker("BTC/USD")
assert result is None
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_set_ticker(self, mock_get_client):
"""Test setting ticker cache."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
result = await cache.set_ticker("BTC/USD", {"price": 45000.0})
assert result is True
mock_client.setex.assert_called_once()
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_get_ohlcv(self, mock_get_client):
"""Test getting cached OHLCV data."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_client.get.return_value = '[[1700000000, 45000, 45500, 44500, 45200, 1000]]'
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
result = await cache.get_ohlcv("BTC/USD", "1h", 100)
assert result is not None
assert len(result) == 1
assert result[0][0] == 1700000000
@patch('src.data.redis_cache.get_redis_client')
@pytest.mark.asyncio
async def test_set_ohlcv(self, mock_get_client):
"""Test setting OHLCV cache."""
mock_redis = Mock()
mock_client = AsyncMock()
mock_redis.get_client.return_value = mock_client
mock_get_client.return_value = mock_redis
from src.data.redis_cache import RedisCache
cache = RedisCache()
ohlcv_data = [[1700000000, 45000, 45500, 44500, 45200, 1000]]
result = await cache.set_ohlcv("BTC/USD", "1h", ohlcv_data)
assert result is True
mock_client.setex.assert_called_once()
class TestGetRedisCache:
"""Tests for get_redis_cache singleton."""
@patch('src.data.redis_cache.get_redis_client')
def test_returns_singleton(self, mock_get_client):
"""Test that get_redis_cache returns same instance."""
mock_get_client.return_value = Mock()
# Reset the global
import src.data.redis_cache as cache_module
cache_module._redis_cache = None
from src.data.redis_cache import get_redis_cache
cache1 = get_redis_cache()
cache2 = get_redis_cache()
assert cache1 is cache2

View File

@@ -0,0 +1,2 @@
"""Unit tests for exchange adapters."""

View File

@@ -0,0 +1,24 @@
"""Tests for base exchange adapter."""
import pytest
from unittest.mock import Mock, AsyncMock
from src.exchanges.base import BaseExchange
class TestBaseExchange:
"""Tests for BaseExchange abstract class."""
def test_base_exchange_init(self):
"""Test base exchange initialization."""
# Can't instantiate abstract class, test through concrete implementation
from src.exchanges.coinbase import CoinbaseExchange
exchange = CoinbaseExchange(
name="test",
api_key="test_key",
secret_key="test_secret"
)
assert exchange.name == "test"
assert exchange.api_key == "test_key"
assert not exchange.is_connected

View File

@@ -0,0 +1,56 @@
"""Tests for Coinbase exchange adapter."""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from src.exchanges.coinbase import CoinbaseExchange
class TestCoinbaseExchange:
"""Tests for CoinbaseExchange."""
@pytest.fixture
def exchange(self):
"""Create Coinbase exchange instance."""
return CoinbaseExchange(
name="test_coinbase",
api_key="test_key",
secret_key="test_secret",
permissions="read_only"
)
@pytest.mark.asyncio
async def test_connect(self, exchange):
"""Test connection to Coinbase."""
with patch.object(exchange.exchange, 'load_markets', new_callable=AsyncMock):
await exchange.connect()
assert exchange.is_connected
@pytest.mark.asyncio
async def test_fetch_balance(self, exchange):
"""Test fetching balance."""
mock_balance = {'USD': {'free': 1000.0, 'used': 0.0, 'total': 1000.0}}
exchange.exchange.fetch_balance = AsyncMock(return_value=mock_balance)
exchange.is_connected = True
balance = await exchange.fetch_balance()
assert balance == mock_balance
@pytest.mark.asyncio
async def test_place_order_readonly(self, exchange):
"""Test placing order with read-only permissions."""
exchange.permissions = "read_only"
exchange.is_connected = True
with pytest.raises(PermissionError):
await exchange.place_order("BTC/USD", "buy", "market", 0.01)
@pytest.mark.asyncio
async def test_fetch_ohlcv(self, exchange):
"""Test fetching OHLCV data."""
mock_ohlcv = [[1609459200000, 29000, 29500, 28800, 29300, 1000]]
exchange.exchange.fetch_ohlcv = AsyncMock(return_value=mock_ohlcv)
exchange.is_connected = True
ohlcv = await exchange.fetch_ohlcv("BTC/USD", "1h")
assert ohlcv == mock_ohlcv

View File

@@ -0,0 +1,40 @@
"""Tests for exchange factory."""
import pytest
from unittest.mock import patch, Mock
from src.exchanges.factory import ExchangeFactory
from src.exchanges.coinbase import CoinbaseExchange
class TestExchangeFactory:
"""Tests for ExchangeFactory."""
def test_register_exchange(self):
"""Test exchange registration."""
ExchangeFactory.register_exchange("test_exchange", CoinbaseExchange)
assert "test_exchange" in ExchangeFactory.list_available()
def test_get_exchange(self):
"""Test getting exchange instance."""
with patch('src.exchanges.factory.get_key_manager') as mock_km:
mock_km.return_value.get_exchange_keys.return_value = {
'api_key': 'test_key',
'secret_key': 'test_secret',
'permissions': 'read_only'
}
exchange = ExchangeFactory.get_exchange("coinbase")
assert exchange is not None
assert isinstance(exchange, CoinbaseExchange)
def test_get_nonexistent_exchange(self):
"""Test getting non-existent exchange."""
with pytest.raises(ValueError, match="not registered"):
ExchangeFactory.get_exchange("nonexistent")
def test_list_available(self):
"""Test listing available exchanges."""
exchanges = ExchangeFactory.list_available()
assert isinstance(exchanges, list)
assert "coinbase" in exchanges

View File

@@ -0,0 +1,44 @@
"""Tests for WebSocket functionality."""
import pytest
from unittest.mock import Mock, patch
from src.exchanges.coinbase import CoinbaseAdapter
def test_subscribe_ticker():
"""Test ticker subscription."""
adapter = CoinbaseAdapter("test_key", "test_secret", sandbox=True)
callback = Mock()
adapter.subscribe_ticker("BTC/USD", callback)
assert f'ticker_BTC/USD' in adapter._ws_callbacks
assert adapter._ws_callbacks[f'ticker_BTC/USD'] == callback
def test_subscribe_orderbook():
"""Test orderbook subscription."""
adapter = CoinbaseAdapter("test_key", "test_secret", sandbox=True)
callback = Mock()
adapter.subscribe_orderbook("BTC/USD", callback)
assert f'orderbook_BTC/USD' in adapter._ws_callbacks
def test_subscribe_trades():
"""Test trades subscription."""
adapter = CoinbaseAdapter("test_key", "test_secret", sandbox=True)
callback = Mock()
adapter.subscribe_trades("BTC/USD", callback)
assert f'trades_BTC/USD' in adapter._ws_callbacks
def test_normalize_symbol():
"""Test symbol normalization."""
adapter = CoinbaseAdapter("test_key", "test_secret", sandbox=True)
assert adapter.normalize_symbol("BTC/USD") == "BTC-USD"
assert adapter.normalize_symbol("ETH/USDT") == "ETH-USDT"

View File

@@ -0,0 +1,2 @@
"""Unit tests for optimization."""

View File

@@ -0,0 +1,44 @@
"""Tests for grid search optimization."""
import pytest
from src.optimization.grid_search import GridSearchOptimizer
class TestGridSearchOptimizer:
"""Tests for GridSearchOptimizer."""
@pytest.fixture
def optimizer(self):
"""Create grid search optimizer."""
return GridSearchOptimizer()
def test_optimize_maximize(self, optimizer):
"""Test optimization with maximize."""
param_grid = {
'param1': [1, 2, 3],
'param2': [10, 20]
}
def objective(params):
return params['param1'] * params['param2']
result = optimizer.optimize(param_grid, objective, maximize=True)
assert result['best_params'] is not None
assert result['best_score'] is not None
assert result['best_score'] > 0
def test_optimize_minimize(self, optimizer):
"""Test optimization with minimize."""
param_grid = {
'param1': [1, 2, 3]
}
def objective(params):
return params['param1'] * 10
result = optimizer.optimize(param_grid, objective, maximize=False)
assert result['best_params'] is not None
assert result['best_score'] is not None

View File

@@ -0,0 +1,2 @@
"""Unit tests for portfolio management."""

View File

@@ -0,0 +1,35 @@
"""Tests for portfolio analytics."""
import pytest
import pandas as pd
import numpy as np
from src.portfolio.analytics import get_portfolio_analytics, PortfolioAnalytics
class TestPortfolioAnalytics:
"""Tests for PortfolioAnalytics."""
@pytest.fixture
def analytics(self):
"""Create portfolio analytics instance."""
return get_portfolio_analytics()
def test_calculate_sharpe_ratio(self, analytics):
"""Test Sharpe ratio calculation."""
returns = pd.Series([0.01, -0.005, 0.02, -0.01, 0.015])
sharpe = analytics.calculate_sharpe_ratio(returns, risk_free_rate=0.0)
assert isinstance(sharpe, float)
def test_calculate_sortino_ratio(self, analytics):
"""Test Sortino ratio calculation."""
returns = pd.Series([0.01, -0.005, 0.02, -0.01, 0.015])
sortino = analytics.calculate_sortino_ratio(returns, risk_free_rate=0.0)
assert isinstance(sortino, float)
def test_calculate_max_drawdown(self, analytics):
"""Test max drawdown calculation."""
equity_curve = pd.Series([10000, 10500, 10200, 11000, 10800])
drawdown = analytics.calculate_max_drawdown(equity_curve)
assert isinstance(drawdown, float)
assert 0 <= drawdown <= 1

View File

@@ -0,0 +1,29 @@
"""Tests for portfolio tracker."""
import pytest
from src.portfolio.tracker import get_portfolio_tracker, PortfolioTracker
class TestPortfolioTracker:
"""Tests for PortfolioTracker."""
@pytest.fixture
def tracker(self):
"""Create portfolio tracker instance."""
return get_portfolio_tracker()
@pytest.mark.asyncio
async def test_get_current_portfolio(self, tracker):
"""Test getting current portfolio."""
portfolio = await tracker.get_current_portfolio(paper_trading=True)
assert portfolio is not None
assert "positions" in portfolio
assert "performance" in portfolio
@pytest.mark.asyncio
async def test_update_positions_prices(self, tracker):
"""Test updating position prices."""
prices = {"BTC/USD": 50000.0}
await tracker.update_positions_prices(prices, paper_trading=True)
# Should not raise exception

View File

@@ -0,0 +1,2 @@
"""Unit tests for reporting."""

View File

@@ -0,0 +1,35 @@
"""Tests for CSV exporter."""
import pytest
from pathlib import Path
from tempfile import TemporaryDirectory
from src.reporting.csv_exporter import get_csv_exporter, CSVExporter
class TestCSVExporter:
"""Tests for CSVExporter."""
@pytest.fixture
def exporter(self):
"""Create CSV exporter instance."""
return get_csv_exporter()
def test_export_trades(self, exporter, mock_database):
"""Test exporting trades."""
engine, Session = mock_database
with TemporaryDirectory() as tmpdir:
filepath = Path(tmpdir) / "trades.csv"
# Export (may be empty if no trades)
result = exporter.export_trades(filepath, paper_trading=True)
assert isinstance(result, bool)
def test_export_portfolio(self, exporter):
"""Test exporting portfolio."""
with TemporaryDirectory() as tmpdir:
filepath = Path(tmpdir) / "portfolio.csv"
result = exporter.export_portfolio(filepath)
assert isinstance(result, bool)

View File

@@ -0,0 +1,2 @@
"""Unit tests for resilience."""

View File

@@ -0,0 +1,31 @@
"""Tests for state manager."""
import pytest
from src.resilience.state_manager import get_state_manager, StateManager
class TestStateManager:
"""Tests for StateManager."""
@pytest.fixture
def state_manager(self):
"""Create state manager instance."""
return get_state_manager()
@pytest.mark.asyncio
async def test_save_state(self, state_manager):
"""Test saving state."""
result = await state_manager.save_state("test_key", {"data": "value"})
assert result is True
@pytest.mark.asyncio
async def test_load_state(self, state_manager):
"""Test loading state."""
# Save first
await state_manager.save_state("test_key", {"data": "value"})
# Load
state = await state_manager.load_state("test_key")
assert state is not None
assert state.get("data") == "value"

View File

@@ -0,0 +1,2 @@
"""Unit tests for risk management."""

View File

@@ -0,0 +1,55 @@
"""Tests for risk manager."""
import pytest
from src.risk.manager import get_risk_manager, RiskManager
class TestRiskManager:
"""Tests for RiskManager."""
@pytest.fixture
def risk_manager(self):
"""Create risk manager instance."""
return get_risk_manager()
@pytest.mark.asyncio
async def test_check_trade_risk(self, risk_manager):
"""Test trade risk checking."""
# Test with valid trade
result = await risk_manager.check_trade_risk(
exchange_id=1,
strategy_id=1,
symbol="BTC/USD",
side="buy",
amount=0.01,
price=50000.0,
current_portfolio_value=10000.0
)
assert isinstance(result, bool)
@pytest.mark.asyncio
async def test_check_max_drawdown(self, risk_manager):
"""Test max drawdown check."""
result = await risk_manager.check_max_drawdown(
current_portfolio_value=9000.0,
peak_portfolio_value=10000.0
)
assert isinstance(result, bool)
@pytest.mark.asyncio
async def test_add_risk_limit(self, risk_manager, mock_database):
"""Test adding risk limit."""
engine, Session = mock_database
await risk_manager.add_risk_limit(
limit_type="max_drawdown",
value=0.10, # 10%
is_active=True
)
# Verify limit was added
await risk_manager.load_risk_limits()
assert len(risk_manager.risk_limits) > 0

View File

@@ -0,0 +1,41 @@
"""Tests for position sizing."""
import pytest
from src.risk.position_sizing import get_position_sizer, PositionSizing
class TestPositionSizing:
"""Tests for PositionSizing."""
@pytest.fixture
def position_sizer(self):
"""Create position sizer instance."""
return get_position_sizer()
def test_fixed_percentage(self, position_sizer):
"""Test fixed percentage sizing."""
size = position_sizer.fixed_percentage(10000.0, 0.02) # 2%
assert size == 200.0
def test_fixed_amount(self, position_sizer):
"""Test fixed amount sizing."""
size = position_sizer.fixed_amount(500.0)
assert size == 500.0
def test_volatility_based(self, position_sizer):
"""Test volatility-based sizing."""
size = position_sizer.volatility_based(
capital=10000.0,
atr=100.0,
risk_per_trade_percentage=0.01 # 1%
)
assert size > 0
def test_kelly_criterion(self, position_sizer):
"""Test Kelly Criterion."""
fraction = position_sizer.kelly_criterion(
win_probability=0.6,
payout_ratio=1.5
)
assert 0 <= fraction <= 1

View File

@@ -0,0 +1,122 @@
"""Tests for ATR-based stop loss."""
import pytest
import pandas as pd
from decimal import Decimal
from src.risk.stop_loss import StopLossManager
class TestATRStopLoss:
"""Tests for ATR-based stop loss functionality."""
@pytest.fixture
def stop_loss_manager(self):
"""Create stop loss manager instance."""
return StopLossManager()
@pytest.fixture
def sample_ohlcv_data(self):
"""Create sample OHLCV data."""
dates = pd.date_range(start='2025-01-01', periods=50, freq='1H')
base_price = 50000
return pd.DataFrame({
'high': [base_price + 100 + i * 10 for i in range(50)],
'low': [base_price - 100 + i * 10 for i in range(50)],
'close': [base_price + i * 10 for i in range(50)],
'open': [base_price - 50 + i * 10 for i in range(50)],
'volume': [1000.0] * 50
}, index=dates)
def test_set_atr_stop_loss(self, stop_loss_manager, sample_ohlcv_data):
"""Test setting ATR-based stop loss."""
position_id = 1
entry_price = Decimal("50000")
stop_loss_manager.set_stop_loss(
position_id=position_id,
stop_price=entry_price,
use_atr=True,
atr_multiplier=Decimal('2.0'),
atr_period=14,
ohlcv_data=sample_ohlcv_data
)
assert position_id in stop_loss_manager.stop_losses
config = stop_loss_manager.stop_losses[position_id]
assert config['use_atr'] is True
assert config['atr_multiplier'] == Decimal('2.0')
assert config['atr_period'] == 14
assert 'stop_price' in config
def test_calculate_atr_stop(self, stop_loss_manager, sample_ohlcv_data):
"""Test calculating ATR stop price."""
entry_price = Decimal("50000")
stop_price_long = stop_loss_manager.calculate_atr_stop(
entry_price=entry_price,
is_long=True,
ohlcv_data=sample_ohlcv_data,
atr_multiplier=Decimal('2.0'),
atr_period=14
)
stop_price_short = stop_loss_manager.calculate_atr_stop(
entry_price=entry_price,
is_long=False,
ohlcv_data=sample_ohlcv_data,
atr_multiplier=Decimal('2.0'),
atr_period=14
)
# Long position: stop should be below entry
assert stop_price_long < entry_price
# Short position: stop should be above entry
assert stop_price_short > entry_price
def test_atr_trailing_stop(self, stop_loss_manager, sample_ohlcv_data):
"""Test ATR-based trailing stop."""
position_id = 1
entry_price = Decimal("50000")
stop_loss_manager.set_stop_loss(
position_id=position_id,
stop_price=entry_price,
trailing=True,
use_atr=True,
atr_multiplier=Decimal('2.0'),
atr_period=14,
ohlcv_data=sample_ohlcv_data
)
# Check stop loss with higher price (should update trailing stop)
current_price = Decimal("51000")
triggered = stop_loss_manager.check_stop_loss(
position_id=position_id,
current_price=current_price,
is_long=True,
ohlcv_data=sample_ohlcv_data
)
# Should not trigger at higher price
assert triggered is False
def test_atr_stop_insufficient_data(self, stop_loss_manager):
"""Test ATR stop with insufficient data falls back to percentage."""
entry_price = Decimal("50000")
insufficient_data = pd.DataFrame({
'high': [51000],
'low': [49000],
'close': [50000]
})
stop_price = stop_loss_manager.calculate_atr_stop(
entry_price=entry_price,
is_long=True,
ohlcv_data=insufficient_data,
atr_period=14
)
# Should fall back to percentage-based stop (2%)
assert stop_price == entry_price * Decimal('0.98')

View File

@@ -0,0 +1,2 @@
"""Unit tests for security."""

View File

@@ -0,0 +1,35 @@
"""Tests for encryption."""
import pytest
from src.security.encryption import get_encryption_manager, EncryptionManager
class TestEncryption:
"""Tests for encryption system."""
@pytest.fixture
def encryptor(self):
"""Create encryption manager."""
return get_encryption_manager()
def test_encrypt_decrypt(self, encryptor):
"""Test encryption and decryption."""
plaintext = "test_api_key_12345"
encrypted = encryptor.encrypt(plaintext)
assert encrypted != plaintext
assert len(encrypted) > 0
decrypted = encryptor.decrypt(encrypted)
assert decrypted == plaintext
def test_encrypt_different_values(self, encryptor):
"""Test that different values encrypt differently."""
value1 = "key1"
value2 = "key2"
encrypted1 = encryptor.encrypt(value1)
encrypted2 = encryptor.encrypt(value2)
assert encrypted1 != encrypted2

View File

@@ -0,0 +1,2 @@
"""Unit tests for strategy framework."""

View File

@@ -0,0 +1,89 @@
"""Tests for base strategy class."""
import pytest
import pandas as pd
from src.strategies.base import BaseStrategy, StrategyRegistry
class ConcreteStrategy(BaseStrategy):
"""Concrete strategy for testing."""
async def on_data(self, new_data: pd.DataFrame):
"""Handle new data."""
self.current_data = pd.concat([self.current_data, new_data]).tail(100)
async def generate_signal(self):
"""Generate signal."""
if len(self.current_data) > 0:
return {"signal": "hold", "price": self.current_data['close'].iloc[-1]}
return {"signal": "hold", "price": None}
async def calculate_position_size(self, capital: float, risk_percentage: float) -> float:
"""Calculate position size."""
return capital * risk_percentage
class TestBaseStrategy:
"""Tests for BaseStrategy."""
@pytest.fixture
def strategy(self):
"""Create strategy instance."""
return ConcreteStrategy(
strategy_id=1,
name="test_strategy",
symbol="BTC/USD",
timeframe="1h",
parameters={}
)
def test_strategy_initialization(self, strategy):
"""Test strategy initialization."""
assert strategy.strategy_id == 1
assert strategy.name == "test_strategy"
assert strategy.symbol == "BTC/USD"
assert strategy.timeframe == "1h"
assert not strategy.is_active
@pytest.mark.asyncio
async def test_strategy_start_stop(self, strategy):
"""Test strategy start and stop."""
await strategy.start()
assert strategy.is_active
await strategy.stop()
assert not strategy.is_active
@pytest.mark.asyncio
async def test_generate_signal(self, strategy):
"""Test signal generation."""
signal = await strategy.generate_signal()
assert "signal" in signal
assert signal["signal"] in ["buy", "sell", "hold"]
@pytest.mark.asyncio
async def test_calculate_position_size(self, strategy):
"""Test position size calculation."""
size = await strategy.calculate_position_size(1000.0, 0.01)
assert size == 10.0
class TestStrategyRegistry:
"""Tests for StrategyRegistry."""
def test_register_strategy(self):
"""Test strategy registration."""
StrategyRegistry.register_strategy("test_strategy", ConcreteStrategy)
assert "test_strategy" in StrategyRegistry.list_available()
def test_get_strategy_class(self):
"""Test getting strategy class."""
StrategyRegistry.register_strategy("test_strategy", ConcreteStrategy)
strategy_class = StrategyRegistry.get_strategy_class("test_strategy")
assert strategy_class == ConcreteStrategy
def test_get_nonexistent_strategy(self):
"""Test getting non-existent strategy."""
with pytest.raises(ValueError, match="not registered"):
StrategyRegistry.get_strategy_class("nonexistent")

View File

@@ -0,0 +1,55 @@
"""Tests for Bollinger Bands mean reversion strategy."""
import pytest
from decimal import Decimal
from src.strategies.technical.bollinger_mean_reversion import BollingerMeanReversionStrategy
from src.strategies.base import SignalType
class TestBollingerMeanReversionStrategy:
"""Tests for BollingerMeanReversionStrategy."""
@pytest.fixture
def strategy(self):
"""Create Bollinger mean reversion strategy instance."""
return BollingerMeanReversionStrategy(
name="test_bollinger_mr",
parameters={
'period': 20,
'std_dev': 2.0,
'trend_filter': True,
'trend_ma_period': 50,
'entry_threshold': 0.95,
'exit_threshold': 0.5
}
)
def test_initialization(self, strategy):
"""Test strategy initialization."""
assert strategy.period == 20
assert strategy.std_dev == 2.0
assert strategy.trend_filter is True
assert strategy.trend_ma_period == 50
assert strategy.entry_threshold == 0.95
assert strategy.exit_threshold == 0.5
def test_on_tick_insufficient_data(self, strategy):
"""Test that strategy returns None with insufficient data."""
signal = strategy.on_tick(
symbol="BTC/USD",
price=Decimal("50000"),
timeframe="1h",
data={'volume': 1000}
)
assert signal is None
def test_position_tracking(self, strategy):
"""Test position tracking."""
assert strategy._in_position is False
assert strategy._entry_price is None
def test_strategy_metadata(self, strategy):
"""Test strategy metadata."""
assert strategy.name == "test_bollinger_mr"
assert strategy.enabled is False

View File

@@ -0,0 +1,61 @@
"""Tests for Confirmed strategy."""
import pytest
from decimal import Decimal
from src.strategies.technical.confirmed_strategy import ConfirmedStrategy
from src.strategies.base import SignalType
class TestConfirmedStrategy:
"""Tests for ConfirmedStrategy."""
@pytest.fixture
def strategy(self):
"""Create Confirmed strategy instance."""
return ConfirmedStrategy(
name="test_confirmed",
parameters={
'rsi_period': 14,
'macd_fast': 12,
'macd_slow': 26,
'macd_signal': 9,
'ma_fast': 10,
'ma_slow': 30,
'min_confirmations': 2,
'require_rsi': True,
'require_macd': True,
'require_ma': True
}
)
def test_initialization(self, strategy):
"""Test strategy initialization."""
assert strategy.rsi_period == 14
assert strategy.macd_fast == 12
assert strategy.ma_fast == 10
assert strategy.min_confirmations == 2
def test_on_tick_insufficient_data(self, strategy):
"""Test that strategy returns None with insufficient data."""
signal = strategy.on_tick(
symbol="BTC/USD",
price=Decimal("50000"),
timeframe="1h",
data={'volume': 1000}
)
assert signal is None
def test_min_confirmations_requirement(self, strategy):
"""Test that signal requires minimum confirmations."""
# This would require actual price history to generate real signals
# For now, we test the structure
assert strategy.min_confirmations == 2
assert strategy.require_rsi is True
assert strategy.require_macd is True
assert strategy.require_ma is True
def test_strategy_metadata(self, strategy):
"""Test strategy metadata."""
assert strategy.name == "test_confirmed"
assert strategy.enabled is False

View File

@@ -0,0 +1,53 @@
"""Tests for Consensus (ensemble) strategy."""
import pytest
from decimal import Decimal
from src.strategies.ensemble.consensus_strategy import ConsensusStrategy
from src.strategies.base import SignalType
class TestConsensusStrategy:
"""Tests for ConsensusStrategy."""
@pytest.fixture
def strategy(self):
"""Create Consensus strategy instance."""
return ConsensusStrategy(
name="test_consensus",
parameters={
'strategy_names': ['rsi', 'macd'],
'min_consensus': 2,
'use_weights': True,
'min_weight': 0.3,
'exclude_strategies': []
}
)
def test_initialization(self, strategy):
"""Test strategy initialization."""
assert strategy.min_consensus == 2
assert strategy.use_weights is True
assert strategy.min_weight == 0.3
def test_on_tick_no_strategies(self, strategy):
"""Test that strategy handles empty strategy list."""
# Strategy should handle cases where no strategies are available
signal = strategy.on_tick(
symbol="BTC/USD",
price=Decimal("50000"),
timeframe="1h",
data={'volume': 1000}
)
# May return None if no strategies available or no consensus
assert signal is None or isinstance(signal, (type(None), object))
def test_strategy_metadata(self, strategy):
"""Test strategy metadata."""
assert strategy.name == "test_consensus"
assert strategy.enabled is False
def test_consensus_calculation(self, strategy):
"""Test consensus calculation parameters."""
assert strategy.min_consensus == 2
assert strategy.use_weights is True

View File

@@ -0,0 +1,52 @@
"""Tests for DCA strategy."""
import pytest
from decimal import Decimal
from datetime import datetime, timedelta
from src.strategies.dca.dca_strategy import DCAStrategy
def test_dca_strategy_initialization():
"""Test DCA strategy initializes correctly."""
strategy = DCAStrategy("Test DCA", {"amount": 10, "interval": "daily"})
assert strategy.name == "Test DCA"
assert strategy.amount == Decimal("10")
assert strategy.interval == "daily"
def test_dca_daily_interval():
"""Test DCA with daily interval."""
strategy = DCAStrategy("Daily DCA", {"amount": 10, "interval": "daily"})
assert strategy.interval_delta == timedelta(days=1)
def test_dca_weekly_interval():
"""Test DCA with weekly interval."""
strategy = DCAStrategy("Weekly DCA", {"amount": 10, "interval": "weekly"})
assert strategy.interval_delta == timedelta(weeks=1)
def test_dca_monthly_interval():
"""Test DCA with monthly interval."""
strategy = DCAStrategy("Monthly DCA", {"amount": 10, "interval": "monthly"})
assert strategy.interval_delta == timedelta(days=30)
def test_dca_signal_generation():
"""Test DCA generates buy signals."""
strategy = DCAStrategy("Test DCA", {"amount": 10, "interval": "daily"})
strategy.last_purchase_time = None
signal = strategy.on_tick("BTC/USD", Decimal("100"), "1h", {})
assert signal is not None
assert signal.signal_type.value == "buy"
assert signal.quantity == Decimal("0.1") # 10 / 100
def test_dca_interval_respect():
"""Test DCA respects interval timing."""
strategy = DCAStrategy("Test DCA", {"amount": 10, "interval": "daily"})
strategy.last_purchase_time = datetime.utcnow() - timedelta(hours=12)
signal = strategy.on_tick("BTC/USD", Decimal("100"), "1h", {})
assert signal is None # Should not generate signal yet

View File

@@ -0,0 +1,59 @@
"""Tests for Divergence strategy."""
import pytest
from decimal import Decimal
from src.strategies.technical.divergence_strategy import DivergenceStrategy
from src.strategies.base import SignalType
class TestDivergenceStrategy:
"""Tests for DivergenceStrategy."""
@pytest.fixture
def strategy(self):
"""Create Divergence strategy instance."""
return DivergenceStrategy(
name="test_divergence",
parameters={
'indicator_type': 'rsi',
'rsi_period': 14,
'lookback': 20,
'min_swings': 2,
'min_confidence': 0.5
}
)
def test_initialization(self, strategy):
"""Test strategy initialization."""
assert strategy.indicator_type == 'rsi'
assert strategy.rsi_period == 14
assert strategy.lookback == 20
assert strategy.min_swings == 2
assert strategy.min_confidence == 0.5
def test_on_tick_insufficient_data(self, strategy):
"""Test that strategy returns None with insufficient data."""
signal = strategy.on_tick(
symbol="BTC/USD",
price=Decimal("50000"),
timeframe="1h",
data={'volume': 1000}
)
assert signal is None
def test_indicator_type_selection(self, strategy):
"""Test indicator type selection."""
assert strategy.indicator_type == 'rsi'
# Test MACD indicator type
macd_strategy = DivergenceStrategy(
name="test_divergence_macd",
parameters={'indicator_type': 'macd'}
)
assert macd_strategy.indicator_type == 'macd'
def test_strategy_metadata(self, strategy):
"""Test strategy metadata."""
assert strategy.name == "test_divergence"
assert strategy.enabled is False

View File

@@ -0,0 +1,69 @@
"""Tests for Grid strategy."""
import pytest
from decimal import Decimal
from src.strategies.grid.grid_strategy import GridStrategy
def test_grid_strategy_initialization():
"""Test Grid strategy initializes correctly."""
strategy = GridStrategy("Test Grid", {
"grid_spacing": 1,
"num_levels": 10,
"profit_target": 2
})
assert strategy.name == "Test Grid"
assert strategy.grid_spacing == Decimal("0.01")
assert strategy.num_levels == 10
def test_grid_levels_calculation():
"""Test grid levels are calculated correctly."""
strategy = GridStrategy("Test Grid", {
"grid_spacing": 1,
"num_levels": 5,
"center_price": 100
})
strategy._update_grid_levels(Decimal("100"))
assert len(strategy.buy_levels) == 5
assert len(strategy.sell_levels) == 5
# Buy levels should be below center
assert all(level < Decimal("100") for level in strategy.buy_levels)
# Sell levels should be above center
assert all(level > Decimal("100") for level in strategy.sell_levels)
def test_grid_buy_signal():
"""Test grid generates buy signal at lower level."""
strategy = GridStrategy("Test Grid", {
"grid_spacing": 1,
"num_levels": 5,
"center_price": 100,
"position_size": Decimal("0.1")
})
# Price at buy level
signal = strategy.on_tick("BTC/USD", Decimal("99"), "1h", {})
assert signal is not None
assert signal.signal_type.value == "buy"
def test_grid_profit_taking():
"""Test grid takes profit at target."""
strategy = GridStrategy("Test Grid", {
"grid_spacing": 1,
"num_levels": 5,
"profit_target": 2
})
# Simulate position
entry_price = Decimal("100")
strategy.positions[entry_price] = Decimal("0.1")
# Price with profit
signal = strategy.on_tick("BTC/USD", Decimal("102"), "1h", {})
assert signal is not None
assert signal.signal_type.value == "sell"
assert entry_price not in strategy.positions # Position removed

View File

@@ -0,0 +1,45 @@
"""Tests for MACD strategy."""
import pytest
import pandas as pd
from src.strategies.technical.macd_strategy import MACDStrategy
class TestMACDStrategy:
"""Tests for MACDStrategy."""
@pytest.fixture
def strategy(self):
"""Create MACD strategy instance."""
return MACDStrategy(
strategy_id=1,
name="test_macd",
symbol="BTC/USD",
timeframe="1h",
parameters={
"fast_period": 12,
"slow_period": 26,
"signal_period": 9
}
)
@pytest.mark.asyncio
async def test_macd_strategy_initialization(self, strategy):
"""Test MACD strategy initialization."""
assert strategy.fast_period == 12
assert strategy.slow_period == 26
assert strategy.signal_period == 9
@pytest.mark.asyncio
async def test_generate_signal(self, strategy):
"""Test signal generation."""
# Create minimal data
data = pd.DataFrame({
'close': [100 + i * 0.1 for i in range(50)]
})
strategy.current_data = data
signal = await strategy.generate_signal()
assert "signal" in signal
assert signal["signal"] in ["buy", "sell", "hold"]

View File

@@ -0,0 +1,72 @@
"""Tests for Momentum strategy."""
import pytest
from decimal import Decimal
import pandas as pd
from src.strategies.momentum.momentum_strategy import MomentumStrategy
def test_momentum_strategy_initialization():
"""Test Momentum strategy initializes correctly."""
strategy = MomentumStrategy("Test Momentum", {
"lookback_period": 20,
"momentum_threshold": 0.05
})
assert strategy.name == "Test Momentum"
assert strategy.lookback_period == 20
assert strategy.momentum_threshold == Decimal("0.05")
def test_momentum_calculation():
"""Test momentum calculation."""
strategy = MomentumStrategy("Test", {"lookback_period": 5})
# Create price history
prices = pd.Series([100, 101, 102, 103, 104, 105])
momentum = strategy._calculate_momentum(prices)
# Should be positive (price increased)
assert momentum > 0
assert momentum == 0.05 # (105 - 100) / 100
def test_momentum_entry_signal():
"""Test momentum generates entry signal."""
strategy = MomentumStrategy("Test", {
"lookback_period": 5,
"momentum_threshold": 0.05,
"volume_threshold": 1.0
})
# Build price history with momentum
for i in range(10):
price = 100 + i * 2 # Strong upward momentum
volume = 1000 * (1.5 if i >= 5 else 1.0) # Volume increase
strategy.on_tick("BTC/USD", Decimal(str(price)), "1h", {"volume": volume})
# Should generate buy signal
signal = strategy.on_tick("BTC/USD", Decimal("120"), "1h", {"volume": 2000})
assert signal is not None
assert signal.signal_type.value == "buy"
assert strategy._in_position == True
def test_momentum_exit_signal():
"""Test momentum generates exit signal on reversal."""
strategy = MomentumStrategy("Test", {
"lookback_period": 5,
"exit_threshold": -0.02
})
strategy._in_position = True
strategy._entry_price = Decimal("100")
# Build history with reversal
for i in range(10):
price = 100 - i # Downward momentum
strategy.on_tick("BTC/USD", Decimal(str(price)), "1h", {"volume": 1000})
signal = strategy.on_tick("BTC/USD", Decimal("90"), "1h", {"volume": 1000})
assert signal is not None
assert signal.signal_type.value == "sell"
assert strategy._in_position == False

View File

@@ -0,0 +1,45 @@
"""Tests for Moving Average strategy."""
import pytest
import pandas as pd
from src.strategies.technical.moving_avg_strategy import MovingAverageStrategy
class TestMovingAverageStrategy:
"""Tests for MovingAverageStrategy."""
@pytest.fixture
def strategy(self):
"""Create Moving Average strategy instance."""
return MovingAverageStrategy(
strategy_id=1,
name="test_ma",
symbol="BTC/USD",
timeframe="1h",
parameters={
"short_period": 10,
"long_period": 30,
"ma_type": "SMA"
}
)
@pytest.mark.asyncio
async def test_ma_strategy_initialization(self, strategy):
"""Test Moving Average strategy initialization."""
assert strategy.short_period == 10
assert strategy.long_period == 30
assert strategy.ma_type == "SMA"
@pytest.mark.asyncio
async def test_generate_signal(self, strategy):
"""Test signal generation."""
# Create data with trend
data = pd.DataFrame({
'close': [100 + i * 0.5 for i in range(50)]
})
strategy.current_data = data
signal = await strategy.generate_signal()
assert "signal" in signal
assert signal["signal"] in ["buy", "sell", "hold"]

View File

@@ -0,0 +1,89 @@
import pytest
import pandas as pd
import numpy as np
from unittest.mock import MagicMock, AsyncMock, patch
from decimal import Decimal
from src.strategies.technical.pairs_trading import PairsTradingStrategy
from src.strategies.base import SignalType
@pytest.fixture
def mock_pricing_service():
service = MagicMock()
service.get_ohlcv = MagicMock()
return service
@pytest.fixture
def strategy(mock_pricing_service):
with patch('src.strategies.technical.pairs_trading.get_pricing_service', return_value=mock_pricing_service):
params = {
'second_symbol': 'AVAX/USD',
'lookback_period': 5,
'z_score_threshold': 1.5,
'symbol': 'SOL/USD'
}
strat = PairsTradingStrategy("test_pairs", params)
strat.enabled = True
return strat
@pytest.mark.asyncio
async def test_pairs_trading_short_spread_signal(strategy, mock_pricing_service):
# Setup Data
# Scenario: SOL (A) pumps relative to AVAX (B) -> Spread widens -> Z-Score High -> Sell A / Buy B
# Prices for A (SOL): 100, 100, 100, 100, 120 (Pump)
ohlcv_a = [
[0, 100, 100, 100, 100, 1000],
[0, 100, 100, 100, 100, 1000],
[0, 100, 100, 100, 100, 1000],
[0, 100, 100, 100, 100, 1000],
[0, 120, 120, 120, 120, 1000],
]
# Prices for B (AVAX): 25, 25, 25, 25, 25 (Flat)
ohlcv_b = [
[0, 25, 25, 25, 25, 1000],
[0, 25, 25, 25, 25, 1000],
[0, 25, 25, 25, 25, 1000],
[0, 25, 25, 25, 25, 1000],
[0, 25, 25, 25, 25, 1000],
]
# Spread: 4, 4, 4, 4, 4.8
# Mean: 4.16, StdDev: approx small but let's see.
# Actually StdDev will be non-zero because of the last value.
mock_pricing_service.get_ohlcv.side_effect = [ohlcv_a, ohlcv_b]
# Execute
signal = await strategy.on_tick("SOL/USD", Decimal(120), "1h", {})
# Verify
assert signal is not None
assert signal.signal_type == SignalType.SELL # Sell Primary (SOL)
assert signal.metadata['secondary_action'] == 'buy' # Buy Secondary (AVAX)
assert signal.metadata['z_score'] > 1.5
@pytest.mark.asyncio
async def test_pairs_trading_long_spread_signal(strategy, mock_pricing_service):
# Scenario: SOL (A) dumps -> Spread drops -> Z-Score Low -> Buy A / Sell B
ohlcv_a = [
[0, 100, 100, 100, 100, 1000],
[0, 100, 100, 100, 100, 1000],
[0, 100, 100, 100, 100, 1000],
[0, 100, 100, 100, 100, 1000],
[0, 80, 80, 80, 80, 1000], # Dump
]
ohlcv_b = [
[0, 25, 25, 25, 25, 1000] for _ in range(5)
]
mock_pricing_service.get_ohlcv.side_effect = [ohlcv_a, ohlcv_b]
signal = await strategy.on_tick("SOL/USD", Decimal(80), "1h", {})
assert signal is not None
assert signal.signal_type == SignalType.BUY # Buy Primary
assert signal.metadata['secondary_action'] == 'sell' # Sell Secondary
assert signal.metadata['z_score'] < -1.5

View File

@@ -0,0 +1,67 @@
"""Tests for RSI strategy."""
import pytest
import pandas as pd
from src.strategies.technical.rsi_strategy import RSIStrategy
class TestRSIStrategy:
"""Tests for RSIStrategy."""
@pytest.fixture
def strategy(self):
"""Create RSI strategy instance."""
return RSIStrategy(
strategy_id=1,
name="test_rsi",
symbol="BTC/USD",
timeframe="1h",
parameters={
"rsi_period": 14,
"overbought": 70,
"oversold": 30
}
)
@pytest.fixture
def sample_data(self):
"""Create sample price data."""
dates = pd.date_range(start='2025-01-01', periods=50, freq='1H')
# Create data with clear trend for RSI calculation
prices = [100 - i * 0.5 for i in range(50)] # Downward trend
return pd.DataFrame({
'timestamp': dates,
'open': prices,
'high': [p + 1 for p in prices],
'low': [p - 1 for p in prices],
'close': prices,
'volume': [1000.0] * 50
})
@pytest.mark.asyncio
async def test_rsi_strategy_initialization(self, strategy):
"""Test RSI strategy initialization."""
assert strategy.rsi_period == 14
assert strategy.overbought == 70
assert strategy.oversold == 30
@pytest.mark.asyncio
async def test_on_data(self, strategy, sample_data):
"""Test on_data method."""
await strategy.on_data(sample_data)
assert len(strategy.current_data) > 0
@pytest.mark.asyncio
async def test_generate_signal_oversold(self, strategy, sample_data):
"""Test signal generation for oversold condition."""
await strategy.on_data(sample_data)
# Calculate RSI - should be low for downward trend
from src.data.indicators import get_indicators
indicators = get_indicators()
rsi = indicators.rsi(strategy.current_data['close'], period=14)
# If RSI is low, should generate buy signal
signal = await strategy.generate_signal()
assert "signal" in signal
assert signal["signal"] in ["buy", "sell", "hold"]

View File

@@ -0,0 +1,88 @@
"""Tests for trend filter functionality."""
import pytest
import pandas as pd
from decimal import Decimal
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
from src.strategies.technical.rsi_strategy import RSIStrategy
class TestTrendFilter:
"""Tests for trend filter in BaseStrategy."""
@pytest.fixture
def strategy(self):
"""Create strategy instance with trend filter enabled."""
strategy = RSIStrategy(
name="test_rsi_with_filter",
parameters={'use_trend_filter': True}
)
return strategy
@pytest.fixture
def ohlcv_data(self):
"""Create sample OHLCV data."""
dates = pd.date_range(start='2025-01-01', periods=50, freq='1H')
base_price = 50000
return pd.DataFrame({
'high': [base_price + 100 + i * 10 for i in range(50)],
'low': [base_price - 100 + i * 10 for i in range(50)],
'close': [base_price + i * 10 for i in range(50)],
'open': [base_price - 50 + i * 10 for i in range(50)],
'volume': [1000.0] * 50
}, index=dates)
def test_trend_filter_method_exists(self, strategy):
"""Test that apply_trend_filter method exists."""
assert hasattr(strategy, 'apply_trend_filter')
assert callable(getattr(strategy, 'apply_trend_filter'))
def test_trend_filter_insufficient_data(self, strategy):
"""Test trend filter with insufficient data."""
signal = StrategySignal(
signal_type=SignalType.BUY,
symbol="BTC/USD",
strength=0.8,
price=Decimal("50000")
)
insufficient_data = pd.DataFrame({
'high': [51000],
'low': [49000],
'close': [50000]
})
# Should allow signal when insufficient data
result = strategy.apply_trend_filter(signal, insufficient_data)
assert result is not None
def test_trend_filter_none_data(self, strategy):
"""Test trend filter with None data."""
signal = StrategySignal(
signal_type=SignalType.BUY,
symbol="BTC/USD",
strength=0.8,
price=Decimal("50000")
)
# Should allow signal when no data provided
result = strategy.apply_trend_filter(signal, None)
assert result is not None
def test_trend_filter_when_disabled(self, strategy):
"""Test that trend filter doesn't filter when disabled."""
strategy_no_filter = RSIStrategy(
name="test_rsi_no_filter",
parameters={'use_trend_filter': False}
)
signal = StrategySignal(
signal_type=SignalType.BUY,
symbol="BTC/USD",
strength=0.8,
price=Decimal("50000")
)
result = strategy_no_filter.apply_trend_filter(signal, None)
assert result == signal

View File

@@ -0,0 +1,284 @@
"""Tests for autopilot model training functionality."""
import pytest
from unittest.mock import Mock, patch, AsyncMock, MagicMock
from fastapi.testclient import TestClient
from backend.main import app
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.fixture
def mock_train_task():
"""Mock Celery train_model_task."""
with patch('backend.api.autopilot.train_model_task') as mock:
mock_result = Mock()
mock_result.id = "test-task-id-12345"
mock.delay.return_value = mock_result
yield mock
@pytest.fixture
def mock_async_result():
"""Mock Celery AsyncResult."""
with patch('backend.api.autopilot.AsyncResult') as mock:
yield mock
@pytest.fixture
def mock_strategy_selector():
"""Mock StrategySelector."""
selector = Mock()
selector.model = Mock()
selector.model.is_trained = True
selector.model.model_type = "classifier"
selector.model.feature_names = ["rsi", "macd", "sma_20"]
selector.model.training_metadata = {
"trained_at": "2024-01-01T00:00:00",
"metrics": {"test_accuracy": 0.85}
}
selector.get_model_info.return_value = {
"is_trained": True,
"model_type": "classifier",
"available_strategies": ["rsi", "macd", "momentum"],
"feature_count": 54,
"training_metadata": {
"trained_at": "2024-01-01T00:00:00",
"metrics": {"test_accuracy": 0.85},
"training_symbols": ["BTC/USD", "ETH/USD"]
}
}
return selector
class TestBootstrapConfig:
"""Tests for bootstrap configuration endpoints."""
def test_get_bootstrap_config(self, client):
"""Test getting bootstrap configuration."""
response = client.get("/api/autopilot/bootstrap-config")
assert response.status_code == 200
data = response.json()
# Verify required fields exist
assert "days" in data
assert "timeframe" in data
assert "min_samples_per_strategy" in data
assert "symbols" in data
# Verify types
assert isinstance(data["days"], int)
assert isinstance(data["timeframe"], str)
assert isinstance(data["min_samples_per_strategy"], int)
assert isinstance(data["symbols"], list)
def test_update_bootstrap_config(self, client):
"""Test updating bootstrap configuration."""
new_config = {
"days": 365,
"timeframe": "4h",
"min_samples_per_strategy": 50,
"symbols": ["BTC/USD", "ETH/USD", "SOL/USD"]
}
response = client.put(
"/api/autopilot/bootstrap-config",
json=new_config
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
# Verify the config was updated
response = client.get("/api/autopilot/bootstrap-config")
data = response.json()
assert data["days"] == 365
assert data["timeframe"] == "4h"
assert data["min_samples_per_strategy"] == 50
assert "SOL/USD" in data["symbols"]
class TestModelTraining:
"""Tests for model training endpoints."""
def test_trigger_retrain(self, client, mock_train_task):
"""Test triggering model retraining."""
response = client.post("/api/autopilot/intelligent/retrain?force=true")
assert response.status_code == 200
data = response.json()
assert data["status"] == "queued"
assert "task_id" in data
assert data["task_id"] == "test-task-id-12345"
# Verify task was called with correct parameters
mock_train_task.delay.assert_called_once()
call_kwargs = mock_train_task.delay.call_args.kwargs
assert call_kwargs["force_retrain"] is True
assert call_kwargs["bootstrap"] is True
assert "symbols" in call_kwargs
assert "days" in call_kwargs
assert "timeframe" in call_kwargs
assert "min_samples_per_strategy" in call_kwargs
def test_get_task_status_pending(self, client, mock_async_result):
"""Test getting status of a pending task."""
mock_result = Mock()
mock_result.status = "PENDING"
mock_result.result = None
mock_async_result.return_value = mock_result
response = client.get("/api/autopilot/tasks/test-task-id")
assert response.status_code == 200
data = response.json()
assert data["status"] == "PENDING"
def test_get_task_status_progress(self, client, mock_async_result):
"""Test getting status of a task in progress."""
mock_result = Mock()
mock_result.status = "PROGRESS"
mock_result.result = None
mock_result.info = {
"step": "fetching",
"progress": 50,
"message": "Fetching BTC/USD data..."
}
mock_async_result.return_value = mock_result
response = client.get("/api/autopilot/tasks/test-task-id")
assert response.status_code == 200
data = response.json()
assert data["status"] == "PROGRESS"
assert data["meta"]["progress"] == 50
def test_get_task_status_success(self, client, mock_async_result):
"""Test getting status of a successful task."""
mock_result = Mock()
mock_result.status = "SUCCESS"
mock_result.result = {
"train_accuracy": 0.85,
"test_accuracy": 0.78,
"n_samples": 1000,
"best_model": "xgboost"
}
mock_async_result.return_value = mock_result
response = client.get("/api/autopilot/tasks/test-task-id")
assert response.status_code == 200
data = response.json()
assert data["status"] == "SUCCESS"
assert data["result"]["best_model"] == "xgboost"
def test_get_task_status_failure(self, client, mock_async_result):
"""Test getting status of a failed task."""
mock_result = Mock()
mock_result.status = "FAILURE"
mock_result.result = Exception("Training failed: insufficient data")
mock_result.traceback = "Traceback (most recent call last)..."
mock_async_result.return_value = mock_result
response = client.get("/api/autopilot/tasks/test-task-id")
assert response.status_code == 200
data = response.json()
assert data["status"] == "FAILURE"
assert "error" in data["result"]
class TestModelInfo:
"""Tests for model info endpoint."""
@patch('backend.api.autopilot.get_strategy_selector')
def test_get_model_info_trained(self, mock_get_selector, client, mock_strategy_selector):
"""Test getting info for a trained model."""
mock_get_selector.return_value = mock_strategy_selector
response = client.get("/api/autopilot/intelligent/model-info")
assert response.status_code == 200
data = response.json()
assert data["is_trained"] is True
assert "available_strategies" in data
assert "feature_count" in data
@patch('backend.api.autopilot.get_strategy_selector')
def test_get_model_info_untrained(self, mock_get_selector, client):
"""Test getting info for an untrained model."""
mock_selector = Mock()
mock_selector.get_model_info.return_value = {
"is_trained": False,
"model_type": "classifier",
"available_strategies": ["rsi", "macd"],
"feature_count": 0
}
mock_get_selector.return_value = mock_selector
response = client.get("/api/autopilot/intelligent/model-info")
assert response.status_code == 200
data = response.json()
assert data["is_trained"] is False
assert data["feature_count"] == 0
class TestModelReset:
"""Tests for model reset endpoint."""
@patch('backend.api.autopilot.get_strategy_selector')
def test_reset_model(self, mock_get_selector, client):
"""Test resetting the model."""
mock_selector = Mock()
mock_selector.reset_model = AsyncMock(return_value={"status": "success"})
mock_get_selector.return_value = mock_selector
response = client.post("/api/autopilot/intelligent/reset")
assert response.status_code == 200
@pytest.mark.integration
class TestTrainingWorkflow:
"""Integration tests for the complete training workflow."""
@patch('backend.api.autopilot.train_model_task')
def test_config_and_retrain_workflow(self, mock_train_task, client):
"""Test configure -> train workflow passes config correctly."""
# Setup mock
mock_task_result = Mock()
mock_task_result.id = "test-task-123"
mock_train_task.delay.return_value = mock_task_result
# 1. Configure bootstrap settings with specific values
config = {
"days": 180,
"timeframe": "4h",
"min_samples_per_strategy": 25,
"symbols": ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"]
}
response = client.put("/api/autopilot/bootstrap-config", json=config)
assert response.status_code == 200
# 2. Trigger retraining
response = client.post("/api/autopilot/intelligent/retrain?force=true")
assert response.status_code == 200
# 3. Verify the task was called with the correct config
mock_train_task.delay.assert_called_once()
call_kwargs = mock_train_task.delay.call_args.kwargs
# All config should be passed to the task
assert call_kwargs["days"] == 180
assert call_kwargs["timeframe"] == "4h"
assert call_kwargs["min_samples_per_strategy"] == 25
assert call_kwargs["symbols"] == ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"]
assert call_kwargs["force_retrain"] is True
assert call_kwargs["bootstrap"] is True

View File

@@ -0,0 +1,2 @@
"""Unit tests for trading engine."""

View File

@@ -0,0 +1,88 @@
"""Tests for trading engine."""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from decimal import Decimal
from src.trading.engine import get_trading_engine, TradingEngine
from src.core.database import OrderSide, OrderType, OrderStatus
@pytest.mark.asyncio
class TestTradingEngine:
@pytest.fixture(autouse=True)
async def setup_data(self, db_session):
"""Setup test data."""
from src.core.database import Exchange
from sqlalchemy import select
# Check if exchange exists
result = await db_session.execute(select(Exchange).where(Exchange.id == 1))
if not result.scalar_one_or_none():
try:
# Try enabled (new schema)
exchange = Exchange(id=1, name="coinbase", enabled=True)
db_session.add(exchange)
await db_session.commit()
except Exception:
# Fallback if I was wrong about schema, but I checked it.
await db_session.rollback()
exchange = Exchange(id=1, name="coinbase")
db_session.add(exchange)
print("Exchange created in setup_data")
async def test_execute_order_paper_trading(self, mock_exchange_adapter):
"""Test executing paper trading order."""
engine = TradingEngine()
# Mock dependencies
engine.get_exchange_adapter = AsyncMock(return_value=mock_exchange_adapter)
engine.risk_manager.check_order_risk = Mock(return_value=(True, ""))
engine.paper_trading.get_balance = Mock(return_value=Decimal("10000"))
engine.paper_trading.execute_order = Mock(return_value=True)
# Mock logger
engine.logger = Mock()
order = await engine.execute_order(
exchange_id=1,
strategy_id=None,
symbol="BTC/USD",
side=OrderSide.BUY,
order_type=OrderType.LIMIT,
quantity=Decimal("0.1"),
price=Decimal("50000"),
paper_trading=True
)
# Check if error occurred
if order is None:
engine.logger.error.assert_called()
call_args = engine.logger.error.call_args
print(f"\nCaught exception in engine: {call_args}")
assert order is not None
assert order.symbol == "BTC/USD"
assert order.quantity == Decimal("0.1")
# Status might be PENDING/OPEN depending on implementation
assert order.status in [OrderStatus.PENDING, OrderStatus.OPEN]
async def test_execute_order_live_trading(self, mock_exchange_adapter):
"""Test executing live trading order."""
engine = TradingEngine()
# Mock dependencies
engine.get_exchange_adapter = AsyncMock(return_value=mock_exchange_adapter)
engine.risk_manager.check_order_risk = Mock(return_value=(True, ""))
order = await engine.execute_order(
exchange_id=1,
strategy_id=None,
symbol="BTC/USD",
side=OrderSide.BUY,
order_type=OrderType.MARKET,
quantity=Decimal("0.1"),
paper_trading=False
)
assert order is not None
assert order.paper_trading is False
mock_exchange_adapter.place_order.assert_called_once()

View File

@@ -0,0 +1,161 @@
"""Tests for fee calculator functionality."""
import pytest
from decimal import Decimal
from unittest.mock import Mock, patch
from src.trading.fee_calculator import FeeCalculator, get_fee_calculator
from src.core.database import OrderType
class TestFeeCalculator:
"""Tests for FeeCalculator class."""
@pytest.fixture
def calculator(self):
"""Create fee calculator instance."""
return FeeCalculator()
def test_get_fee_calculator_singleton(self):
"""Test that get_fee_calculator returns singleton."""
calc1 = get_fee_calculator()
calc2 = get_fee_calculator()
assert calc1 is calc2
def test_calculate_fee_basic(self, calculator):
"""Test basic fee calculation."""
fee = calculator.calculate_fee(
quantity=Decimal('1.0'),
price=Decimal('100.0'),
order_type=OrderType.MARKET
)
assert fee > 0
assert isinstance(fee, Decimal)
def test_calculate_fee_zero_quantity(self, calculator):
"""Test fee calculation with zero quantity."""
fee = calculator.calculate_fee(
quantity=Decimal('0'),
price=Decimal('100.0'),
order_type=OrderType.MARKET
)
assert fee == Decimal('0')
def test_calculate_fee_maker_vs_taker(self, calculator):
"""Test maker fees are typically lower than taker fees."""
maker_fee = calculator.calculate_fee(
quantity=Decimal('1.0'),
price=Decimal('1000.0'),
order_type=OrderType.LIMIT,
is_maker=True
)
taker_fee = calculator.calculate_fee(
quantity=Decimal('1.0'),
price=Decimal('1000.0'),
order_type=OrderType.MARKET,
is_maker=False
)
# Maker fees should be <= taker fees
assert maker_fee <= taker_fee
class TestFeeCalculatorPaperTrading:
"""Tests for paper trading fee calculation."""
@pytest.fixture
def calculator(self):
"""Create fee calculator instance."""
return FeeCalculator()
def test_get_fee_structure_by_exchange_name_coinbase(self, calculator):
"""Test getting Coinbase fee structure."""
with patch.object(calculator, 'config') as mock_config:
mock_config.get.side_effect = lambda key, default=None: {
'trading.default_fees': {'maker': 0.001, 'taker': 0.001},
'trading.exchanges.coinbase.fees': {'maker': 0.004, 'taker': 0.006}
}.get(key, default)
fees = calculator.get_fee_structure_by_exchange_name('coinbase')
assert fees['maker'] == 0.004
assert fees['taker'] == 0.006
def test_get_fee_structure_by_exchange_name_unknown(self, calculator):
"""Test getting fee structure for unknown exchange returns defaults."""
with patch.object(calculator, 'config') as mock_config:
mock_config.get.side_effect = lambda key, default=None: {
'trading.default_fees': {'maker': 0.001, 'taker': 0.001},
'trading.exchanges.unknown.fees': None
}.get(key, default)
fees = calculator.get_fee_structure_by_exchange_name('unknown')
assert fees['maker'] == 0.001
assert fees['taker'] == 0.001
def test_calculate_fee_for_paper_trading(self, calculator):
"""Test paper trading fee calculation."""
with patch.object(calculator, 'config') as mock_config:
mock_config.get.side_effect = lambda key, default=None: {
'paper_trading.fee_exchange': 'coinbase',
'trading.exchanges.coinbase.fees': {'maker': 0.004, 'taker': 0.006},
'trading.default_fees': {'maker': 0.001, 'taker': 0.001}
}.get(key, default)
fee = calculator.calculate_fee_for_paper_trading(
quantity=Decimal('1.0'),
price=Decimal('1000.0'),
order_type=OrderType.MARKET,
is_maker=False
)
# Taker fee at 0.6% of $1000 = $6
expected = Decimal('1000.0') * Decimal('0.006')
assert fee == expected
def test_calculate_fee_for_paper_trading_zero(self, calculator):
"""Test paper trading fee with zero values."""
fee = calculator.calculate_fee_for_paper_trading(
quantity=Decimal('0'),
price=Decimal('100.0'),
order_type=OrderType.MARKET
)
assert fee == Decimal('0')
class TestRoundTripFees:
"""Tests for round-trip fee calculations."""
@pytest.fixture
def calculator(self):
"""Create fee calculator instance."""
return FeeCalculator()
def test_estimate_round_trip_fee(self, calculator):
"""Test round-trip fee estimation."""
fee = calculator.estimate_round_trip_fee(
quantity=Decimal('1.0'),
price=Decimal('1000.0')
)
# Round-trip should include buy and sell fees
assert fee > 0
single_fee = calculator.calculate_fee(
quantity=Decimal('1.0'),
price=Decimal('1000.0'),
order_type=OrderType.MARKET
)
# Round-trip should be approximately 2x single fee
assert fee >= single_fee
def test_get_minimum_profit_threshold(self, calculator):
"""Test minimum profit threshold calculation."""
threshold = calculator.get_minimum_profit_threshold(
quantity=Decimal('1.0'),
price=Decimal('1000.0'),
multiplier=2.0
)
assert threshold > 0
round_trip_fee = calculator.estimate_round_trip_fee(
quantity=Decimal('1.0'),
price=Decimal('1000.0')
)
# Threshold should be 2x round-trip fee
assert threshold == round_trip_fee * Decimal('2.0')

View File

@@ -0,0 +1,70 @@
"""Tests for order manager."""
import pytest
from unittest.mock import Mock, patch
from src.trading.order_manager import get_order_manager, OrderManager
class TestOrderManager:
"""Tests for OrderManager."""
@pytest.fixture
def order_manager(self, mock_database):
"""Create order manager instance."""
return get_order_manager()
@pytest.mark.asyncio
async def test_create_order(self, order_manager, mock_database):
"""Test order creation."""
engine, Session = mock_database
order = await order_manager.create_order(
exchange_id=1,
strategy_id=1,
symbol="BTC/USD",
side="buy",
order_type="market",
amount=0.01,
price=50000.0,
is_paper_trade=True
)
assert order is not None
assert order.symbol == "BTC/USD"
assert order.side == "buy"
assert order.status == "pending"
@pytest.mark.asyncio
async def test_update_order_status(self, order_manager, mock_database):
"""Test order status update."""
engine, Session = mock_database
# Create order first
order = await order_manager.create_order(
exchange_id=1,
strategy_id=1,
symbol="BTC/USD",
side="buy",
order_type="market",
amount=0.01,
is_paper_trade=True
)
# Update status
updated = await order_manager.update_order_status(
client_order_id=order.client_order_id,
new_status="filled",
filled_amount=0.01,
cost=500.0
)
assert updated is not None
assert updated.status == "filled"
def test_get_order(self, order_manager, mock_database):
"""Test getting order."""
# This would require creating an order first
# Simplified test
order = order_manager.get_order(client_order_id="nonexistent")
assert order is None

View File

@@ -0,0 +1,32 @@
"""Tests for paper trading simulator."""
import pytest
from decimal import Decimal
from unittest.mock import Mock
from src.trading.paper_trading import get_paper_trading, PaperTradingSimulator
from src.core.database import Order, OrderSide, OrderType
class TestPaperTradingSimulator:
"""Tests for PaperTradingSimulator."""
@pytest.fixture
def simulator(self):
"""Create paper trading simulator."""
return PaperTradingSimulator(initial_capital=Decimal('1000.0'))
def test_initialization(self, simulator):
"""Test simulator initialization."""
assert simulator.initial_capital == Decimal('1000.0')
assert simulator.cash == Decimal('1000.0')
def test_get_balance(self, simulator):
"""Test getting balance."""
balance = simulator.get_balance()
assert balance == Decimal('1000.0')
def test_get_positions(self, simulator):
"""Test getting positions."""
positions = simulator.get_positions()
assert isinstance(positions, list)

Some files were not shown because too many files have changed in this diff Show More