123 lines
4.1 KiB
Python
123 lines
4.1 KiB
Python
"""Tests for ATR-based stop loss."""
|
|
|
|
import pytest
|
|
import pandas as pd
|
|
from decimal import Decimal
|
|
from src.risk.stop_loss import StopLossManager
|
|
|
|
|
|
class TestATRStopLoss:
|
|
"""Tests for ATR-based stop loss functionality."""
|
|
|
|
@pytest.fixture
|
|
def stop_loss_manager(self):
|
|
"""Create stop loss manager instance."""
|
|
return StopLossManager()
|
|
|
|
@pytest.fixture
|
|
def sample_ohlcv_data(self):
|
|
"""Create sample OHLCV data."""
|
|
dates = pd.date_range(start='2025-01-01', periods=50, freq='1H')
|
|
base_price = 50000
|
|
return pd.DataFrame({
|
|
'high': [base_price + 100 + i * 10 for i in range(50)],
|
|
'low': [base_price - 100 + i * 10 for i in range(50)],
|
|
'close': [base_price + i * 10 for i in range(50)],
|
|
'open': [base_price - 50 + i * 10 for i in range(50)],
|
|
'volume': [1000.0] * 50
|
|
}, index=dates)
|
|
|
|
def test_set_atr_stop_loss(self, stop_loss_manager, sample_ohlcv_data):
|
|
"""Test setting ATR-based stop loss."""
|
|
position_id = 1
|
|
entry_price = Decimal("50000")
|
|
|
|
stop_loss_manager.set_stop_loss(
|
|
position_id=position_id,
|
|
stop_price=entry_price,
|
|
use_atr=True,
|
|
atr_multiplier=Decimal('2.0'),
|
|
atr_period=14,
|
|
ohlcv_data=sample_ohlcv_data
|
|
)
|
|
|
|
assert position_id in stop_loss_manager.stop_losses
|
|
config = stop_loss_manager.stop_losses[position_id]
|
|
assert config['use_atr'] is True
|
|
assert config['atr_multiplier'] == Decimal('2.0')
|
|
assert config['atr_period'] == 14
|
|
assert 'stop_price' in config
|
|
|
|
def test_calculate_atr_stop(self, stop_loss_manager, sample_ohlcv_data):
|
|
"""Test calculating ATR stop price."""
|
|
entry_price = Decimal("50000")
|
|
|
|
stop_price_long = stop_loss_manager.calculate_atr_stop(
|
|
entry_price=entry_price,
|
|
is_long=True,
|
|
ohlcv_data=sample_ohlcv_data,
|
|
atr_multiplier=Decimal('2.0'),
|
|
atr_period=14
|
|
)
|
|
|
|
stop_price_short = stop_loss_manager.calculate_atr_stop(
|
|
entry_price=entry_price,
|
|
is_long=False,
|
|
ohlcv_data=sample_ohlcv_data,
|
|
atr_multiplier=Decimal('2.0'),
|
|
atr_period=14
|
|
)
|
|
|
|
# Long position: stop should be below entry
|
|
assert stop_price_long < entry_price
|
|
|
|
# Short position: stop should be above entry
|
|
assert stop_price_short > entry_price
|
|
|
|
def test_atr_trailing_stop(self, stop_loss_manager, sample_ohlcv_data):
|
|
"""Test ATR-based trailing stop."""
|
|
position_id = 1
|
|
entry_price = Decimal("50000")
|
|
|
|
stop_loss_manager.set_stop_loss(
|
|
position_id=position_id,
|
|
stop_price=entry_price,
|
|
trailing=True,
|
|
use_atr=True,
|
|
atr_multiplier=Decimal('2.0'),
|
|
atr_period=14,
|
|
ohlcv_data=sample_ohlcv_data
|
|
)
|
|
|
|
# Check stop loss with higher price (should update trailing stop)
|
|
current_price = Decimal("51000")
|
|
triggered = stop_loss_manager.check_stop_loss(
|
|
position_id=position_id,
|
|
current_price=current_price,
|
|
is_long=True,
|
|
ohlcv_data=sample_ohlcv_data
|
|
)
|
|
|
|
# Should not trigger at higher price
|
|
assert triggered is False
|
|
|
|
def test_atr_stop_insufficient_data(self, stop_loss_manager):
|
|
"""Test ATR stop with insufficient data falls back to percentage."""
|
|
entry_price = Decimal("50000")
|
|
insufficient_data = pd.DataFrame({
|
|
'high': [51000],
|
|
'low': [49000],
|
|
'close': [50000]
|
|
})
|
|
|
|
stop_price = stop_loss_manager.calculate_atr_stop(
|
|
entry_price=entry_price,
|
|
is_long=True,
|
|
ohlcv_data=insufficient_data,
|
|
atr_period=14
|
|
)
|
|
|
|
# Should fall back to percentage-based stop (2%)
|
|
assert stop_price == entry_price * Decimal('0.98')
|
|
|