"""Tests for autopilot model training functionality.""" import pytest from unittest.mock import Mock, patch, AsyncMock, MagicMock from fastapi.testclient import TestClient from backend.main import app @pytest.fixture def client(): """Test client fixture.""" return TestClient(app) @pytest.fixture def mock_train_task(): """Mock Celery train_model_task.""" with patch('backend.api.autopilot.train_model_task') as mock: mock_result = Mock() mock_result.id = "test-task-id-12345" mock.delay.return_value = mock_result yield mock @pytest.fixture def mock_async_result(): """Mock Celery AsyncResult.""" with patch('backend.api.autopilot.AsyncResult') as mock: yield mock @pytest.fixture def mock_strategy_selector(): """Mock StrategySelector.""" selector = Mock() selector.model = Mock() selector.model.is_trained = True selector.model.model_type = "classifier" selector.model.feature_names = ["rsi", "macd", "sma_20"] selector.model.training_metadata = { "trained_at": "2024-01-01T00:00:00", "metrics": {"test_accuracy": 0.85} } selector.get_model_info.return_value = { "is_trained": True, "model_type": "classifier", "available_strategies": ["rsi", "macd", "momentum"], "feature_count": 54, "training_metadata": { "trained_at": "2024-01-01T00:00:00", "metrics": {"test_accuracy": 0.85}, "training_symbols": ["BTC/USD", "ETH/USD"] } } return selector class TestBootstrapConfig: """Tests for bootstrap configuration endpoints.""" def test_get_bootstrap_config(self, client): """Test getting bootstrap configuration.""" response = client.get("/api/autopilot/bootstrap-config") assert response.status_code == 200 data = response.json() # Verify required fields exist assert "days" in data assert "timeframe" in data assert "min_samples_per_strategy" in data assert "symbols" in data # Verify types assert isinstance(data["days"], int) assert isinstance(data["timeframe"], str) assert isinstance(data["min_samples_per_strategy"], int) assert isinstance(data["symbols"], list) def test_update_bootstrap_config(self, client): """Test updating bootstrap configuration.""" new_config = { "days": 365, "timeframe": "4h", "min_samples_per_strategy": 50, "symbols": ["BTC/USD", "ETH/USD", "SOL/USD"] } response = client.put( "/api/autopilot/bootstrap-config", json=new_config ) assert response.status_code == 200 data = response.json() assert data["status"] == "success" # Verify the config was updated response = client.get("/api/autopilot/bootstrap-config") data = response.json() assert data["days"] == 365 assert data["timeframe"] == "4h" assert data["min_samples_per_strategy"] == 50 assert "SOL/USD" in data["symbols"] class TestModelTraining: """Tests for model training endpoints.""" def test_trigger_retrain(self, client, mock_train_task): """Test triggering model retraining.""" response = client.post("/api/autopilot/intelligent/retrain?force=true") assert response.status_code == 200 data = response.json() assert data["status"] == "queued" assert "task_id" in data assert data["task_id"] == "test-task-id-12345" # Verify task was called with correct parameters mock_train_task.delay.assert_called_once() call_kwargs = mock_train_task.delay.call_args.kwargs assert call_kwargs["force_retrain"] is True assert call_kwargs["bootstrap"] is True assert "symbols" in call_kwargs assert "days" in call_kwargs assert "timeframe" in call_kwargs assert "min_samples_per_strategy" in call_kwargs def test_get_task_status_pending(self, client, mock_async_result): """Test getting status of a pending task.""" mock_result = Mock() mock_result.status = "PENDING" mock_result.result = None mock_async_result.return_value = mock_result response = client.get("/api/autopilot/tasks/test-task-id") assert response.status_code == 200 data = response.json() assert data["status"] == "PENDING" def test_get_task_status_progress(self, client, mock_async_result): """Test getting status of a task in progress.""" mock_result = Mock() mock_result.status = "PROGRESS" mock_result.result = None mock_result.info = { "step": "fetching", "progress": 50, "message": "Fetching BTC/USD data..." } mock_async_result.return_value = mock_result response = client.get("/api/autopilot/tasks/test-task-id") assert response.status_code == 200 data = response.json() assert data["status"] == "PROGRESS" assert data["meta"]["progress"] == 50 def test_get_task_status_success(self, client, mock_async_result): """Test getting status of a successful task.""" mock_result = Mock() mock_result.status = "SUCCESS" mock_result.result = { "train_accuracy": 0.85, "test_accuracy": 0.78, "n_samples": 1000, "best_model": "xgboost" } mock_async_result.return_value = mock_result response = client.get("/api/autopilot/tasks/test-task-id") assert response.status_code == 200 data = response.json() assert data["status"] == "SUCCESS" assert data["result"]["best_model"] == "xgboost" def test_get_task_status_failure(self, client, mock_async_result): """Test getting status of a failed task.""" mock_result = Mock() mock_result.status = "FAILURE" mock_result.result = Exception("Training failed: insufficient data") mock_result.traceback = "Traceback (most recent call last)..." mock_async_result.return_value = mock_result response = client.get("/api/autopilot/tasks/test-task-id") assert response.status_code == 200 data = response.json() assert data["status"] == "FAILURE" assert "error" in data["result"] class TestModelInfo: """Tests for model info endpoint.""" @patch('backend.api.autopilot.get_strategy_selector') def test_get_model_info_trained(self, mock_get_selector, client, mock_strategy_selector): """Test getting info for a trained model.""" mock_get_selector.return_value = mock_strategy_selector response = client.get("/api/autopilot/intelligent/model-info") assert response.status_code == 200 data = response.json() assert data["is_trained"] is True assert "available_strategies" in data assert "feature_count" in data @patch('backend.api.autopilot.get_strategy_selector') def test_get_model_info_untrained(self, mock_get_selector, client): """Test getting info for an untrained model.""" mock_selector = Mock() mock_selector.get_model_info.return_value = { "is_trained": False, "model_type": "classifier", "available_strategies": ["rsi", "macd"], "feature_count": 0 } mock_get_selector.return_value = mock_selector response = client.get("/api/autopilot/intelligent/model-info") assert response.status_code == 200 data = response.json() assert data["is_trained"] is False assert data["feature_count"] == 0 class TestModelReset: """Tests for model reset endpoint.""" @patch('backend.api.autopilot.get_strategy_selector') def test_reset_model(self, mock_get_selector, client): """Test resetting the model.""" mock_selector = Mock() mock_selector.reset_model = AsyncMock(return_value={"status": "success"}) mock_get_selector.return_value = mock_selector response = client.post("/api/autopilot/intelligent/reset") assert response.status_code == 200 @pytest.mark.integration class TestTrainingWorkflow: """Integration tests for the complete training workflow.""" @patch('backend.api.autopilot.train_model_task') def test_config_and_retrain_workflow(self, mock_train_task, client): """Test configure -> train workflow passes config correctly.""" # Setup mock mock_task_result = Mock() mock_task_result.id = "test-task-123" mock_train_task.delay.return_value = mock_task_result # 1. Configure bootstrap settings with specific values config = { "days": 180, "timeframe": "4h", "min_samples_per_strategy": 25, "symbols": ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"] } response = client.put("/api/autopilot/bootstrap-config", json=config) assert response.status_code == 200 # 2. Trigger retraining response = client.post("/api/autopilot/intelligent/retrain?force=true") assert response.status_code == 200 # 3. Verify the task was called with the correct config mock_train_task.delay.assert_called_once() call_kwargs = mock_train_task.delay.call_args.kwargs # All config should be passed to the task assert call_kwargs["days"] == 180 assert call_kwargs["timeframe"] == "4h" assert call_kwargs["min_samples_per_strategy"] == 25 assert call_kwargs["symbols"] == ["BTC/USD", "ETH/USD", "SOL/USD", "XRP/USD"] assert call_kwargs["force_retrain"] is True assert call_kwargs["bootstrap"] is True