Files
crypto_trader/tests/unit/test_autopilot_training.py

285 lines
9.9 KiB
Python
Raw Permalink Normal View History

"""Tests for autopilot model training functionality."""
import pytest
from unittest.mock import Mock, patch, AsyncMock, MagicMock
from fastapi.testclient import TestClient
from backend.main import app
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
@pytest.fixture
def mock_train_task():
"""Mock Celery train_model_task."""
with patch('backend.api.autopilot.train_model_task') as mock:
mock_result = Mock()
mock_result.id = "test-task-id-12345"
mock.delay.return_value = mock_result
yield mock
@pytest.fixture
def mock_async_result():
"""Mock Celery AsyncResult."""
with patch('backend.api.autopilot.AsyncResult') as mock:
yield mock
@pytest.fixture
def mock_strategy_selector():
"""Mock StrategySelector."""
selector = Mock()
selector.model = Mock()
selector.model.is_trained = True
selector.model.model_type = "classifier"
selector.model.feature_names = ["rsi", "macd", "sma_20"]
selector.model.training_metadata = {
"trained_at": "2024-01-01T00:00:00",
"metrics": {"test_accuracy": 0.85}
}
selector.get_model_info.return_value = {
"is_trained": True,
"model_type": "classifier",
"available_strategies": ["rsi", "macd", "momentum"],
"feature_count": 54,
"training_metadata": {
"trained_at": "2024-01-01T00:00:00",
"metrics": {"test_accuracy": 0.85},
"training_symbols": ["BTC/USD", "ETH/USD"]
}
}
return selector
class TestBootstrapConfig:
"""Tests for bootstrap configuration endpoints."""
def test_get_bootstrap_config(self, client):
"""Test getting bootstrap configuration."""
response = client.get("/api/autopilot/bootstrap-config")
assert response.status_code == 200
data = response.json()
# Verify required fields exist
assert "days" in data
assert "timeframe" in data
assert "min_samples_per_strategy" in data
assert "symbols" in data
# Verify types
assert isinstance(data["days"], int)
assert isinstance(data["timeframe"], str)
assert isinstance(data["min_samples_per_strategy"], int)
assert isinstance(data["symbols"], list)
def test_update_bootstrap_config(self, client):
"""Test updating bootstrap configuration."""
new_config = {
"days": 365,
"timeframe": "4h",
"min_samples_per_strategy": 50,
"symbols": ["BTC/USD", "ETH/USD", "SOL/USD"]
}
response = client.put(
"/api/autopilot/bootstrap-config",
json=new_config
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
# Verify the config was updated
response = client.get("/api/autopilot/bootstrap-config")
data = response.json()
assert data["days"] == 365
assert data["timeframe"] == "4h"
assert data["min_samples_per_strategy"] == 50
assert "SOL/USD" in data["symbols"]
class TestModelTraining:
"""Tests for model training endpoints."""
def test_trigger_retrain(self, client, mock_train_task):
"""Test triggering model retraining."""
response = client.post("/api/autopilot/intelligent/retrain?force=true")
assert response.status_code == 200
data = response.json()
assert data["status"] == "queued"
assert "task_id" in data
assert data["task_id"] == "test-task-id-12345"
# Verify task was called with correct parameters
mock_train_task.delay.assert_called_once()
call_kwargs = mock_train_task.delay.call_args.kwargs
assert call_kwargs["force_retrain"] is True
assert call_kwargs["bootstrap"] is True
assert "symbols" in call_kwargs
assert "days" in call_kwargs
assert "timeframe" in call_kwargs
assert "min_samples_per_strategy" in call_kwargs
def test_get_task_status_pending(self, client, mock_async_result):
"""Test getting status of a pending task."""
mock_result = Mock()
mock_result.status = "PENDING"
mock_result.result = None
mock_async_result.return_value = mock_result
response = client.get("/api/autopilot/tasks/test-task-id")
assert response.status_code == 200
data = response.json()
assert data["status"] == "PENDING"
def test_get_task_status_progress(self, client, mock_async_result):
"""Test getting status of a task in progress."""
mock_result = Mock()
mock_result.status = "PROGRESS"
mock_result.result = None
mock_result.info = {
"step": "fetching",
"progress": 50,
"message": "Fetching BTC/USD data..."
}
mock_async_result.return_value = mock_result
response = client.get("/api/autopilot/tasks/test-task-id")
assert response.status_code == 200
data = response.json()
assert data["status"] == "PROGRESS"
assert data["meta"]["progress"] == 50
def test_get_task_status_success(self, client, mock_async_result):
"""Test getting status of a successful task."""
mock_result = Mock()
mock_result.status = "SUCCESS"
mock_result.result = {
"train_accuracy": 0.85,
"test_accuracy": 0.78,
"n_samples": 1000,
"best_model": "xgboost"
}
mock_async_result.return_value = mock_result
response = client.get("/api/autopilot/tasks/test-task-id")
assert response.status_code == 200
data = response.json()
assert data["status"] == "SUCCESS"
assert data["result"]["best_model"] == "xgboost"
def test_get_task_status_failure(self, client, mock_async_result):
"""Test getting status of a failed task."""
mock_result = Mock()
mock_result.status = "FAILURE"
mock_result.result = Exception("Training failed: insufficient data")
mock_result.traceback = "Traceback (most recent call last)..."
mock_async_result.return_value = mock_result
response = client.get("/api/autopilot/tasks/test-task-id")
assert response.status_code == 200
data = response.json()
assert data["status"] == "FAILURE"
assert "error" in data["result"]
class TestModelInfo:
"""Tests for model info endpoint."""
@patch('backend.api.autopilot.get_strategy_selector')
def test_get_model_info_trained(self, mock_get_selector, client, mock_strategy_selector):
"""Test getting info for a trained model."""
mock_get_selector.return_value = mock_strategy_selector
response = client.get("/api/autopilot/intelligent/model-info")
assert response.status_code == 200
data = response.json()
assert data["is_trained"] is True
assert "available_strategies" in data
assert "feature_count" in data
@patch('backend.api.autopilot.get_strategy_selector')
def test_get_model_info_untrained(self, mock_get_selector, client):
"""Test getting info for an untrained model."""
mock_selector = Mock()
mock_selector.get_model_info.return_value = {
"is_trained": False,
"model_type": "classifier",
"available_strategies": ["rsi", "macd"],
"feature_count": 0
}
mock_get_selector.return_value = mock_selector
response = client.get("/api/autopilot/intelligent/model-info")
assert response.status_code == 200
data = response.json()
assert data["is_trained"] is False
assert data["feature_count"] == 0
class TestModelReset:
"""Tests for model reset endpoint."""
@patch('backend.api.autopilot.get_strategy_selector')
def test_reset_model(self, mock_get_selector, client):
"""Test resetting the model."""
mock_selector = Mock()
mock_selector.reset_model = AsyncMock(return_value={"status": "success"})
mock_get_selector.return_value = mock_selector
response = client.post("/api/autopilot/intelligent/reset")
assert response.status_code == 200
@pytest.mark.integration
class TestTrainingWorkflow:
"""Integration tests for the complete training workflow."""
@patch('backend.api.autopilot.train_model_task')
def test_config_and_retrain_workflow(self, mock_train_task, client):
"""Test configure -> train workflow passes config correctly."""
# Setup mock
mock_task_result = Mock()
mock_task_result.id = "test-task-123"
mock_train_task.delay.return_value = mock_task_result
# 1. Configure bootstrap settings with specific values
config = {
"days": 180,
"timeframe": "4h",
"min_samples_per_strategy": 25,
"symbols": ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"]
}
response = client.put("/api/autopilot/bootstrap-config", json=config)
assert response.status_code == 200
# 2. Trigger retraining
response = client.post("/api/autopilot/intelligent/retrain?force=true")
assert response.status_code == 200
# 3. Verify the task was called with the correct config
mock_train_task.delay.assert_called_once()
call_kwargs = mock_train_task.delay.call_args.kwargs
# All config should be passed to the task
assert call_kwargs["days"] == 180
assert call_kwargs["timeframe"] == "4h"
assert call_kwargs["min_samples_per_strategy"] == 25
assert call_kwargs["symbols"] == ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"]
assert call_kwargs["force_retrain"] is True
assert call_kwargs["bootstrap"] is True