Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
2
tests/unit/risk/__init__.py
Normal file
2
tests/unit/risk/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for risk management."""
|
||||
|
||||
55
tests/unit/risk/test_manager.py
Normal file
55
tests/unit/risk/test_manager.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Tests for risk manager."""
|
||||
|
||||
import pytest
|
||||
from src.risk.manager import get_risk_manager, RiskManager
|
||||
|
||||
|
||||
class TestRiskManager:
|
||||
"""Tests for RiskManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def risk_manager(self):
|
||||
"""Create risk manager instance."""
|
||||
return get_risk_manager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_trade_risk(self, risk_manager):
|
||||
"""Test trade risk checking."""
|
||||
# Test with valid trade
|
||||
result = await risk_manager.check_trade_risk(
|
||||
exchange_id=1,
|
||||
strategy_id=1,
|
||||
symbol="BTC/USD",
|
||||
side="buy",
|
||||
amount=0.01,
|
||||
price=50000.0,
|
||||
current_portfolio_value=10000.0
|
||||
)
|
||||
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_max_drawdown(self, risk_manager):
|
||||
"""Test max drawdown check."""
|
||||
result = await risk_manager.check_max_drawdown(
|
||||
current_portfolio_value=9000.0,
|
||||
peak_portfolio_value=10000.0
|
||||
)
|
||||
|
||||
assert isinstance(result, bool)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_risk_limit(self, risk_manager, mock_database):
|
||||
"""Test adding risk limit."""
|
||||
engine, Session = mock_database
|
||||
|
||||
await risk_manager.add_risk_limit(
|
||||
limit_type="max_drawdown",
|
||||
value=0.10, # 10%
|
||||
is_active=True
|
||||
)
|
||||
|
||||
# Verify limit was added
|
||||
await risk_manager.load_risk_limits()
|
||||
assert len(risk_manager.risk_limits) > 0
|
||||
|
||||
41
tests/unit/risk/test_position_sizing.py
Normal file
41
tests/unit/risk/test_position_sizing.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Tests for position sizing."""
|
||||
|
||||
import pytest
|
||||
from src.risk.position_sizing import get_position_sizer, PositionSizing
|
||||
|
||||
|
||||
class TestPositionSizing:
|
||||
"""Tests for PositionSizing."""
|
||||
|
||||
@pytest.fixture
|
||||
def position_sizer(self):
|
||||
"""Create position sizer instance."""
|
||||
return get_position_sizer()
|
||||
|
||||
def test_fixed_percentage(self, position_sizer):
|
||||
"""Test fixed percentage sizing."""
|
||||
size = position_sizer.fixed_percentage(10000.0, 0.02) # 2%
|
||||
assert size == 200.0
|
||||
|
||||
def test_fixed_amount(self, position_sizer):
|
||||
"""Test fixed amount sizing."""
|
||||
size = position_sizer.fixed_amount(500.0)
|
||||
assert size == 500.0
|
||||
|
||||
def test_volatility_based(self, position_sizer):
|
||||
"""Test volatility-based sizing."""
|
||||
size = position_sizer.volatility_based(
|
||||
capital=10000.0,
|
||||
atr=100.0,
|
||||
risk_per_trade_percentage=0.01 # 1%
|
||||
)
|
||||
assert size > 0
|
||||
|
||||
def test_kelly_criterion(self, position_sizer):
|
||||
"""Test Kelly Criterion."""
|
||||
fraction = position_sizer.kelly_criterion(
|
||||
win_probability=0.6,
|
||||
payout_ratio=1.5
|
||||
)
|
||||
assert 0 <= fraction <= 1
|
||||
|
||||
122
tests/unit/risk/test_stop_loss_atr.py
Normal file
122
tests/unit/risk/test_stop_loss_atr.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Tests for ATR-based stop loss."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from src.risk.stop_loss import StopLossManager
|
||||
|
||||
|
||||
class TestATRStopLoss:
|
||||
"""Tests for ATR-based stop loss functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def stop_loss_manager(self):
|
||||
"""Create stop loss manager instance."""
|
||||
return StopLossManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ohlcv_data(self):
|
||||
"""Create sample OHLCV data."""
|
||||
dates = pd.date_range(start='2025-01-01', periods=50, freq='1H')
|
||||
base_price = 50000
|
||||
return pd.DataFrame({
|
||||
'high': [base_price + 100 + i * 10 for i in range(50)],
|
||||
'low': [base_price - 100 + i * 10 for i in range(50)],
|
||||
'close': [base_price + i * 10 for i in range(50)],
|
||||
'open': [base_price - 50 + i * 10 for i in range(50)],
|
||||
'volume': [1000.0] * 50
|
||||
}, index=dates)
|
||||
|
||||
def test_set_atr_stop_loss(self, stop_loss_manager, sample_ohlcv_data):
|
||||
"""Test setting ATR-based stop loss."""
|
||||
position_id = 1
|
||||
entry_price = Decimal("50000")
|
||||
|
||||
stop_loss_manager.set_stop_loss(
|
||||
position_id=position_id,
|
||||
stop_price=entry_price,
|
||||
use_atr=True,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14,
|
||||
ohlcv_data=sample_ohlcv_data
|
||||
)
|
||||
|
||||
assert position_id in stop_loss_manager.stop_losses
|
||||
config = stop_loss_manager.stop_losses[position_id]
|
||||
assert config['use_atr'] is True
|
||||
assert config['atr_multiplier'] == Decimal('2.0')
|
||||
assert config['atr_period'] == 14
|
||||
assert 'stop_price' in config
|
||||
|
||||
def test_calculate_atr_stop(self, stop_loss_manager, sample_ohlcv_data):
|
||||
"""Test calculating ATR stop price."""
|
||||
entry_price = Decimal("50000")
|
||||
|
||||
stop_price_long = stop_loss_manager.calculate_atr_stop(
|
||||
entry_price=entry_price,
|
||||
is_long=True,
|
||||
ohlcv_data=sample_ohlcv_data,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14
|
||||
)
|
||||
|
||||
stop_price_short = stop_loss_manager.calculate_atr_stop(
|
||||
entry_price=entry_price,
|
||||
is_long=False,
|
||||
ohlcv_data=sample_ohlcv_data,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14
|
||||
)
|
||||
|
||||
# Long position: stop should be below entry
|
||||
assert stop_price_long < entry_price
|
||||
|
||||
# Short position: stop should be above entry
|
||||
assert stop_price_short > entry_price
|
||||
|
||||
def test_atr_trailing_stop(self, stop_loss_manager, sample_ohlcv_data):
|
||||
"""Test ATR-based trailing stop."""
|
||||
position_id = 1
|
||||
entry_price = Decimal("50000")
|
||||
|
||||
stop_loss_manager.set_stop_loss(
|
||||
position_id=position_id,
|
||||
stop_price=entry_price,
|
||||
trailing=True,
|
||||
use_atr=True,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14,
|
||||
ohlcv_data=sample_ohlcv_data
|
||||
)
|
||||
|
||||
# Check stop loss with higher price (should update trailing stop)
|
||||
current_price = Decimal("51000")
|
||||
triggered = stop_loss_manager.check_stop_loss(
|
||||
position_id=position_id,
|
||||
current_price=current_price,
|
||||
is_long=True,
|
||||
ohlcv_data=sample_ohlcv_data
|
||||
)
|
||||
|
||||
# Should not trigger at higher price
|
||||
assert triggered is False
|
||||
|
||||
def test_atr_stop_insufficient_data(self, stop_loss_manager):
|
||||
"""Test ATR stop with insufficient data falls back to percentage."""
|
||||
entry_price = Decimal("50000")
|
||||
insufficient_data = pd.DataFrame({
|
||||
'high': [51000],
|
||||
'low': [49000],
|
||||
'close': [50000]
|
||||
})
|
||||
|
||||
stop_price = stop_loss_manager.calculate_atr_stop(
|
||||
entry_price=entry_price,
|
||||
is_long=True,
|
||||
ohlcv_data=insufficient_data,
|
||||
atr_period=14
|
||||
)
|
||||
|
||||
# Should fall back to percentage-based stop (2%)
|
||||
assert stop_price == entry_price * Decimal('0.98')
|
||||
|
||||
Reference in New Issue
Block a user