Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
1
tests/unit/worker/__init__.py
Normal file
1
tests/unit/worker/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test init file."""
|
||||
178
tests/unit/worker/test_tasks.py
Normal file
178
tests/unit/worker/test_tasks.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Tests for Celery tasks."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
|
||||
|
||||
class TestAsyncToSync:
|
||||
"""Tests for async_to_sync helper."""
|
||||
|
||||
def test_runs_awaitable(self):
|
||||
"""Test that async_to_sync runs awaitable and returns result."""
|
||||
from src.worker.tasks import async_to_sync
|
||||
|
||||
async def async_func():
|
||||
return "test_result"
|
||||
|
||||
result = async_to_sync(async_func())
|
||||
assert result == "test_result"
|
||||
|
||||
def test_handles_exception(self):
|
||||
"""Test that async_to_sync propagates exceptions."""
|
||||
from src.worker.tasks import async_to_sync
|
||||
|
||||
async def async_error():
|
||||
raise ValueError("test error")
|
||||
|
||||
with pytest.raises(ValueError, match="test error"):
|
||||
async_to_sync(async_error())
|
||||
|
||||
|
||||
class TestTrainModelTask:
|
||||
"""Tests for train_model_task."""
|
||||
|
||||
@patch('src.worker.tasks.get_strategy_selector')
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_train_model_basic(self, mock_async_to_sync, mock_get_selector):
|
||||
"""Test basic model training task."""
|
||||
# Setup mocks
|
||||
mock_selector = Mock()
|
||||
mock_selector.bootstrap_symbols = ["BTC/USD"]
|
||||
mock_get_selector.return_value = mock_selector
|
||||
|
||||
mock_async_to_sync.side_effect = [
|
||||
{"X": [1, 2, 3]}, # prepare_training_data result
|
||||
{"accuracy": 0.9} # train_model result
|
||||
]
|
||||
|
||||
from src.worker.tasks import train_model_task
|
||||
|
||||
# Call the task directly - Celery will bind self automatically
|
||||
# For testing, we need to access the underlying function
|
||||
result = train_model_task.run(force_retrain=True, bootstrap=False)
|
||||
|
||||
assert result == {"accuracy": 0.9}
|
||||
mock_get_selector.assert_called_once()
|
||||
|
||||
@patch('src.worker.tasks.get_strategy_selector')
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_train_model_with_bootstrap(self, mock_async_to_sync, mock_get_selector):
|
||||
"""Test model training with bootstrapping."""
|
||||
mock_selector = Mock()
|
||||
mock_selector.bootstrap_symbols = ["BTC/USD", "ETH/USD"]
|
||||
mock_get_selector.return_value = mock_selector
|
||||
|
||||
# First call returns empty data, triggering bootstrap
|
||||
mock_async_to_sync.side_effect = [
|
||||
{"X": []}, # Empty training data
|
||||
{"total_samples": 100}, # First symbol bootstrap
|
||||
{"total_samples": 50}, # Second symbol bootstrap
|
||||
{"accuracy": 0.85} # Final training
|
||||
]
|
||||
|
||||
from src.worker.tasks import train_model_task
|
||||
|
||||
result = train_model_task.run(force_retrain=False, bootstrap=True)
|
||||
|
||||
assert result == {"accuracy": 0.85}
|
||||
|
||||
|
||||
class TestBootstrapTask:
|
||||
"""Tests for bootstrap_task."""
|
||||
|
||||
@patch('src.worker.tasks.get_strategy_selector')
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_bootstrap_basic(self, mock_async_to_sync, mock_get_selector):
|
||||
"""Test basic bootstrap task."""
|
||||
mock_selector = Mock()
|
||||
mock_get_selector.return_value = mock_selector
|
||||
mock_async_to_sync.return_value = {"total_samples": 200}
|
||||
|
||||
from src.worker.tasks import bootstrap_task
|
||||
|
||||
result = bootstrap_task.run(days=90, timeframe="1h")
|
||||
|
||||
assert result == {"total_samples": 200}
|
||||
|
||||
|
||||
class TestGenerateReportTask:
|
||||
"""Tests for generate_report_task."""
|
||||
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_generate_report_unknown_type(self, mock_async_to_sync):
|
||||
"""Test report generation with unknown type."""
|
||||
from src.worker.tasks import generate_report_task
|
||||
|
||||
result = generate_report_task.run("unknown", {})
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "Unknown report type" in result["message"]
|
||||
|
||||
|
||||
class TestOptimizeStrategyTask:
|
||||
"""Tests for optimize_strategy_task."""
|
||||
|
||||
@patch('src.optimization.genetic.GeneticOptimizer')
|
||||
def test_optimize_genetic_basic(self, mock_optimizer_class):
|
||||
"""Test basic genetic optimization."""
|
||||
from src.worker.tasks import optimize_strategy_task
|
||||
|
||||
mock_optimizer = Mock()
|
||||
mock_optimizer.optimize.return_value = {
|
||||
"best_params": {"period": 14},
|
||||
"best_score": 0.85
|
||||
}
|
||||
mock_optimizer_class.return_value = mock_optimizer
|
||||
|
||||
result = optimize_strategy_task.run(
|
||||
strategy_type="rsi",
|
||||
symbol="BTC/USD",
|
||||
param_ranges={"period": (5, 50)},
|
||||
method="genetic",
|
||||
population_size=10,
|
||||
generations=5
|
||||
)
|
||||
|
||||
assert result["best_params"] == {"period": 14}
|
||||
assert result["best_score"] == 0.85
|
||||
|
||||
def test_optimize_unknown_method(self):
|
||||
"""Test optimization with unknown method."""
|
||||
from src.worker.tasks import optimize_strategy_task
|
||||
|
||||
result = optimize_strategy_task.run(
|
||||
strategy_type="rsi",
|
||||
symbol="BTC/USD",
|
||||
param_ranges={"period": (5, 50)},
|
||||
method="unknown_method"
|
||||
)
|
||||
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestExportDataTask:
|
||||
"""Tests for export_data_task."""
|
||||
|
||||
@patch('src.reporting.csv_exporter.get_csv_exporter')
|
||||
@patch('src.worker.tasks.async_to_sync')
|
||||
def test_export_orders(self, mock_async_to_sync, mock_exporter_func):
|
||||
"""Test order export."""
|
||||
mock_exporter = Mock()
|
||||
mock_exporter.export_orders.return_value = True
|
||||
mock_exporter_func.return_value = mock_exporter
|
||||
mock_async_to_sync.return_value = [] # Empty orders list
|
||||
|
||||
from src.worker.tasks import export_data_task
|
||||
|
||||
result = export_data_task.run("orders", {})
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["export_type"] == "orders"
|
||||
|
||||
def test_export_unknown_type(self):
|
||||
"""Test export with unknown type."""
|
||||
from src.worker.tasks import export_data_task
|
||||
|
||||
result = export_data_task.run("unknown", {})
|
||||
|
||||
assert result["status"] == "error"
|
||||
Reference in New Issue
Block a user