Files

87 lines
3.5 KiB
Python
Raw Permalink Normal View History

"""Pytest configuration and fixtures."""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, PropertyMock
from decimal import Decimal
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from src.core.database import Base, Database, get_database
from sqlalchemy.pool import StaticPool
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def db_engine():
"""Create async database engine."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=False
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
"""Create async database session."""
async_session = async_sessionmaker(bind=db_engine, expire_on_commit=False)
async with async_session() as session:
yield session
@pytest.fixture(autouse=True)
def override_get_database(db_engine, monkeypatch):
"""Override get_database to use test engine."""
test_db = Database()
# We mock the internal attributes to return our test engine/session
test_db.engine = db_engine
test_db.SessionLocal = async_sessionmaker(bind=db_engine, class_=AsyncSession, expire_on_commit=False)
# Patch the global get_database
monkeypatch.setattr("src.core.database._db_instance", test_db)
return test_db
@pytest.fixture
def mock_exchange_adapter():
"""Mock exchange adapter."""
from src.exchanges.base import BaseExchangeAdapter
adapter = AsyncMock(spec=BaseExchangeAdapter)
adapter.get_ticker.return_value = {'last': Decimal("50000")}
adapter.place_order.return_value = {'id': 'test_order_123', 'status': 'open'}
adapter.get_balance.return_value = {'USD': Decimal("10000")}
# Helper methods should be sync mocks
# Note: If extract_fee... is not part of BaseExchangeAdapter spec, we have to attach it manually
# But checking base.py, it likely IS or isn't.
# Safe to attach it manually even with spec if we traverse __dict__ or simply assign.
# However, standard mock might block unknown attribs.
# Actually BaseExchangeAdapter is abstract.
# Let's inspect BaseExchangeAdapter structure if needed.
# For now, let's assume usage of spec is the right direction.
# But if extract_fee... is NOT in BaseExchangeAdapter, we might need to mock a Concrete class like Coinbase
pass
# Better approach: Just Delete get_fee_structure from the mock to ensure AttributeError
adapter = AsyncMock()
del adapter.get_fee_structure
# Wait, AsyncMock creates attrs on access. del might not work if not existing.
# We can se side_effect to raise AttributeError
adapter.get_ticker.return_value = {'last': Decimal("50000")}
adapter.place_order.return_value = {'id': 'test_order_123', 'status': 'open'}
adapter.get_balance.return_value = {'USD': Decimal("10000")}
adapter.extract_fee_from_order_response = Mock(return_value=Decimal("1.0"))
adapter.name = "coinbase" # FeeCalculator accesses .name
type(adapter).get_fee_structure = PropertyMock(side_effect=AttributeError)
# Accessing methods on AsyncMock...
return adapter