Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
2
tests/unit/trading/__init__.py
Normal file
2
tests/unit/trading/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Unit tests for trading engine."""
|
||||
|
||||
88
tests/unit/trading/test_engine.py
Normal file
88
tests/unit/trading/test_engine.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Tests for trading engine."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from decimal import Decimal
|
||||
from src.trading.engine import get_trading_engine, TradingEngine
|
||||
from src.core.database import OrderSide, OrderType, OrderStatus
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTradingEngine:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_data(self, db_session):
|
||||
"""Setup test data."""
|
||||
from src.core.database import Exchange
|
||||
from sqlalchemy import select
|
||||
|
||||
# Check if exchange exists
|
||||
result = await db_session.execute(select(Exchange).where(Exchange.id == 1))
|
||||
if not result.scalar_one_or_none():
|
||||
try:
|
||||
# Try enabled (new schema)
|
||||
exchange = Exchange(id=1, name="coinbase", enabled=True)
|
||||
db_session.add(exchange)
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
# Fallback if I was wrong about schema, but I checked it.
|
||||
await db_session.rollback()
|
||||
exchange = Exchange(id=1, name="coinbase")
|
||||
db_session.add(exchange)
|
||||
print("Exchange created in setup_data")
|
||||
|
||||
async def test_execute_order_paper_trading(self, mock_exchange_adapter):
|
||||
"""Test executing paper trading order."""
|
||||
engine = TradingEngine()
|
||||
|
||||
# Mock dependencies
|
||||
engine.get_exchange_adapter = AsyncMock(return_value=mock_exchange_adapter)
|
||||
engine.risk_manager.check_order_risk = Mock(return_value=(True, ""))
|
||||
engine.paper_trading.get_balance = Mock(return_value=Decimal("10000"))
|
||||
engine.paper_trading.execute_order = Mock(return_value=True)
|
||||
# Mock logger
|
||||
engine.logger = Mock()
|
||||
|
||||
order = await engine.execute_order(
|
||||
exchange_id=1,
|
||||
strategy_id=None,
|
||||
symbol="BTC/USD",
|
||||
side=OrderSide.BUY,
|
||||
order_type=OrderType.LIMIT,
|
||||
quantity=Decimal("0.1"),
|
||||
price=Decimal("50000"),
|
||||
paper_trading=True
|
||||
)
|
||||
|
||||
# Check if error occurred
|
||||
if order is None:
|
||||
engine.logger.error.assert_called()
|
||||
call_args = engine.logger.error.call_args
|
||||
print(f"\nCaught exception in engine: {call_args}")
|
||||
|
||||
assert order is not None
|
||||
assert order.symbol == "BTC/USD"
|
||||
assert order.quantity == Decimal("0.1")
|
||||
# Status might be PENDING/OPEN depending on implementation
|
||||
assert order.status in [OrderStatus.PENDING, OrderStatus.OPEN]
|
||||
|
||||
async def test_execute_order_live_trading(self, mock_exchange_adapter):
|
||||
"""Test executing live trading order."""
|
||||
engine = TradingEngine()
|
||||
|
||||
# Mock dependencies
|
||||
engine.get_exchange_adapter = AsyncMock(return_value=mock_exchange_adapter)
|
||||
engine.risk_manager.check_order_risk = Mock(return_value=(True, ""))
|
||||
|
||||
order = await engine.execute_order(
|
||||
exchange_id=1,
|
||||
strategy_id=None,
|
||||
symbol="BTC/USD",
|
||||
side=OrderSide.BUY,
|
||||
order_type=OrderType.MARKET,
|
||||
quantity=Decimal("0.1"),
|
||||
paper_trading=False
|
||||
)
|
||||
|
||||
assert order is not None
|
||||
assert order.paper_trading is False
|
||||
mock_exchange_adapter.place_order.assert_called_once()
|
||||
161
tests/unit/trading/test_fee_calculator.py
Normal file
161
tests/unit/trading/test_fee_calculator.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Tests for fee calculator functionality."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch
|
||||
from src.trading.fee_calculator import FeeCalculator, get_fee_calculator
|
||||
from src.core.database import OrderType
|
||||
|
||||
|
||||
class TestFeeCalculator:
|
||||
"""Tests for FeeCalculator class."""
|
||||
|
||||
@pytest.fixture
|
||||
def calculator(self):
|
||||
"""Create fee calculator instance."""
|
||||
return FeeCalculator()
|
||||
|
||||
def test_get_fee_calculator_singleton(self):
|
||||
"""Test that get_fee_calculator returns singleton."""
|
||||
calc1 = get_fee_calculator()
|
||||
calc2 = get_fee_calculator()
|
||||
assert calc1 is calc2
|
||||
|
||||
def test_calculate_fee_basic(self, calculator):
|
||||
"""Test basic fee calculation."""
|
||||
fee = calculator.calculate_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('100.0'),
|
||||
order_type=OrderType.MARKET
|
||||
)
|
||||
assert fee > 0
|
||||
assert isinstance(fee, Decimal)
|
||||
|
||||
def test_calculate_fee_zero_quantity(self, calculator):
|
||||
"""Test fee calculation with zero quantity."""
|
||||
fee = calculator.calculate_fee(
|
||||
quantity=Decimal('0'),
|
||||
price=Decimal('100.0'),
|
||||
order_type=OrderType.MARKET
|
||||
)
|
||||
assert fee == Decimal('0')
|
||||
|
||||
def test_calculate_fee_maker_vs_taker(self, calculator):
|
||||
"""Test maker fees are typically lower than taker fees."""
|
||||
maker_fee = calculator.calculate_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
order_type=OrderType.LIMIT,
|
||||
is_maker=True
|
||||
)
|
||||
taker_fee = calculator.calculate_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
order_type=OrderType.MARKET,
|
||||
is_maker=False
|
||||
)
|
||||
# Maker fees should be <= taker fees
|
||||
assert maker_fee <= taker_fee
|
||||
|
||||
|
||||
class TestFeeCalculatorPaperTrading:
|
||||
"""Tests for paper trading fee calculation."""
|
||||
|
||||
@pytest.fixture
|
||||
def calculator(self):
|
||||
"""Create fee calculator instance."""
|
||||
return FeeCalculator()
|
||||
|
||||
def test_get_fee_structure_by_exchange_name_coinbase(self, calculator):
|
||||
"""Test getting Coinbase fee structure."""
|
||||
with patch.object(calculator, 'config') as mock_config:
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
'trading.default_fees': {'maker': 0.001, 'taker': 0.001},
|
||||
'trading.exchanges.coinbase.fees': {'maker': 0.004, 'taker': 0.006}
|
||||
}.get(key, default)
|
||||
|
||||
fees = calculator.get_fee_structure_by_exchange_name('coinbase')
|
||||
assert fees['maker'] == 0.004
|
||||
assert fees['taker'] == 0.006
|
||||
|
||||
def test_get_fee_structure_by_exchange_name_unknown(self, calculator):
|
||||
"""Test getting fee structure for unknown exchange returns defaults."""
|
||||
with patch.object(calculator, 'config') as mock_config:
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
'trading.default_fees': {'maker': 0.001, 'taker': 0.001},
|
||||
'trading.exchanges.unknown.fees': None
|
||||
}.get(key, default)
|
||||
|
||||
fees = calculator.get_fee_structure_by_exchange_name('unknown')
|
||||
assert fees['maker'] == 0.001
|
||||
assert fees['taker'] == 0.001
|
||||
|
||||
def test_calculate_fee_for_paper_trading(self, calculator):
|
||||
"""Test paper trading fee calculation."""
|
||||
with patch.object(calculator, 'config') as mock_config:
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
'paper_trading.fee_exchange': 'coinbase',
|
||||
'trading.exchanges.coinbase.fees': {'maker': 0.004, 'taker': 0.006},
|
||||
'trading.default_fees': {'maker': 0.001, 'taker': 0.001}
|
||||
}.get(key, default)
|
||||
|
||||
fee = calculator.calculate_fee_for_paper_trading(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
order_type=OrderType.MARKET,
|
||||
is_maker=False
|
||||
)
|
||||
# Taker fee at 0.6% of $1000 = $6
|
||||
expected = Decimal('1000.0') * Decimal('0.006')
|
||||
assert fee == expected
|
||||
|
||||
def test_calculate_fee_for_paper_trading_zero(self, calculator):
|
||||
"""Test paper trading fee with zero values."""
|
||||
fee = calculator.calculate_fee_for_paper_trading(
|
||||
quantity=Decimal('0'),
|
||||
price=Decimal('100.0'),
|
||||
order_type=OrderType.MARKET
|
||||
)
|
||||
assert fee == Decimal('0')
|
||||
|
||||
|
||||
class TestRoundTripFees:
|
||||
"""Tests for round-trip fee calculations."""
|
||||
|
||||
@pytest.fixture
|
||||
def calculator(self):
|
||||
"""Create fee calculator instance."""
|
||||
return FeeCalculator()
|
||||
|
||||
def test_estimate_round_trip_fee(self, calculator):
|
||||
"""Test round-trip fee estimation."""
|
||||
fee = calculator.estimate_round_trip_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0')
|
||||
)
|
||||
# Round-trip should include buy and sell fees
|
||||
assert fee > 0
|
||||
|
||||
single_fee = calculator.calculate_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
order_type=OrderType.MARKET
|
||||
)
|
||||
# Round-trip should be approximately 2x single fee
|
||||
assert fee >= single_fee
|
||||
|
||||
def test_get_minimum_profit_threshold(self, calculator):
|
||||
"""Test minimum profit threshold calculation."""
|
||||
threshold = calculator.get_minimum_profit_threshold(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0'),
|
||||
multiplier=2.0
|
||||
)
|
||||
assert threshold > 0
|
||||
|
||||
round_trip_fee = calculator.estimate_round_trip_fee(
|
||||
quantity=Decimal('1.0'),
|
||||
price=Decimal('1000.0')
|
||||
)
|
||||
# Threshold should be 2x round-trip fee
|
||||
assert threshold == round_trip_fee * Decimal('2.0')
|
||||
70
tests/unit/trading/test_order_manager.py
Normal file
70
tests/unit/trading/test_order_manager.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Tests for order manager."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from src.trading.order_manager import get_order_manager, OrderManager
|
||||
|
||||
|
||||
class TestOrderManager:
|
||||
"""Tests for OrderManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def order_manager(self, mock_database):
|
||||
"""Create order manager instance."""
|
||||
return get_order_manager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_order(self, order_manager, mock_database):
|
||||
"""Test order creation."""
|
||||
engine, Session = mock_database
|
||||
|
||||
order = await order_manager.create_order(
|
||||
exchange_id=1,
|
||||
strategy_id=1,
|
||||
symbol="BTC/USD",
|
||||
side="buy",
|
||||
order_type="market",
|
||||
amount=0.01,
|
||||
price=50000.0,
|
||||
is_paper_trade=True
|
||||
)
|
||||
|
||||
assert order is not None
|
||||
assert order.symbol == "BTC/USD"
|
||||
assert order.side == "buy"
|
||||
assert order.status == "pending"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_order_status(self, order_manager, mock_database):
|
||||
"""Test order status update."""
|
||||
engine, Session = mock_database
|
||||
|
||||
# Create order first
|
||||
order = await order_manager.create_order(
|
||||
exchange_id=1,
|
||||
strategy_id=1,
|
||||
symbol="BTC/USD",
|
||||
side="buy",
|
||||
order_type="market",
|
||||
amount=0.01,
|
||||
is_paper_trade=True
|
||||
)
|
||||
|
||||
# Update status
|
||||
updated = await order_manager.update_order_status(
|
||||
client_order_id=order.client_order_id,
|
||||
new_status="filled",
|
||||
filled_amount=0.01,
|
||||
cost=500.0
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.status == "filled"
|
||||
|
||||
def test_get_order(self, order_manager, mock_database):
|
||||
"""Test getting order."""
|
||||
# This would require creating an order first
|
||||
# Simplified test
|
||||
order = order_manager.get_order(client_order_id="nonexistent")
|
||||
assert order is None
|
||||
|
||||
32
tests/unit/trading/test_paper_trading.py
Normal file
32
tests/unit/trading/test_paper_trading.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Tests for paper trading simulator."""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock
|
||||
from src.trading.paper_trading import get_paper_trading, PaperTradingSimulator
|
||||
from src.core.database import Order, OrderSide, OrderType
|
||||
|
||||
|
||||
class TestPaperTradingSimulator:
|
||||
"""Tests for PaperTradingSimulator."""
|
||||
|
||||
@pytest.fixture
|
||||
def simulator(self):
|
||||
"""Create paper trading simulator."""
|
||||
return PaperTradingSimulator(initial_capital=Decimal('1000.0'))
|
||||
|
||||
def test_initialization(self, simulator):
|
||||
"""Test simulator initialization."""
|
||||
assert simulator.initial_capital == Decimal('1000.0')
|
||||
assert simulator.cash == Decimal('1000.0')
|
||||
|
||||
def test_get_balance(self, simulator):
|
||||
"""Test getting balance."""
|
||||
balance = simulator.get_balance()
|
||||
assert balance == Decimal('1000.0')
|
||||
|
||||
def test_get_positions(self, simulator):
|
||||
"""Test getting positions."""
|
||||
positions = simulator.get_positions()
|
||||
assert isinstance(positions, list)
|
||||
|
||||
Reference in New Issue
Block a user