Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
2
tests/unit/backend/__init__.py
Normal file
2
tests/unit/backend/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Backend API tests."""
|
||||
|
||||
2
tests/unit/backend/api/__init__.py
Normal file
2
tests/unit/backend/api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Backend API endpoint tests."""
|
||||
|
||||
379
tests/unit/backend/api/test_autopilot.py
Normal file
379
tests/unit/backend/api/test_autopilot.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Tests for autopilot API endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_autopilot():
|
||||
"""Mock autopilot instance."""
|
||||
autopilot = Mock()
|
||||
autopilot.symbol = "BTC/USD"
|
||||
autopilot.is_running = False
|
||||
autopilot.get_status.return_value = {
|
||||
"symbol": "BTC/USD",
|
||||
"running": False,
|
||||
"interval": 60.0,
|
||||
"has_market_data": True,
|
||||
"market_data_length": 100,
|
||||
"headlines_count": 5,
|
||||
"last_sentiment_score": 0.5,
|
||||
"last_pattern": "head_and_shoulders",
|
||||
"last_signal": None,
|
||||
}
|
||||
autopilot.last_signal = None
|
||||
autopilot.analyze_once.return_value = None
|
||||
return autopilot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_intelligent_autopilot():
|
||||
"""Mock intelligent autopilot instance."""
|
||||
autopilot = Mock()
|
||||
autopilot.symbol = "BTC/USD"
|
||||
autopilot.is_running = False
|
||||
autopilot.enable_auto_execution = False
|
||||
autopilot.get_status.return_value = {
|
||||
"symbol": "BTC/USD",
|
||||
"timeframe": "1h",
|
||||
"running": False,
|
||||
"selected_strategy": None,
|
||||
"trades_today": 0,
|
||||
"max_trades_per_day": 10,
|
||||
"min_confidence_threshold": 0.75,
|
||||
"enable_auto_execution": False,
|
||||
"last_analysis": None,
|
||||
"model_info": {},
|
||||
}
|
||||
return autopilot
|
||||
|
||||
|
||||
class TestUnifiedAutopilotEndpoints:
|
||||
"""Tests for unified autopilot endpoints."""
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot_mode_info')
|
||||
def test_get_modes(self, mock_get_mode_info, client):
|
||||
"""Test getting autopilot mode information."""
|
||||
mock_get_mode_info.return_value = {
|
||||
"modes": {
|
||||
"pattern": {"name": "Pattern-Based Autopilot"},
|
||||
"intelligent": {"name": "ML-Based Autopilot"},
|
||||
},
|
||||
"comparison": {},
|
||||
}
|
||||
|
||||
response = client.get("/api/autopilot/modes")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "modes" in data
|
||||
assert "pattern" in data["modes"]
|
||||
assert "intelligent" in data["modes"]
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot')
|
||||
@patch('backend.api.autopilot.run_autopilot_loop')
|
||||
def test_start_unified_pattern_mode(
|
||||
self, mock_run_loop, mock_get_autopilot, client, mock_autopilot
|
||||
):
|
||||
"""Test starting unified autopilot in pattern mode."""
|
||||
mock_get_autopilot.return_value = mock_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "pattern",
|
||||
"auto_execute": False,
|
||||
"interval": 60.0,
|
||||
"pattern_order": 5,
|
||||
"auto_fetch_news": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "started"
|
||||
assert data["mode"] == "pattern"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["auto_execute"] is False
|
||||
mock_get_autopilot.assert_called_once()
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_start_unified_intelligent_mode(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test starting unified autopilot in intelligent mode."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "intelligent",
|
||||
"auto_execute": True,
|
||||
"exchange_id": 1,
|
||||
"timeframe": "1h",
|
||||
"interval": 60.0,
|
||||
"paper_trading": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "started"
|
||||
assert data["mode"] == "intelligent"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["auto_execute"] is True
|
||||
assert mock_intelligent_autopilot.enable_auto_execution is True
|
||||
mock_get_intelligent.assert_called_once()
|
||||
|
||||
def test_start_unified_invalid_mode(self, client):
|
||||
"""Test starting unified autopilot with invalid mode."""
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "invalid_mode",
|
||||
"auto_execute": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid mode" in response.json()["detail"]
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot')
|
||||
def test_stop_unified_pattern_mode(
|
||||
self, mock_get_autopilot, client, mock_autopilot
|
||||
):
|
||||
"""Test stopping unified autopilot in pattern mode."""
|
||||
mock_get_autopilot.return_value = mock_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=pattern"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "stopped"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["mode"] == "pattern"
|
||||
mock_autopilot.stop.assert_called_once()
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_stop_unified_intelligent_mode(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test stopping unified autopilot in intelligent mode."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=intelligent&timeframe=1h"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "stopped"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["mode"] == "intelligent"
|
||||
mock_intelligent_autopilot.stop.assert_called_once()
|
||||
|
||||
def test_stop_unified_invalid_mode(self, client):
|
||||
"""Test stopping unified autopilot with invalid mode."""
|
||||
response = client.post(
|
||||
"/api/autopilot/stop-unified?symbol=BTC/USD&mode=invalid_mode"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid mode" in response.json()["detail"]
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot')
|
||||
def test_get_unified_status_pattern_mode(
|
||||
self, mock_get_autopilot, client, mock_autopilot
|
||||
):
|
||||
"""Test getting unified autopilot status in pattern mode."""
|
||||
mock_get_autopilot.return_value = mock_autopilot
|
||||
|
||||
response = client.get(
|
||||
"/api/autopilot/status-unified/BTC/USD?mode=pattern"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["mode"] == "pattern"
|
||||
assert "running" in data
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_get_unified_status_intelligent_mode(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test getting unified autopilot status in intelligent mode."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.get(
|
||||
"/api/autopilot/status-unified/BTC/USD?mode=intelligent&timeframe=1h"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["mode"] == "intelligent"
|
||||
assert "running" in data
|
||||
|
||||
def test_get_unified_status_invalid_mode(self, client):
|
||||
"""Test getting unified autopilot status with invalid mode."""
|
||||
response = client.get(
|
||||
"/api/autopilot/status-unified/BTC/USD?mode=invalid_mode"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid mode" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestModeSelection:
|
||||
"""Tests for mode selection logic."""
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot_mode_info')
|
||||
def test_mode_info_structure(self, mock_get_mode_info):
|
||||
"""Test that mode info has correct structure."""
|
||||
from src.autopilot import get_autopilot_mode_info
|
||||
|
||||
mode_info = get_autopilot_mode_info()
|
||||
|
||||
assert "modes" in mode_info
|
||||
assert "pattern" in mode_info["modes"]
|
||||
assert "intelligent" in mode_info["modes"]
|
||||
assert "comparison" in mode_info
|
||||
|
||||
# Check pattern mode structure
|
||||
pattern = mode_info["modes"]["pattern"]
|
||||
assert "name" in pattern
|
||||
assert "description" in pattern
|
||||
assert "how_it_works" in pattern
|
||||
assert "best_for" in pattern
|
||||
assert "tradeoffs" in pattern
|
||||
assert "features" in pattern
|
||||
assert "requirements" in pattern
|
||||
|
||||
# Check intelligent mode structure
|
||||
intelligent = mode_info["modes"]["intelligent"]
|
||||
assert "name" in intelligent
|
||||
assert "description" in intelligent
|
||||
assert "how_it_works" in intelligent
|
||||
assert "best_for" in intelligent
|
||||
assert "tradeoffs" in intelligent
|
||||
assert "features" in intelligent
|
||||
assert "requirements" in intelligent
|
||||
|
||||
# Check comparison structure
|
||||
comparison = mode_info["comparison"]
|
||||
assert "transparency" in comparison
|
||||
assert "adaptability" in comparison
|
||||
assert "setup_time" in comparison
|
||||
assert "resource_usage" in comparison
|
||||
|
||||
|
||||
class TestAutoExecution:
|
||||
"""Tests for auto-execution functionality."""
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_auto_execute_enabled(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test that auto-execute is set when enabled."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "intelligent",
|
||||
"auto_execute": True,
|
||||
"exchange_id": 1,
|
||||
"timeframe": "1h",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_intelligent_autopilot.enable_auto_execution is True
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_auto_execute_disabled(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test that auto-execute is not set when disabled."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start-unified",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"mode": "intelligent",
|
||||
"auto_execute": False,
|
||||
"exchange_id": 1,
|
||||
"timeframe": "1h",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Note: enable_auto_execution may have a default value, so we check it's not True
|
||||
# The actual behavior depends on the implementation
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""Tests for backward compatibility with old endpoints."""
|
||||
|
||||
@patch('backend.api.autopilot.get_autopilot')
|
||||
@patch('backend.api.autopilot.run_autopilot_loop')
|
||||
def test_old_start_endpoint_still_works(
|
||||
self, mock_run_loop, mock_get_autopilot, client, mock_autopilot
|
||||
):
|
||||
"""Test that old /start endpoint still works (deprecated but functional)."""
|
||||
mock_get_autopilot.return_value = mock_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/start",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"interval": 60.0,
|
||||
"pattern_order": 5,
|
||||
"auto_fetch_news": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "started"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
|
||||
@patch('backend.api.autopilot.get_intelligent_autopilot')
|
||||
def test_old_intelligent_start_endpoint_still_works(
|
||||
self, mock_get_intelligent, client, mock_intelligent_autopilot
|
||||
):
|
||||
"""Test that old /intelligent/start endpoint still works (deprecated but functional)."""
|
||||
mock_get_intelligent.return_value = mock_intelligent_autopilot
|
||||
|
||||
response = client.post(
|
||||
"/api/autopilot/intelligent/start",
|
||||
json={
|
||||
"symbol": "BTC/USD",
|
||||
"exchange_id": 1,
|
||||
"timeframe": "1h",
|
||||
"interval": 60.0,
|
||||
"paper_trading": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "started"
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
|
||||
81
tests/unit/backend/api/test_exchanges.py
Normal file
81
tests/unit/backend/api/test_exchanges.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for exchanges API endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
from src.core.database import Exchange
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_exchange():
|
||||
"""Mock exchange object."""
|
||||
exchange = Mock(spec=Exchange)
|
||||
exchange.id = 1
|
||||
exchange.name = "coinbase"
|
||||
exchange.is_enabled = True
|
||||
exchange.api_permissions = "read_only"
|
||||
return exchange
|
||||
|
||||
|
||||
class TestListExchanges:
|
||||
"""Tests for GET /api/exchanges."""
|
||||
|
||||
@patch('backend.api.exchanges.get_db')
|
||||
def test_list_exchanges_success(self, mock_get_db, client):
|
||||
"""Test listing exchanges."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.all.return_value = []
|
||||
|
||||
response = client.get("/api/exchanges")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestGetExchange:
|
||||
"""Tests for GET /api/exchanges/{exchange_id}."""
|
||||
|
||||
@patch('backend.api.exchanges.get_db')
|
||||
def test_get_exchange_success(self, mock_get_db, client, mock_exchange):
|
||||
"""Test getting exchange by ID."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_exchange
|
||||
|
||||
response = client.get("/api/exchanges/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == 1
|
||||
|
||||
@patch('backend.api.exchanges.get_db')
|
||||
def test_get_exchange_not_found(self, mock_get_db, client):
|
||||
"""Test getting non-existent exchange."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
response = client.get("/api/exchanges/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Exchange not found" in response.json()["detail"]
|
||||
|
||||
65
tests/unit/backend/api/test_portfolio.py
Normal file
65
tests/unit/backend/api/test_portfolio.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Tests for portfolio API endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_portfolio_tracker():
|
||||
"""Mock portfolio tracker."""
|
||||
tracker = Mock()
|
||||
tracker.get_current_portfolio.return_value = {
|
||||
"positions": [],
|
||||
"performance": {"total_return": 0.1},
|
||||
"timestamp": "2025-01-01T00:00:00"
|
||||
}
|
||||
tracker.get_portfolio_history.return_value = {
|
||||
"dates": ["2025-01-01"],
|
||||
"values": [1000.0],
|
||||
"pnl": [0.0]
|
||||
}
|
||||
return tracker
|
||||
|
||||
|
||||
class TestGetCurrentPortfolio:
|
||||
"""Tests for GET /api/portfolio/current."""
|
||||
|
||||
@patch('backend.api.portfolio.get_portfolio_tracker')
|
||||
def test_get_current_portfolio_success(self, mock_get_tracker, client, mock_portfolio_tracker):
|
||||
"""Test getting current portfolio."""
|
||||
mock_get_tracker.return_value = mock_portfolio_tracker
|
||||
|
||||
response = client.get("/api/portfolio/current?paper_trading=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "positions" in data
|
||||
assert "performance" in data
|
||||
assert "timestamp" in data
|
||||
|
||||
|
||||
class TestGetPortfolioHistory:
|
||||
"""Tests for GET /api/portfolio/history."""
|
||||
|
||||
@patch('backend.api.portfolio.get_portfolio_tracker')
|
||||
def test_get_portfolio_history_success(self, mock_get_tracker, client, mock_portfolio_tracker):
|
||||
"""Test getting portfolio history."""
|
||||
mock_get_tracker.return_value = mock_portfolio_tracker
|
||||
|
||||
response = client.get("/api/portfolio/history?paper_trading=true&days=30")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "dates" in data
|
||||
assert "values" in data
|
||||
assert "pnl" in data
|
||||
|
||||
108
tests/unit/backend/api/test_strategies.py
Normal file
108
tests/unit/backend/api/test_strategies.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Tests for strategies API endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
from src.core.database import Strategy
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_strategy_registry():
|
||||
"""Mock strategy registry."""
|
||||
registry = Mock()
|
||||
registry.list_available.return_value = ["RSIStrategy", "MACDStrategy"]
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_strategy():
|
||||
"""Mock strategy object."""
|
||||
strategy = Mock(spec=Strategy)
|
||||
strategy.id = 1
|
||||
strategy.name = "Test Strategy"
|
||||
strategy.type = "RSIStrategy"
|
||||
strategy.status = "active" # Use string instead of enum
|
||||
strategy.symbol = "BTC/USD"
|
||||
strategy.params = {}
|
||||
return strategy
|
||||
|
||||
|
||||
|
||||
class TestListStrategies:
|
||||
"""Tests for GET /api/strategies."""
|
||||
|
||||
@patch('backend.api.strategies.get_db')
|
||||
def test_list_strategies_success(self, mock_get_db, client):
|
||||
"""Test listing strategies."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.all.return_value = []
|
||||
|
||||
response = client.get("/api/strategies")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestGetStrategy:
|
||||
"""Tests for GET /api/strategies/{strategy_id}."""
|
||||
|
||||
@patch('backend.api.strategies.get_db')
|
||||
def test_get_strategy_success(self, mock_get_db, client, mock_strategy):
|
||||
"""Test getting strategy by ID."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_strategy
|
||||
|
||||
response = client.get("/api/strategies/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == 1
|
||||
|
||||
@patch('backend.api.strategies.get_db')
|
||||
def test_get_strategy_not_found(self, mock_get_db, client):
|
||||
"""Test getting non-existent strategy."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
response = client.get("/api/strategies/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Strategy not found" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestListAvailableStrategyTypes:
|
||||
"""Tests for GET /api/strategies/types."""
|
||||
|
||||
@patch('backend.api.strategies.get_strategy_registry')
|
||||
def test_list_strategy_types_success(self, mock_get_registry, client, mock_strategy_registry):
|
||||
"""Test listing available strategy types."""
|
||||
mock_get_registry.return_value = mock_strategy_registry
|
||||
|
||||
response = client.get("/api/strategies/types")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert "RSIStrategy" in data
|
||||
|
||||
293
tests/unit/backend/api/test_trading.py
Normal file
293
tests/unit/backend/api/test_trading.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Tests for trading API endpoints."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime
|
||||
|
||||
from backend.main import app
|
||||
from backend.core.schemas import OrderCreate, OrderSide, OrderType
|
||||
from src.core.database import Order, OrderStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trading_engine():
|
||||
"""Mock trading engine."""
|
||||
engine = Mock()
|
||||
engine.order_manager = Mock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_order():
|
||||
"""Mock order object."""
|
||||
order = Mock()
|
||||
order.id = 1
|
||||
order.exchange_id = 1
|
||||
order.strategy_id = None
|
||||
order.symbol = "BTC/USD"
|
||||
order.order_type = OrderType.MARKET
|
||||
order.side = OrderSide.BUY
|
||||
order.status = OrderStatus.FILLED
|
||||
order.quantity = Decimal("0.1")
|
||||
order.price = Decimal("50000")
|
||||
order.filled_quantity = Decimal("0.1")
|
||||
order.average_fill_price = Decimal("50000")
|
||||
order.fee = Decimal("5")
|
||||
order.paper_trading = True
|
||||
order.created_at = datetime.now()
|
||||
order.updated_at = datetime.now()
|
||||
order.filled_at = datetime.now()
|
||||
return order
|
||||
|
||||
|
||||
class TestCreateOrder:
|
||||
"""Tests for POST /api/trading/orders."""
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
def test_create_order_success(self, mock_get_engine, client, mock_trading_engine, mock_order):
|
||||
"""Test successful order creation."""
|
||||
mock_get_engine.return_value = mock_trading_engine
|
||||
mock_trading_engine.execute_order.return_value = mock_order
|
||||
|
||||
order_data = {
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTC/USD",
|
||||
"side": "buy",
|
||||
"order_type": "market",
|
||||
"quantity": "0.1",
|
||||
"paper_trading": True
|
||||
}
|
||||
|
||||
response = client.post("/api/trading/orders", json=order_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == 1
|
||||
assert data["symbol"] == "BTC/USD"
|
||||
assert data["side"] == "buy"
|
||||
mock_trading_engine.execute_order.assert_called_once()
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
def test_create_order_execution_failed(self, mock_get_engine, client, mock_trading_engine):
|
||||
"""Test order creation when execution fails."""
|
||||
mock_get_engine.return_value = mock_trading_engine
|
||||
mock_trading_engine.execute_order.return_value = None
|
||||
|
||||
order_data = {
|
||||
"exchange_id": 1,
|
||||
"symbol": "BTC/USD",
|
||||
"side": "buy",
|
||||
"order_type": "market",
|
||||
"quantity": "0.1",
|
||||
"paper_trading": True
|
||||
}
|
||||
|
||||
response = client.post("/api/trading/orders", json=order_data)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Order execution failed" in response.json()["detail"]
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
def test_create_order_invalid_data(self, client):
|
||||
"""Test order creation with invalid data."""
|
||||
order_data = {
|
||||
"exchange_id": "invalid",
|
||||
"symbol": "BTC/USD",
|
||||
"side": "buy",
|
||||
"order_type": "market",
|
||||
"quantity": "0.1"
|
||||
}
|
||||
|
||||
response = client.post("/api/trading/orders", json=order_data)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
class TestGetOrders:
|
||||
"""Tests for GET /api/trading/orders."""
|
||||
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_get_orders_success(self, mock_get_db, client, mock_database):
|
||||
"""Test getting orders."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
# Create mock orders
|
||||
mock_order1 = Mock(spec=Order)
|
||||
mock_order1.id = 1
|
||||
mock_order1.symbol = "BTC/USD"
|
||||
mock_order1.paper_trading = True
|
||||
|
||||
mock_order2 = Mock(spec=Order)
|
||||
mock_order2.id = 2
|
||||
mock_order2.symbol = "ETH/USD"
|
||||
mock_order2.paper_trading = True
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_order1, mock_order2]
|
||||
|
||||
response = client.get("/api/trading/orders?paper_trading=true&limit=10")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
|
||||
class TestGetOrder:
|
||||
"""Tests for GET /api/trading/orders/{order_id}."""
|
||||
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_get_order_success(self, mock_get_db, client, mock_database):
|
||||
"""Test getting order by ID."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_order = Mock(spec=Order)
|
||||
mock_order.id = 1
|
||||
mock_order.symbol = "BTC/USD"
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
|
||||
|
||||
response = client.get("/api/trading/orders/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == 1
|
||||
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_get_order_not_found(self, mock_get_db, client):
|
||||
"""Test getting non-existent order."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
response = client.get("/api/trading/orders/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Order not found" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestCancelOrder:
|
||||
"""Tests for POST /api/trading/orders/{order_id}/cancel."""
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_cancel_order_success(self, mock_get_db, mock_get_engine, client, mock_trading_engine):
|
||||
"""Test successful order cancellation."""
|
||||
mock_get_engine.return_value = mock_trading_engine
|
||||
mock_trading_engine.order_manager.cancel_order.return_value = True
|
||||
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_order = Mock(spec=Order)
|
||||
mock_order.id = 1
|
||||
mock_order.status = OrderStatus.OPEN
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
|
||||
|
||||
response = client.post("/api/trading/orders/1/cancel")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "cancelled"
|
||||
assert data["order_id"] == 1
|
||||
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_cancel_order_not_found(self, mock_get_db, client):
|
||||
"""Test cancelling non-existent order."""
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
response = client.post("/api/trading/orders/999/cancel")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Order not found" in response.json()["detail"]
|
||||
|
||||
@patch('backend.api.trading.get_trading_engine')
|
||||
@patch('backend.api.trading.get_db')
|
||||
def test_cancel_order_already_filled(self, mock_get_db, mock_get_engine, client, mock_trading_engine):
|
||||
"""Test cancelling already filled order."""
|
||||
mock_get_engine.return_value = mock_trading_engine
|
||||
|
||||
mock_db = Mock()
|
||||
mock_session = Mock()
|
||||
mock_db.get_session.return_value = mock_session
|
||||
mock_get_db.return_value = mock_db
|
||||
|
||||
mock_order = Mock(spec=Order)
|
||||
mock_order.id = 1
|
||||
mock_order.status = OrderStatus.FILLED
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
|
||||
|
||||
response = client.post("/api/trading/orders/1/cancel")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "cannot be cancelled" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestGetPositions:
|
||||
"""Tests for GET /api/trading/positions."""
|
||||
|
||||
@patch('backend.api.trading.get_paper_trading')
|
||||
def test_get_positions_paper_trading(self, mock_get_paper, client):
|
||||
"""Test getting positions for paper trading."""
|
||||
mock_paper = Mock()
|
||||
mock_position = Mock()
|
||||
mock_position.symbol = "BTC/USD"
|
||||
mock_position.quantity = Decimal("0.1")
|
||||
mock_position.entry_price = Decimal("50000")
|
||||
mock_position.current_price = Decimal("51000")
|
||||
mock_position.unrealized_pnl = Decimal("100")
|
||||
mock_position.realized_pnl = Decimal("0")
|
||||
mock_paper.get_positions.return_value = [mock_position]
|
||||
mock_get_paper.return_value = mock_paper
|
||||
|
||||
response = client.get("/api/trading/positions?paper_trading=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
if len(data) > 0:
|
||||
assert data[0]["symbol"] == "BTC/USD"
|
||||
|
||||
|
||||
class TestGetBalance:
|
||||
"""Tests for GET /api/trading/balance."""
|
||||
|
||||
@patch('backend.api.trading.get_paper_trading')
|
||||
def test_get_balance_paper_trading(self, mock_get_paper, client):
|
||||
"""Test getting balance for paper trading."""
|
||||
mock_paper = Mock()
|
||||
mock_paper.get_balance.return_value = Decimal("1000")
|
||||
mock_paper.get_performance.return_value = {"total_return": 0.1}
|
||||
mock_get_paper.return_value = mock_paper
|
||||
|
||||
response = client.get("/api/trading/balance?paper_trading=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "balance" in data
|
||||
assert "performance" in data
|
||||
assert data["balance"] == 1000.0
|
||||
|
||||
77
tests/unit/backend/core/test_dependencies.py
Normal file
77
tests/unit/backend/core/test_dependencies.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Tests for backend dependencies."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock
|
||||
from backend.core.dependencies import (
|
||||
get_database, get_trading_engine, get_portfolio_tracker,
|
||||
get_strategy_registry, get_backtesting_engine, get_exchange_factory
|
||||
)
|
||||
|
||||
|
||||
class TestGetDatabase:
|
||||
"""Tests for get_database dependency."""
|
||||
|
||||
def test_get_database_singleton(self):
|
||||
"""Test that get_database returns same instance."""
|
||||
db1 = get_database()
|
||||
db2 = get_database()
|
||||
# Should be cached/same instance due to lru_cache
|
||||
assert db1 is db2
|
||||
|
||||
|
||||
class TestGetTradingEngine:
|
||||
"""Tests for get_trading_engine dependency."""
|
||||
|
||||
@patch('backend.core.dependencies.get_trading_engine')
|
||||
def test_get_trading_engine(self, mock_get_engine):
|
||||
"""Test getting trading engine."""
|
||||
mock_engine = Mock()
|
||||
mock_get_engine.return_value = mock_engine
|
||||
engine = get_trading_engine()
|
||||
assert engine is not None
|
||||
|
||||
|
||||
class TestGetPortfolioTracker:
|
||||
"""Tests for get_portfolio_tracker dependency."""
|
||||
|
||||
@patch('backend.core.dependencies.get_portfolio_tracker')
|
||||
def test_get_portfolio_tracker(self, mock_get_tracker):
|
||||
"""Test getting portfolio tracker."""
|
||||
mock_tracker = Mock()
|
||||
mock_get_tracker.return_value = mock_tracker
|
||||
tracker = get_portfolio_tracker()
|
||||
assert tracker is not None
|
||||
|
||||
|
||||
class TestGetStrategyRegistry:
|
||||
"""Tests for get_strategy_registry dependency."""
|
||||
|
||||
@patch('backend.core.dependencies.get_strategy_registry')
|
||||
def test_get_strategy_registry(self, mock_get_registry):
|
||||
"""Test getting strategy registry."""
|
||||
mock_registry = Mock()
|
||||
mock_get_registry.return_value = mock_registry
|
||||
registry = get_strategy_registry()
|
||||
assert registry is not None
|
||||
|
||||
|
||||
class TestGetBacktestingEngine:
|
||||
"""Tests for get_backtesting_engine dependency."""
|
||||
|
||||
@patch('backend.core.dependencies.get_backtesting_engine')
|
||||
def test_get_backtesting_engine(self, mock_get_engine):
|
||||
"""Test getting backtesting engine."""
|
||||
mock_engine = Mock()
|
||||
mock_get_engine.return_value = mock_engine
|
||||
engine = get_backtesting_engine()
|
||||
assert engine is not None
|
||||
|
||||
|
||||
class TestGetExchangeFactory:
|
||||
"""Tests for get_exchange_factory dependency."""
|
||||
|
||||
def test_get_exchange_factory(self):
|
||||
"""Test getting exchange factory."""
|
||||
factory = get_exchange_factory()
|
||||
assert factory is not None
|
||||
|
||||
128
tests/unit/backend/core/test_schemas.py
Normal file
128
tests/unit/backend/core/test_schemas.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Tests for Pydantic schemas."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.core.schemas import (
|
||||
OrderCreate, OrderResponse, OrderSide, OrderType, OrderStatus,
|
||||
PositionResponse, PortfolioResponse, PortfolioHistoryResponse
|
||||
)
|
||||
|
||||
|
||||
class TestOrderCreate:
|
||||
"""Tests for OrderCreate schema."""
|
||||
|
||||
def test_order_create_valid(self):
|
||||
"""Test valid order creation."""
|
||||
order = OrderCreate(
|
||||
exchange_id=1,
|
||||
symbol="BTC/USD",
|
||||
side=OrderSide.BUY,
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=Decimal("0.1"),
|
||||
paper_trading=True
|
||||
)
|
||||
assert order.exchange_id == 1
|
||||
assert order.symbol == "BTC/USD"
|
||||
assert order.side == OrderSide.BUY
|
||||
|
||||
def test_order_create_with_price(self):
|
||||
"""Test order creation with price."""
|
||||
order = OrderCreate(
|
||||
exchange_id=1,
|
||||
symbol="BTC/USD",
|
||||
side=OrderSide.BUY,
|
||||
order_type=OrderType.LIMIT,
|
||||
quantity=Decimal("0.1"),
|
||||
price=Decimal("50000"),
|
||||
paper_trading=True
|
||||
)
|
||||
assert order.price == Decimal("50000")
|
||||
|
||||
def test_order_create_invalid_side(self):
|
||||
"""Test order creation with invalid side."""
|
||||
with pytest.raises(ValidationError):
|
||||
OrderCreate(
|
||||
exchange_id=1,
|
||||
symbol="BTC/USD",
|
||||
side="invalid",
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=Decimal("0.1")
|
||||
)
|
||||
|
||||
|
||||
class TestOrderResponse:
|
||||
"""Tests for OrderResponse schema."""
|
||||
|
||||
def test_order_response_from_dict(self):
|
||||
"""Test creating OrderResponse from dictionary."""
|
||||
order_data = {
|
||||
"id": 1,
|
||||
"exchange_id": 1,
|
||||
"strategy_id": None,
|
||||
"symbol": "BTC/USD",
|
||||
"order_type": OrderType.MARKET,
|
||||
"side": OrderSide.BUY,
|
||||
"status": OrderStatus.FILLED,
|
||||
"quantity": Decimal("0.1"),
|
||||
"price": Decimal("50000"),
|
||||
"filled_quantity": Decimal("0.1"),
|
||||
"average_fill_price": Decimal("50000"),
|
||||
"fee": Decimal("5"),
|
||||
"paper_trading": True,
|
||||
"created_at": datetime.now(),
|
||||
"updated_at": datetime.now(),
|
||||
"filled_at": datetime.now()
|
||||
}
|
||||
|
||||
order = OrderResponse(**order_data)
|
||||
assert order.id == 1
|
||||
assert order.symbol == "BTC/USD"
|
||||
|
||||
|
||||
class TestPositionResponse:
|
||||
"""Tests for PositionResponse schema."""
|
||||
|
||||
def test_position_response_valid(self):
|
||||
"""Test valid position response."""
|
||||
position = PositionResponse(
|
||||
symbol="BTC/USD",
|
||||
quantity=Decimal("0.1"),
|
||||
entry_price=Decimal("50000"),
|
||||
current_price=Decimal("51000"),
|
||||
unrealized_pnl=Decimal("100"),
|
||||
realized_pnl=Decimal("0")
|
||||
)
|
||||
assert position.symbol == "BTC/USD"
|
||||
assert position.unrealized_pnl == Decimal("100")
|
||||
|
||||
|
||||
class TestPortfolioResponse:
|
||||
"""Tests for PortfolioResponse schema."""
|
||||
|
||||
def test_portfolio_response_valid(self):
|
||||
"""Test valid portfolio response."""
|
||||
portfolio = PortfolioResponse(
|
||||
positions=[],
|
||||
performance={"total_return": 0.1},
|
||||
timestamp="2025-01-01T00:00:00"
|
||||
)
|
||||
assert portfolio.positions == []
|
||||
assert portfolio.performance["total_return"] == 0.1
|
||||
|
||||
|
||||
class TestPortfolioHistoryResponse:
|
||||
"""Tests for PortfolioHistoryResponse schema."""
|
||||
|
||||
def test_portfolio_history_response_valid(self):
|
||||
"""Test valid portfolio history response."""
|
||||
history = PortfolioHistoryResponse(
|
||||
dates=["2025-01-01", "2025-01-02"],
|
||||
values=[1000.0, 1100.0],
|
||||
pnl=[0.0, 100.0]
|
||||
)
|
||||
assert len(history.dates) == 2
|
||||
assert len(history.values) == 2
|
||||
|
||||
Reference in New Issue
Block a user