162 lines
5.9 KiB
Python
162 lines
5.9 KiB
Python
|
|
"""Tests for fee calculator functionality."""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from decimal import Decimal
|
||
|
|
from unittest.mock import Mock, patch
|
||
|
|
from src.trading.fee_calculator import FeeCalculator, get_fee_calculator
|
||
|
|
from src.core.database import OrderType
|
||
|
|
|
||
|
|
|
||
|
|
class TestFeeCalculator:
|
||
|
|
"""Tests for FeeCalculator class."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def calculator(self):
|
||
|
|
"""Create fee calculator instance."""
|
||
|
|
return FeeCalculator()
|
||
|
|
|
||
|
|
def test_get_fee_calculator_singleton(self):
|
||
|
|
"""Test that get_fee_calculator returns singleton."""
|
||
|
|
calc1 = get_fee_calculator()
|
||
|
|
calc2 = get_fee_calculator()
|
||
|
|
assert calc1 is calc2
|
||
|
|
|
||
|
|
def test_calculate_fee_basic(self, calculator):
|
||
|
|
"""Test basic fee calculation."""
|
||
|
|
fee = calculator.calculate_fee(
|
||
|
|
quantity=Decimal('1.0'),
|
||
|
|
price=Decimal('100.0'),
|
||
|
|
order_type=OrderType.MARKET
|
||
|
|
)
|
||
|
|
assert fee > 0
|
||
|
|
assert isinstance(fee, Decimal)
|
||
|
|
|
||
|
|
def test_calculate_fee_zero_quantity(self, calculator):
|
||
|
|
"""Test fee calculation with zero quantity."""
|
||
|
|
fee = calculator.calculate_fee(
|
||
|
|
quantity=Decimal('0'),
|
||
|
|
price=Decimal('100.0'),
|
||
|
|
order_type=OrderType.MARKET
|
||
|
|
)
|
||
|
|
assert fee == Decimal('0')
|
||
|
|
|
||
|
|
def test_calculate_fee_maker_vs_taker(self, calculator):
|
||
|
|
"""Test maker fees are typically lower than taker fees."""
|
||
|
|
maker_fee = calculator.calculate_fee(
|
||
|
|
quantity=Decimal('1.0'),
|
||
|
|
price=Decimal('1000.0'),
|
||
|
|
order_type=OrderType.LIMIT,
|
||
|
|
is_maker=True
|
||
|
|
)
|
||
|
|
taker_fee = calculator.calculate_fee(
|
||
|
|
quantity=Decimal('1.0'),
|
||
|
|
price=Decimal('1000.0'),
|
||
|
|
order_type=OrderType.MARKET,
|
||
|
|
is_maker=False
|
||
|
|
)
|
||
|
|
# Maker fees should be <= taker fees
|
||
|
|
assert maker_fee <= taker_fee
|
||
|
|
|
||
|
|
|
||
|
|
class TestFeeCalculatorPaperTrading:
|
||
|
|
"""Tests for paper trading fee calculation."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def calculator(self):
|
||
|
|
"""Create fee calculator instance."""
|
||
|
|
return FeeCalculator()
|
||
|
|
|
||
|
|
def test_get_fee_structure_by_exchange_name_coinbase(self, calculator):
|
||
|
|
"""Test getting Coinbase fee structure."""
|
||
|
|
with patch.object(calculator, 'config') as mock_config:
|
||
|
|
mock_config.get.side_effect = lambda key, default=None: {
|
||
|
|
'trading.default_fees': {'maker': 0.001, 'taker': 0.001},
|
||
|
|
'trading.exchanges.coinbase.fees': {'maker': 0.004, 'taker': 0.006}
|
||
|
|
}.get(key, default)
|
||
|
|
|
||
|
|
fees = calculator.get_fee_structure_by_exchange_name('coinbase')
|
||
|
|
assert fees['maker'] == 0.004
|
||
|
|
assert fees['taker'] == 0.006
|
||
|
|
|
||
|
|
def test_get_fee_structure_by_exchange_name_unknown(self, calculator):
|
||
|
|
"""Test getting fee structure for unknown exchange returns defaults."""
|
||
|
|
with patch.object(calculator, 'config') as mock_config:
|
||
|
|
mock_config.get.side_effect = lambda key, default=None: {
|
||
|
|
'trading.default_fees': {'maker': 0.001, 'taker': 0.001},
|
||
|
|
'trading.exchanges.unknown.fees': None
|
||
|
|
}.get(key, default)
|
||
|
|
|
||
|
|
fees = calculator.get_fee_structure_by_exchange_name('unknown')
|
||
|
|
assert fees['maker'] == 0.001
|
||
|
|
assert fees['taker'] == 0.001
|
||
|
|
|
||
|
|
def test_calculate_fee_for_paper_trading(self, calculator):
|
||
|
|
"""Test paper trading fee calculation."""
|
||
|
|
with patch.object(calculator, 'config') as mock_config:
|
||
|
|
mock_config.get.side_effect = lambda key, default=None: {
|
||
|
|
'paper_trading.fee_exchange': 'coinbase',
|
||
|
|
'trading.exchanges.coinbase.fees': {'maker': 0.004, 'taker': 0.006},
|
||
|
|
'trading.default_fees': {'maker': 0.001, 'taker': 0.001}
|
||
|
|
}.get(key, default)
|
||
|
|
|
||
|
|
fee = calculator.calculate_fee_for_paper_trading(
|
||
|
|
quantity=Decimal('1.0'),
|
||
|
|
price=Decimal('1000.0'),
|
||
|
|
order_type=OrderType.MARKET,
|
||
|
|
is_maker=False
|
||
|
|
)
|
||
|
|
# Taker fee at 0.6% of $1000 = $6
|
||
|
|
expected = Decimal('1000.0') * Decimal('0.006')
|
||
|
|
assert fee == expected
|
||
|
|
|
||
|
|
def test_calculate_fee_for_paper_trading_zero(self, calculator):
|
||
|
|
"""Test paper trading fee with zero values."""
|
||
|
|
fee = calculator.calculate_fee_for_paper_trading(
|
||
|
|
quantity=Decimal('0'),
|
||
|
|
price=Decimal('100.0'),
|
||
|
|
order_type=OrderType.MARKET
|
||
|
|
)
|
||
|
|
assert fee == Decimal('0')
|
||
|
|
|
||
|
|
|
||
|
|
class TestRoundTripFees:
|
||
|
|
"""Tests for round-trip fee calculations."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def calculator(self):
|
||
|
|
"""Create fee calculator instance."""
|
||
|
|
return FeeCalculator()
|
||
|
|
|
||
|
|
def test_estimate_round_trip_fee(self, calculator):
|
||
|
|
"""Test round-trip fee estimation."""
|
||
|
|
fee = calculator.estimate_round_trip_fee(
|
||
|
|
quantity=Decimal('1.0'),
|
||
|
|
price=Decimal('1000.0')
|
||
|
|
)
|
||
|
|
# Round-trip should include buy and sell fees
|
||
|
|
assert fee > 0
|
||
|
|
|
||
|
|
single_fee = calculator.calculate_fee(
|
||
|
|
quantity=Decimal('1.0'),
|
||
|
|
price=Decimal('1000.0'),
|
||
|
|
order_type=OrderType.MARKET
|
||
|
|
)
|
||
|
|
# Round-trip should be approximately 2x single fee
|
||
|
|
assert fee >= single_fee
|
||
|
|
|
||
|
|
def test_get_minimum_profit_threshold(self, calculator):
|
||
|
|
"""Test minimum profit threshold calculation."""
|
||
|
|
threshold = calculator.get_minimum_profit_threshold(
|
||
|
|
quantity=Decimal('1.0'),
|
||
|
|
price=Decimal('1000.0'),
|
||
|
|
multiplier=2.0
|
||
|
|
)
|
||
|
|
assert threshold > 0
|
||
|
|
|
||
|
|
round_trip_fee = calculator.estimate_round_trip_fee(
|
||
|
|
quantity=Decimal('1.0'),
|
||
|
|
price=Decimal('1000.0')
|
||
|
|
)
|
||
|
|
# Threshold should be 2x round-trip fee
|
||
|
|
assert threshold == round_trip_fee * Decimal('2.0')
|