90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
|
|
"""Tests for base strategy class."""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import pandas as pd
|
||
|
|
from src.strategies.base import BaseStrategy, StrategyRegistry
|
||
|
|
|
||
|
|
|
||
|
|
class ConcreteStrategy(BaseStrategy):
|
||
|
|
"""Concrete strategy for testing."""
|
||
|
|
|
||
|
|
async def on_data(self, new_data: pd.DataFrame):
|
||
|
|
"""Handle new data."""
|
||
|
|
self.current_data = pd.concat([self.current_data, new_data]).tail(100)
|
||
|
|
|
||
|
|
async def generate_signal(self):
|
||
|
|
"""Generate signal."""
|
||
|
|
if len(self.current_data) > 0:
|
||
|
|
return {"signal": "hold", "price": self.current_data['close'].iloc[-1]}
|
||
|
|
return {"signal": "hold", "price": None}
|
||
|
|
|
||
|
|
async def calculate_position_size(self, capital: float, risk_percentage: float) -> float:
|
||
|
|
"""Calculate position size."""
|
||
|
|
return capital * risk_percentage
|
||
|
|
|
||
|
|
|
||
|
|
class TestBaseStrategy:
|
||
|
|
"""Tests for BaseStrategy."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def strategy(self):
|
||
|
|
"""Create strategy instance."""
|
||
|
|
return ConcreteStrategy(
|
||
|
|
strategy_id=1,
|
||
|
|
name="test_strategy",
|
||
|
|
symbol="BTC/USD",
|
||
|
|
timeframe="1h",
|
||
|
|
parameters={}
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_strategy_initialization(self, strategy):
|
||
|
|
"""Test strategy initialization."""
|
||
|
|
assert strategy.strategy_id == 1
|
||
|
|
assert strategy.name == "test_strategy"
|
||
|
|
assert strategy.symbol == "BTC/USD"
|
||
|
|
assert strategy.timeframe == "1h"
|
||
|
|
assert not strategy.is_active
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_strategy_start_stop(self, strategy):
|
||
|
|
"""Test strategy start and stop."""
|
||
|
|
await strategy.start()
|
||
|
|
assert strategy.is_active
|
||
|
|
|
||
|
|
await strategy.stop()
|
||
|
|
assert not strategy.is_active
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_signal(self, strategy):
|
||
|
|
"""Test signal generation."""
|
||
|
|
signal = await strategy.generate_signal()
|
||
|
|
assert "signal" in signal
|
||
|
|
assert signal["signal"] in ["buy", "sell", "hold"]
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_calculate_position_size(self, strategy):
|
||
|
|
"""Test position size calculation."""
|
||
|
|
size = await strategy.calculate_position_size(1000.0, 0.01)
|
||
|
|
assert size == 10.0
|
||
|
|
|
||
|
|
|
||
|
|
class TestStrategyRegistry:
|
||
|
|
"""Tests for StrategyRegistry."""
|
||
|
|
|
||
|
|
def test_register_strategy(self):
|
||
|
|
"""Test strategy registration."""
|
||
|
|
StrategyRegistry.register_strategy("test_strategy", ConcreteStrategy)
|
||
|
|
assert "test_strategy" in StrategyRegistry.list_available()
|
||
|
|
|
||
|
|
def test_get_strategy_class(self):
|
||
|
|
"""Test getting strategy class."""
|
||
|
|
StrategyRegistry.register_strategy("test_strategy", ConcreteStrategy)
|
||
|
|
strategy_class = StrategyRegistry.get_strategy_class("test_strategy")
|
||
|
|
assert strategy_class == ConcreteStrategy
|
||
|
|
|
||
|
|
def test_get_nonexistent_strategy(self):
|
||
|
|
"""Test getting non-existent strategy."""
|
||
|
|
with pytest.raises(ValueError, match="not registered"):
|
||
|
|
StrategyRegistry.get_strategy_class("nonexistent")
|
||
|
|
|