Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
122
tests/unit/risk/test_stop_loss_atr.py
Normal file
122
tests/unit/risk/test_stop_loss_atr.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Tests for ATR-based stop loss."""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from src.risk.stop_loss import StopLossManager
|
||||
|
||||
|
||||
class TestATRStopLoss:
|
||||
"""Tests for ATR-based stop loss functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def stop_loss_manager(self):
|
||||
"""Create stop loss manager instance."""
|
||||
return StopLossManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ohlcv_data(self):
|
||||
"""Create sample OHLCV data."""
|
||||
dates = pd.date_range(start='2025-01-01', periods=50, freq='1H')
|
||||
base_price = 50000
|
||||
return pd.DataFrame({
|
||||
'high': [base_price + 100 + i * 10 for i in range(50)],
|
||||
'low': [base_price - 100 + i * 10 for i in range(50)],
|
||||
'close': [base_price + i * 10 for i in range(50)],
|
||||
'open': [base_price - 50 + i * 10 for i in range(50)],
|
||||
'volume': [1000.0] * 50
|
||||
}, index=dates)
|
||||
|
||||
def test_set_atr_stop_loss(self, stop_loss_manager, sample_ohlcv_data):
|
||||
"""Test setting ATR-based stop loss."""
|
||||
position_id = 1
|
||||
entry_price = Decimal("50000")
|
||||
|
||||
stop_loss_manager.set_stop_loss(
|
||||
position_id=position_id,
|
||||
stop_price=entry_price,
|
||||
use_atr=True,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14,
|
||||
ohlcv_data=sample_ohlcv_data
|
||||
)
|
||||
|
||||
assert position_id in stop_loss_manager.stop_losses
|
||||
config = stop_loss_manager.stop_losses[position_id]
|
||||
assert config['use_atr'] is True
|
||||
assert config['atr_multiplier'] == Decimal('2.0')
|
||||
assert config['atr_period'] == 14
|
||||
assert 'stop_price' in config
|
||||
|
||||
def test_calculate_atr_stop(self, stop_loss_manager, sample_ohlcv_data):
|
||||
"""Test calculating ATR stop price."""
|
||||
entry_price = Decimal("50000")
|
||||
|
||||
stop_price_long = stop_loss_manager.calculate_atr_stop(
|
||||
entry_price=entry_price,
|
||||
is_long=True,
|
||||
ohlcv_data=sample_ohlcv_data,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14
|
||||
)
|
||||
|
||||
stop_price_short = stop_loss_manager.calculate_atr_stop(
|
||||
entry_price=entry_price,
|
||||
is_long=False,
|
||||
ohlcv_data=sample_ohlcv_data,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14
|
||||
)
|
||||
|
||||
# Long position: stop should be below entry
|
||||
assert stop_price_long < entry_price
|
||||
|
||||
# Short position: stop should be above entry
|
||||
assert stop_price_short > entry_price
|
||||
|
||||
def test_atr_trailing_stop(self, stop_loss_manager, sample_ohlcv_data):
|
||||
"""Test ATR-based trailing stop."""
|
||||
position_id = 1
|
||||
entry_price = Decimal("50000")
|
||||
|
||||
stop_loss_manager.set_stop_loss(
|
||||
position_id=position_id,
|
||||
stop_price=entry_price,
|
||||
trailing=True,
|
||||
use_atr=True,
|
||||
atr_multiplier=Decimal('2.0'),
|
||||
atr_period=14,
|
||||
ohlcv_data=sample_ohlcv_data
|
||||
)
|
||||
|
||||
# Check stop loss with higher price (should update trailing stop)
|
||||
current_price = Decimal("51000")
|
||||
triggered = stop_loss_manager.check_stop_loss(
|
||||
position_id=position_id,
|
||||
current_price=current_price,
|
||||
is_long=True,
|
||||
ohlcv_data=sample_ohlcv_data
|
||||
)
|
||||
|
||||
# Should not trigger at higher price
|
||||
assert triggered is False
|
||||
|
||||
def test_atr_stop_insufficient_data(self, stop_loss_manager):
|
||||
"""Test ATR stop with insufficient data falls back to percentage."""
|
||||
entry_price = Decimal("50000")
|
||||
insufficient_data = pd.DataFrame({
|
||||
'high': [51000],
|
||||
'low': [49000],
|
||||
'close': [50000]
|
||||
})
|
||||
|
||||
stop_price = stop_loss_manager.calculate_atr_stop(
|
||||
entry_price=entry_price,
|
||||
is_long=True,
|
||||
ohlcv_data=insufficient_data,
|
||||
atr_period=14
|
||||
)
|
||||
|
||||
# Should fall back to percentage-based stop (2%)
|
||||
assert stop_price == entry_price * Decimal('0.98')
|
||||
|
||||
Reference in New Issue
Block a user