Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
1
tests/unit/core/__init__.py
Normal file
1
tests/unit/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test init file."""
|
||||
44
tests/unit/core/test_config.py
Normal file
44
tests/unit/core/test_config.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for configuration system."""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from src.core.config import Config, get_config
|
||||
|
||||
|
||||
class TestConfig:
|
||||
"""Tests for Config class."""
|
||||
|
||||
def test_config_initialization(self, tmp_path):
|
||||
"""Test config initialization."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config = Config(config_file=str(config_file))
|
||||
assert config is not None
|
||||
assert config.config_dir is not None
|
||||
assert config.data_dir is not None
|
||||
|
||||
def test_config_get(self, tmp_path):
|
||||
"""Test config get method."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config = Config(config_file=str(config_file))
|
||||
# Test nested key access
|
||||
value = config.get('paper_trading.default_capital')
|
||||
assert value is not None
|
||||
|
||||
def test_config_set(self, tmp_path):
|
||||
"""Test config set method."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config = Config(config_file=str(config_file))
|
||||
config.set('paper_trading.default_capital', 200.0)
|
||||
value = config.get('paper_trading.default_capital')
|
||||
assert value == 200.0
|
||||
|
||||
def test_config_defaults(self, tmp_path):
|
||||
"""Test default configuration values."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config = Config(config_file=str(config_file))
|
||||
assert config.get('paper_trading.default_capital') == 100.0
|
||||
assert config.get('database.type') == 'postgresql'
|
||||
|
||||
97
tests/unit/core/test_database.py
Normal file
97
tests/unit/core/test_database.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Tests for database system."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from src.core.database import (
|
||||
get_database, Base, Exchange, Strategy, Trade, Position, Order
|
||||
)
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
"""Tests for database operations."""
|
||||
|
||||
def test_database_initialization(self):
|
||||
"""Test database initialization."""
|
||||
db = get_database()
|
||||
assert db is not None
|
||||
assert db.engine is not None
|
||||
|
||||
def test_table_creation(self, mock_database):
|
||||
"""Test table creation."""
|
||||
engine, Session = mock_database
|
||||
# Verify tables exist
|
||||
assert Base.metadata.tables.get('exchanges') is not None
|
||||
assert Base.metadata.tables.get('strategies') is not None
|
||||
assert Base.metadata.tables.get('trades') is not None
|
||||
|
||||
def test_exchange_model(self, mock_database):
|
||||
"""Test Exchange model."""
|
||||
engine, Session = mock_database
|
||||
session = Session()
|
||||
|
||||
exchange = Exchange(
|
||||
name="test_exchange",
|
||||
api_key="encrypted_key",
|
||||
secret_key="encrypted_secret",
|
||||
api_permissions="read_only",
|
||||
is_enabled=True
|
||||
)
|
||||
session.add(exchange)
|
||||
session.commit()
|
||||
|
||||
retrieved = session.query(Exchange).filter_by(name="test_exchange").first()
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "test_exchange"
|
||||
assert retrieved.api_permissions == "read_only"
|
||||
|
||||
session.close()
|
||||
|
||||
def test_strategy_model(self, mock_database):
|
||||
"""Test Strategy model."""
|
||||
engine, Session = mock_database
|
||||
session = Session()
|
||||
|
||||
strategy = Strategy(
|
||||
name="test_strategy",
|
||||
strategy_type="RSI",
|
||||
parameters='{"rsi_period": 14}',
|
||||
is_enabled=True,
|
||||
is_paper_trading=True
|
||||
)
|
||||
session.add(strategy)
|
||||
session.commit()
|
||||
|
||||
retrieved = session.query(Strategy).filter_by(name="test_strategy").first()
|
||||
assert retrieved is not None
|
||||
assert retrieved.strategy_type == "RSI"
|
||||
|
||||
session.close()
|
||||
|
||||
def test_trade_model(self, mock_database):
|
||||
"""Test Trade model."""
|
||||
engine, Session = mock_database
|
||||
session = Session()
|
||||
|
||||
trade = Trade(
|
||||
order_id="test_order_123",
|
||||
symbol="BTC/USD",
|
||||
side="buy",
|
||||
type="market",
|
||||
price=50000.0,
|
||||
amount=0.01,
|
||||
cost=500.0,
|
||||
fee=0.5,
|
||||
status="filled",
|
||||
is_paper_trade=True
|
||||
)
|
||||
session.add(trade)
|
||||
session.commit()
|
||||
|
||||
retrieved = session.query(Trade).filter_by(order_id="test_order_123").first()
|
||||
assert retrieved is not None
|
||||
assert retrieved.symbol == "BTC/USD"
|
||||
assert retrieved.status == "filled"
|
||||
|
||||
session.close()
|
||||
|
||||
43
tests/unit/core/test_logger.py
Normal file
43
tests/unit/core/test_logger.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Tests for logging system."""
|
||||
|
||||
import pytest
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from src.core.logger import setup_logging, get_logger
|
||||
|
||||
|
||||
class TestLogger:
|
||||
"""Tests for logging system."""
|
||||
|
||||
def test_logger_setup(self, test_log_dir):
|
||||
"""Test logger setup."""
|
||||
with patch('src.core.logger.get_config') as mock_get_config:
|
||||
mock_config = mock_get_config.return_value
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
'logging.dir': str(test_log_dir),
|
||||
'logging.retention_days': 30,
|
||||
'logging.level': 'INFO'
|
||||
}.get(key, default)
|
||||
|
||||
setup_logging()
|
||||
logger = get_logger('test')
|
||||
assert logger is not None
|
||||
assert isinstance(logger, logging.Logger)
|
||||
|
||||
def test_logger_get(self):
|
||||
"""Test getting logger instance."""
|
||||
logger = get_logger('test_module')
|
||||
assert logger is not None
|
||||
assert logger.name == 'test_module'
|
||||
|
||||
def test_logger_levels(self):
|
||||
"""Test different log levels."""
|
||||
logger = get_logger('test')
|
||||
# Should not raise exceptions
|
||||
logger.debug("Debug message")
|
||||
logger.info("Info message")
|
||||
logger.warning("Warning message")
|
||||
logger.error("Error message")
|
||||
logger.critical("Critical message")
|
||||
|
||||
92
tests/unit/core/test_redis.py
Normal file
92
tests/unit/core/test_redis.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for Redis client wrapper."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
|
||||
|
||||
class TestRedisClient:
|
||||
"""Tests for RedisClient class."""
|
||||
|
||||
@patch('src.core.redis.get_config')
|
||||
def test_get_client_creates_connection(self, mock_config):
|
||||
"""Test that get_client creates a Redis connection."""
|
||||
# Setup mock config
|
||||
mock_config.return_value.get.return_value = {
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"db": 0,
|
||||
"password": None,
|
||||
"socket_connect_timeout": 5
|
||||
}
|
||||
|
||||
from src.core.redis import RedisClient
|
||||
|
||||
client = RedisClient()
|
||||
|
||||
# Should not have connected yet
|
||||
assert client._client is None
|
||||
|
||||
# Get client should trigger connection
|
||||
with patch('src.core.redis.redis.ConnectionPool') as mock_pool:
|
||||
with patch('src.core.redis.redis.Redis') as mock_redis:
|
||||
redis_client = client.get_client()
|
||||
|
||||
mock_pool.assert_called_once()
|
||||
mock_redis.assert_called_once()
|
||||
|
||||
@patch('src.core.redis.get_config')
|
||||
def test_get_client_reuses_existing(self, mock_config):
|
||||
"""Test that get_client reuses existing connection."""
|
||||
mock_config.return_value.get.return_value = {
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"db": 0,
|
||||
}
|
||||
|
||||
from src.core.redis import RedisClient
|
||||
|
||||
client = RedisClient()
|
||||
|
||||
# Pre-set a mock client
|
||||
mock_redis = Mock()
|
||||
client._client = mock_redis
|
||||
|
||||
# Should return existing
|
||||
result = client.get_client()
|
||||
assert result is mock_redis
|
||||
|
||||
@patch('src.core.redis.get_config')
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_connection(self, mock_config):
|
||||
"""Test closing Redis connection."""
|
||||
mock_config.return_value.get.return_value = {"host": "localhost"}
|
||||
|
||||
from src.core.redis import RedisClient
|
||||
|
||||
client = RedisClient()
|
||||
mock_redis = AsyncMock()
|
||||
client._client = mock_redis
|
||||
|
||||
await client.close()
|
||||
|
||||
mock_redis.aclose.assert_called_once()
|
||||
|
||||
|
||||
class TestGetRedisClient:
|
||||
"""Tests for get_redis_client singleton."""
|
||||
|
||||
@patch('src.core.redis.get_config')
|
||||
def test_returns_singleton(self, mock_config):
|
||||
"""Test that get_redis_client returns same instance."""
|
||||
mock_config.return_value.get.return_value = {"host": "localhost"}
|
||||
|
||||
# Reset the global
|
||||
import src.core.redis as redis_module
|
||||
redis_module._redis_client = None
|
||||
|
||||
from src.core.redis import get_redis_client
|
||||
|
||||
client1 = get_redis_client()
|
||||
client2 = get_redis_client()
|
||||
|
||||
assert client1 is client2
|
||||
Reference in New Issue
Block a user