87 lines
3.5 KiB
Python
87 lines
3.5 KiB
Python
"""Pytest configuration and fixtures."""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import Mock, AsyncMock, PropertyMock
|
|
from decimal import Decimal
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
from src.core.database import Base, Database, get_database
|
|
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop():
|
|
"""Create an instance of the default event loop for each test case."""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
@pytest.fixture(scope="session")
|
|
async def db_engine():
|
|
"""Create async database engine."""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
echo=False
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
@pytest.fixture
|
|
async def db_session(db_engine):
|
|
"""Create async database session."""
|
|
async_session = async_sessionmaker(bind=db_engine, expire_on_commit=False)
|
|
async with async_session() as session:
|
|
yield session
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def override_get_database(db_engine, monkeypatch):
|
|
"""Override get_database to use test engine."""
|
|
test_db = Database()
|
|
# We mock the internal attributes to return our test engine/session
|
|
test_db.engine = db_engine
|
|
test_db.SessionLocal = async_sessionmaker(bind=db_engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
# Patch the global get_database
|
|
monkeypatch.setattr("src.core.database._db_instance", test_db)
|
|
return test_db
|
|
|
|
@pytest.fixture
|
|
def mock_exchange_adapter():
|
|
"""Mock exchange adapter."""
|
|
from src.exchanges.base import BaseExchangeAdapter
|
|
adapter = AsyncMock(spec=BaseExchangeAdapter)
|
|
adapter.get_ticker.return_value = {'last': Decimal("50000")}
|
|
adapter.place_order.return_value = {'id': 'test_order_123', 'status': 'open'}
|
|
adapter.get_balance.return_value = {'USD': Decimal("10000")}
|
|
# Helper methods should be sync mocks
|
|
# Note: If extract_fee... is not part of BaseExchangeAdapter spec, we have to attach it manually
|
|
# But checking base.py, it likely IS or isn't.
|
|
# Safe to attach it manually even with spec if we traverse __dict__ or simply assign.
|
|
# However, standard mock might block unknown attribs.
|
|
# Actually BaseExchangeAdapter is abstract.
|
|
|
|
# Let's inspect BaseExchangeAdapter structure if needed.
|
|
# For now, let's assume usage of spec is the right direction.
|
|
# But if extract_fee... is NOT in BaseExchangeAdapter, we might need to mock a Concrete class like Coinbase
|
|
pass
|
|
|
|
# Better approach: Just Delete get_fee_structure from the mock to ensure AttributeError
|
|
adapter = AsyncMock()
|
|
del adapter.get_fee_structure
|
|
# Wait, AsyncMock creates attrs on access. del might not work if not existing.
|
|
# We can se side_effect to raise AttributeError
|
|
|
|
adapter.get_ticker.return_value = {'last': Decimal("50000")}
|
|
adapter.place_order.return_value = {'id': 'test_order_123', 'status': 'open'}
|
|
adapter.get_balance.return_value = {'USD': Decimal("10000")}
|
|
adapter.extract_fee_from_order_response = Mock(return_value=Decimal("1.0"))
|
|
adapter.name = "coinbase" # FeeCalculator accesses .name
|
|
type(adapter).get_fee_structure = PropertyMock(side_effect=AttributeError)
|
|
# Accessing methods on AsyncMock...
|
|
|
|
return adapter
|