285 lines
9.9 KiB
Python
285 lines
9.9 KiB
Python
|
|
"""Tests for autopilot model training functionality."""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||
|
|
from fastapi.testclient import TestClient
|
||
|
|
|
||
|
|
from backend.main import app
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def client():
|
||
|
|
"""Test client fixture."""
|
||
|
|
return TestClient(app)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_train_task():
|
||
|
|
"""Mock Celery train_model_task."""
|
||
|
|
with patch('backend.api.autopilot.train_model_task') as mock:
|
||
|
|
mock_result = Mock()
|
||
|
|
mock_result.id = "test-task-id-12345"
|
||
|
|
mock.delay.return_value = mock_result
|
||
|
|
yield mock
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_async_result():
|
||
|
|
"""Mock Celery AsyncResult."""
|
||
|
|
with patch('backend.api.autopilot.AsyncResult') as mock:
|
||
|
|
yield mock
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_strategy_selector():
|
||
|
|
"""Mock StrategySelector."""
|
||
|
|
selector = Mock()
|
||
|
|
selector.model = Mock()
|
||
|
|
selector.model.is_trained = True
|
||
|
|
selector.model.model_type = "classifier"
|
||
|
|
selector.model.feature_names = ["rsi", "macd", "sma_20"]
|
||
|
|
selector.model.training_metadata = {
|
||
|
|
"trained_at": "2024-01-01T00:00:00",
|
||
|
|
"metrics": {"test_accuracy": 0.85}
|
||
|
|
}
|
||
|
|
selector.get_model_info.return_value = {
|
||
|
|
"is_trained": True,
|
||
|
|
"model_type": "classifier",
|
||
|
|
"available_strategies": ["rsi", "macd", "momentum"],
|
||
|
|
"feature_count": 54,
|
||
|
|
"training_metadata": {
|
||
|
|
"trained_at": "2024-01-01T00:00:00",
|
||
|
|
"metrics": {"test_accuracy": 0.85},
|
||
|
|
"training_symbols": ["BTC/USD", "ETH/USD"]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return selector
|
||
|
|
|
||
|
|
|
||
|
|
class TestBootstrapConfig:
|
||
|
|
"""Tests for bootstrap configuration endpoints."""
|
||
|
|
|
||
|
|
def test_get_bootstrap_config(self, client):
|
||
|
|
"""Test getting bootstrap configuration."""
|
||
|
|
response = client.get("/api/autopilot/bootstrap-config")
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
|
||
|
|
# Verify required fields exist
|
||
|
|
assert "days" in data
|
||
|
|
assert "timeframe" in data
|
||
|
|
assert "min_samples_per_strategy" in data
|
||
|
|
assert "symbols" in data
|
||
|
|
|
||
|
|
# Verify types
|
||
|
|
assert isinstance(data["days"], int)
|
||
|
|
assert isinstance(data["timeframe"], str)
|
||
|
|
assert isinstance(data["min_samples_per_strategy"], int)
|
||
|
|
assert isinstance(data["symbols"], list)
|
||
|
|
|
||
|
|
def test_update_bootstrap_config(self, client):
|
||
|
|
"""Test updating bootstrap configuration."""
|
||
|
|
new_config = {
|
||
|
|
"days": 365,
|
||
|
|
"timeframe": "4h",
|
||
|
|
"min_samples_per_strategy": 50,
|
||
|
|
"symbols": ["BTC/USD", "ETH/USD", "SOL/USD"]
|
||
|
|
}
|
||
|
|
|
||
|
|
response = client.put(
|
||
|
|
"/api/autopilot/bootstrap-config",
|
||
|
|
json=new_config
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["status"] == "success"
|
||
|
|
|
||
|
|
# Verify the config was updated
|
||
|
|
response = client.get("/api/autopilot/bootstrap-config")
|
||
|
|
data = response.json()
|
||
|
|
assert data["days"] == 365
|
||
|
|
assert data["timeframe"] == "4h"
|
||
|
|
assert data["min_samples_per_strategy"] == 50
|
||
|
|
assert "SOL/USD" in data["symbols"]
|
||
|
|
|
||
|
|
|
||
|
|
class TestModelTraining:
|
||
|
|
"""Tests for model training endpoints."""
|
||
|
|
|
||
|
|
def test_trigger_retrain(self, client, mock_train_task):
|
||
|
|
"""Test triggering model retraining."""
|
||
|
|
response = client.post("/api/autopilot/intelligent/retrain?force=true")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["status"] == "queued"
|
||
|
|
assert "task_id" in data
|
||
|
|
assert data["task_id"] == "test-task-id-12345"
|
||
|
|
|
||
|
|
# Verify task was called with correct parameters
|
||
|
|
mock_train_task.delay.assert_called_once()
|
||
|
|
call_kwargs = mock_train_task.delay.call_args.kwargs
|
||
|
|
assert call_kwargs["force_retrain"] is True
|
||
|
|
assert call_kwargs["bootstrap"] is True
|
||
|
|
assert "symbols" in call_kwargs
|
||
|
|
assert "days" in call_kwargs
|
||
|
|
assert "timeframe" in call_kwargs
|
||
|
|
assert "min_samples_per_strategy" in call_kwargs
|
||
|
|
|
||
|
|
def test_get_task_status_pending(self, client, mock_async_result):
|
||
|
|
"""Test getting status of a pending task."""
|
||
|
|
mock_result = Mock()
|
||
|
|
mock_result.status = "PENDING"
|
||
|
|
mock_result.result = None
|
||
|
|
mock_async_result.return_value = mock_result
|
||
|
|
|
||
|
|
response = client.get("/api/autopilot/tasks/test-task-id")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["status"] == "PENDING"
|
||
|
|
|
||
|
|
def test_get_task_status_progress(self, client, mock_async_result):
|
||
|
|
"""Test getting status of a task in progress."""
|
||
|
|
mock_result = Mock()
|
||
|
|
mock_result.status = "PROGRESS"
|
||
|
|
mock_result.result = None
|
||
|
|
mock_result.info = {
|
||
|
|
"step": "fetching",
|
||
|
|
"progress": 50,
|
||
|
|
"message": "Fetching BTC/USD data..."
|
||
|
|
}
|
||
|
|
mock_async_result.return_value = mock_result
|
||
|
|
|
||
|
|
response = client.get("/api/autopilot/tasks/test-task-id")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["status"] == "PROGRESS"
|
||
|
|
assert data["meta"]["progress"] == 50
|
||
|
|
|
||
|
|
def test_get_task_status_success(self, client, mock_async_result):
|
||
|
|
"""Test getting status of a successful task."""
|
||
|
|
mock_result = Mock()
|
||
|
|
mock_result.status = "SUCCESS"
|
||
|
|
mock_result.result = {
|
||
|
|
"train_accuracy": 0.85,
|
||
|
|
"test_accuracy": 0.78,
|
||
|
|
"n_samples": 1000,
|
||
|
|
"best_model": "xgboost"
|
||
|
|
}
|
||
|
|
mock_async_result.return_value = mock_result
|
||
|
|
|
||
|
|
response = client.get("/api/autopilot/tasks/test-task-id")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["status"] == "SUCCESS"
|
||
|
|
assert data["result"]["best_model"] == "xgboost"
|
||
|
|
|
||
|
|
def test_get_task_status_failure(self, client, mock_async_result):
|
||
|
|
"""Test getting status of a failed task."""
|
||
|
|
mock_result = Mock()
|
||
|
|
mock_result.status = "FAILURE"
|
||
|
|
mock_result.result = Exception("Training failed: insufficient data")
|
||
|
|
mock_result.traceback = "Traceback (most recent call last)..."
|
||
|
|
mock_async_result.return_value = mock_result
|
||
|
|
|
||
|
|
response = client.get("/api/autopilot/tasks/test-task-id")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["status"] == "FAILURE"
|
||
|
|
assert "error" in data["result"]
|
||
|
|
|
||
|
|
|
||
|
|
class TestModelInfo:
|
||
|
|
"""Tests for model info endpoint."""
|
||
|
|
|
||
|
|
@patch('backend.api.autopilot.get_strategy_selector')
|
||
|
|
def test_get_model_info_trained(self, mock_get_selector, client, mock_strategy_selector):
|
||
|
|
"""Test getting info for a trained model."""
|
||
|
|
mock_get_selector.return_value = mock_strategy_selector
|
||
|
|
|
||
|
|
response = client.get("/api/autopilot/intelligent/model-info")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["is_trained"] is True
|
||
|
|
assert "available_strategies" in data
|
||
|
|
assert "feature_count" in data
|
||
|
|
|
||
|
|
@patch('backend.api.autopilot.get_strategy_selector')
|
||
|
|
def test_get_model_info_untrained(self, mock_get_selector, client):
|
||
|
|
"""Test getting info for an untrained model."""
|
||
|
|
mock_selector = Mock()
|
||
|
|
mock_selector.get_model_info.return_value = {
|
||
|
|
"is_trained": False,
|
||
|
|
"model_type": "classifier",
|
||
|
|
"available_strategies": ["rsi", "macd"],
|
||
|
|
"feature_count": 0
|
||
|
|
}
|
||
|
|
mock_get_selector.return_value = mock_selector
|
||
|
|
|
||
|
|
response = client.get("/api/autopilot/intelligent/model-info")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["is_trained"] is False
|
||
|
|
assert data["feature_count"] == 0
|
||
|
|
|
||
|
|
|
||
|
|
class TestModelReset:
|
||
|
|
"""Tests for model reset endpoint."""
|
||
|
|
|
||
|
|
@patch('backend.api.autopilot.get_strategy_selector')
|
||
|
|
def test_reset_model(self, mock_get_selector, client):
|
||
|
|
"""Test resetting the model."""
|
||
|
|
mock_selector = Mock()
|
||
|
|
mock_selector.reset_model = AsyncMock(return_value={"status": "success"})
|
||
|
|
mock_get_selector.return_value = mock_selector
|
||
|
|
|
||
|
|
response = client.post("/api/autopilot/intelligent/reset")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.integration
|
||
|
|
class TestTrainingWorkflow:
|
||
|
|
"""Integration tests for the complete training workflow."""
|
||
|
|
|
||
|
|
@patch('backend.api.autopilot.train_model_task')
|
||
|
|
def test_config_and_retrain_workflow(self, mock_train_task, client):
|
||
|
|
"""Test configure -> train workflow passes config correctly."""
|
||
|
|
# Setup mock
|
||
|
|
mock_task_result = Mock()
|
||
|
|
mock_task_result.id = "test-task-123"
|
||
|
|
mock_train_task.delay.return_value = mock_task_result
|
||
|
|
|
||
|
|
# 1. Configure bootstrap settings with specific values
|
||
|
|
config = {
|
||
|
|
"days": 180,
|
||
|
|
"timeframe": "4h",
|
||
|
|
"min_samples_per_strategy": 25,
|
||
|
|
"symbols": ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"]
|
||
|
|
}
|
||
|
|
response = client.put("/api/autopilot/bootstrap-config", json=config)
|
||
|
|
assert response.status_code == 200
|
||
|
|
|
||
|
|
# 2. Trigger retraining
|
||
|
|
response = client.post("/api/autopilot/intelligent/retrain?force=true")
|
||
|
|
assert response.status_code == 200
|
||
|
|
|
||
|
|
# 3. Verify the task was called with the correct config
|
||
|
|
mock_train_task.delay.assert_called_once()
|
||
|
|
call_kwargs = mock_train_task.delay.call_args.kwargs
|
||
|
|
|
||
|
|
# All config should be passed to the task
|
||
|
|
assert call_kwargs["days"] == 180
|
||
|
|
assert call_kwargs["timeframe"] == "4h"
|
||
|
|
assert call_kwargs["min_samples_per_strategy"] == 25
|
||
|
|
assert call_kwargs["symbols"] == ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"]
|
||
|
|
assert call_kwargs["force_retrain"] is True
|
||
|
|
assert call_kwargs["bootstrap"] is True
|