179 lines
6.1 KiB
Python
179 lines
6.1 KiB
Python
|
|
"""Tests for Celery tasks."""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||
|
|
|
||
|
|
|
||
|
|
class TestAsyncToSync:
|
||
|
|
"""Tests for async_to_sync helper."""
|
||
|
|
|
||
|
|
def test_runs_awaitable(self):
|
||
|
|
"""Test that async_to_sync runs awaitable and returns result."""
|
||
|
|
from src.worker.tasks import async_to_sync
|
||
|
|
|
||
|
|
async def async_func():
|
||
|
|
return "test_result"
|
||
|
|
|
||
|
|
result = async_to_sync(async_func())
|
||
|
|
assert result == "test_result"
|
||
|
|
|
||
|
|
def test_handles_exception(self):
|
||
|
|
"""Test that async_to_sync propagates exceptions."""
|
||
|
|
from src.worker.tasks import async_to_sync
|
||
|
|
|
||
|
|
async def async_error():
|
||
|
|
raise ValueError("test error")
|
||
|
|
|
||
|
|
with pytest.raises(ValueError, match="test error"):
|
||
|
|
async_to_sync(async_error())
|
||
|
|
|
||
|
|
|
||
|
|
class TestTrainModelTask:
|
||
|
|
"""Tests for train_model_task."""
|
||
|
|
|
||
|
|
@patch('src.worker.tasks.get_strategy_selector')
|
||
|
|
@patch('src.worker.tasks.async_to_sync')
|
||
|
|
def test_train_model_basic(self, mock_async_to_sync, mock_get_selector):
|
||
|
|
"""Test basic model training task."""
|
||
|
|
# Setup mocks
|
||
|
|
mock_selector = Mock()
|
||
|
|
mock_selector.bootstrap_symbols = ["BTC/USD"]
|
||
|
|
mock_get_selector.return_value = mock_selector
|
||
|
|
|
||
|
|
mock_async_to_sync.side_effect = [
|
||
|
|
{"X": [1, 2, 3]}, # prepare_training_data result
|
||
|
|
{"accuracy": 0.9} # train_model result
|
||
|
|
]
|
||
|
|
|
||
|
|
from src.worker.tasks import train_model_task
|
||
|
|
|
||
|
|
# Call the task directly - Celery will bind self automatically
|
||
|
|
# For testing, we need to access the underlying function
|
||
|
|
result = train_model_task.run(force_retrain=True, bootstrap=False)
|
||
|
|
|
||
|
|
assert result == {"accuracy": 0.9}
|
||
|
|
mock_get_selector.assert_called_once()
|
||
|
|
|
||
|
|
@patch('src.worker.tasks.get_strategy_selector')
|
||
|
|
@patch('src.worker.tasks.async_to_sync')
|
||
|
|
def test_train_model_with_bootstrap(self, mock_async_to_sync, mock_get_selector):
|
||
|
|
"""Test model training with bootstrapping."""
|
||
|
|
mock_selector = Mock()
|
||
|
|
mock_selector.bootstrap_symbols = ["BTC/USD", "ETH/USD"]
|
||
|
|
mock_get_selector.return_value = mock_selector
|
||
|
|
|
||
|
|
# First call returns empty data, triggering bootstrap
|
||
|
|
mock_async_to_sync.side_effect = [
|
||
|
|
{"X": []}, # Empty training data
|
||
|
|
{"total_samples": 100}, # First symbol bootstrap
|
||
|
|
{"total_samples": 50}, # Second symbol bootstrap
|
||
|
|
{"accuracy": 0.85} # Final training
|
||
|
|
]
|
||
|
|
|
||
|
|
from src.worker.tasks import train_model_task
|
||
|
|
|
||
|
|
result = train_model_task.run(force_retrain=False, bootstrap=True)
|
||
|
|
|
||
|
|
assert result == {"accuracy": 0.85}
|
||
|
|
|
||
|
|
|
||
|
|
class TestBootstrapTask:
|
||
|
|
"""Tests for bootstrap_task."""
|
||
|
|
|
||
|
|
@patch('src.worker.tasks.get_strategy_selector')
|
||
|
|
@patch('src.worker.tasks.async_to_sync')
|
||
|
|
def test_bootstrap_basic(self, mock_async_to_sync, mock_get_selector):
|
||
|
|
"""Test basic bootstrap task."""
|
||
|
|
mock_selector = Mock()
|
||
|
|
mock_get_selector.return_value = mock_selector
|
||
|
|
mock_async_to_sync.return_value = {"total_samples": 200}
|
||
|
|
|
||
|
|
from src.worker.tasks import bootstrap_task
|
||
|
|
|
||
|
|
result = bootstrap_task.run(days=90, timeframe="1h")
|
||
|
|
|
||
|
|
assert result == {"total_samples": 200}
|
||
|
|
|
||
|
|
|
||
|
|
class TestGenerateReportTask:
|
||
|
|
"""Tests for generate_report_task."""
|
||
|
|
|
||
|
|
@patch('src.worker.tasks.async_to_sync')
|
||
|
|
def test_generate_report_unknown_type(self, mock_async_to_sync):
|
||
|
|
"""Test report generation with unknown type."""
|
||
|
|
from src.worker.tasks import generate_report_task
|
||
|
|
|
||
|
|
result = generate_report_task.run("unknown", {})
|
||
|
|
|
||
|
|
assert result["status"] == "error"
|
||
|
|
assert "Unknown report type" in result["message"]
|
||
|
|
|
||
|
|
|
||
|
|
class TestOptimizeStrategyTask:
|
||
|
|
"""Tests for optimize_strategy_task."""
|
||
|
|
|
||
|
|
@patch('src.optimization.genetic.GeneticOptimizer')
|
||
|
|
def test_optimize_genetic_basic(self, mock_optimizer_class):
|
||
|
|
"""Test basic genetic optimization."""
|
||
|
|
from src.worker.tasks import optimize_strategy_task
|
||
|
|
|
||
|
|
mock_optimizer = Mock()
|
||
|
|
mock_optimizer.optimize.return_value = {
|
||
|
|
"best_params": {"period": 14},
|
||
|
|
"best_score": 0.85
|
||
|
|
}
|
||
|
|
mock_optimizer_class.return_value = mock_optimizer
|
||
|
|
|
||
|
|
result = optimize_strategy_task.run(
|
||
|
|
strategy_type="rsi",
|
||
|
|
symbol="BTC/USD",
|
||
|
|
param_ranges={"period": (5, 50)},
|
||
|
|
method="genetic",
|
||
|
|
population_size=10,
|
||
|
|
generations=5
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result["best_params"] == {"period": 14}
|
||
|
|
assert result["best_score"] == 0.85
|
||
|
|
|
||
|
|
def test_optimize_unknown_method(self):
|
||
|
|
"""Test optimization with unknown method."""
|
||
|
|
from src.worker.tasks import optimize_strategy_task
|
||
|
|
|
||
|
|
result = optimize_strategy_task.run(
|
||
|
|
strategy_type="rsi",
|
||
|
|
symbol="BTC/USD",
|
||
|
|
param_ranges={"period": (5, 50)},
|
||
|
|
method="unknown_method"
|
||
|
|
)
|
||
|
|
|
||
|
|
assert "error" in result
|
||
|
|
|
||
|
|
|
||
|
|
class TestExportDataTask:
|
||
|
|
"""Tests for export_data_task."""
|
||
|
|
|
||
|
|
@patch('src.reporting.csv_exporter.get_csv_exporter')
|
||
|
|
@patch('src.worker.tasks.async_to_sync')
|
||
|
|
def test_export_orders(self, mock_async_to_sync, mock_exporter_func):
|
||
|
|
"""Test order export."""
|
||
|
|
mock_exporter = Mock()
|
||
|
|
mock_exporter.export_orders.return_value = True
|
||
|
|
mock_exporter_func.return_value = mock_exporter
|
||
|
|
mock_async_to_sync.return_value = [] # Empty orders list
|
||
|
|
|
||
|
|
from src.worker.tasks import export_data_task
|
||
|
|
|
||
|
|
result = export_data_task.run("orders", {})
|
||
|
|
|
||
|
|
assert result["status"] == "success"
|
||
|
|
assert result["export_type"] == "orders"
|
||
|
|
|
||
|
|
def test_export_unknown_type(self):
|
||
|
|
"""Test export with unknown type."""
|
||
|
|
from src.worker.tasks import export_data_task
|
||
|
|
|
||
|
|
result = export_data_task.run("unknown", {})
|
||
|
|
|
||
|
|
assert result["status"] == "error"
|