294 lines
10 KiB
Python
294 lines
10 KiB
Python
"""Tests for trading API endpoints."""
|
|
|
|
import pytest
|
|
from decimal import Decimal
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
from fastapi.testclient import TestClient
|
|
from datetime import datetime
|
|
|
|
from backend.main import app
|
|
from backend.core.schemas import OrderCreate, OrderSide, OrderType
|
|
from src.core.database import Order, OrderStatus
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Test client fixture."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_trading_engine():
|
|
"""Mock trading engine."""
|
|
engine = Mock()
|
|
engine.order_manager = Mock()
|
|
return engine
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_order():
|
|
"""Mock order object."""
|
|
order = Mock()
|
|
order.id = 1
|
|
order.exchange_id = 1
|
|
order.strategy_id = None
|
|
order.symbol = "BTC/USD"
|
|
order.order_type = OrderType.MARKET
|
|
order.side = OrderSide.BUY
|
|
order.status = OrderStatus.FILLED
|
|
order.quantity = Decimal("0.1")
|
|
order.price = Decimal("50000")
|
|
order.filled_quantity = Decimal("0.1")
|
|
order.average_fill_price = Decimal("50000")
|
|
order.fee = Decimal("5")
|
|
order.paper_trading = True
|
|
order.created_at = datetime.now()
|
|
order.updated_at = datetime.now()
|
|
order.filled_at = datetime.now()
|
|
return order
|
|
|
|
|
|
class TestCreateOrder:
|
|
"""Tests for POST /api/trading/orders."""
|
|
|
|
@patch('backend.api.trading.get_trading_engine')
|
|
def test_create_order_success(self, mock_get_engine, client, mock_trading_engine, mock_order):
|
|
"""Test successful order creation."""
|
|
mock_get_engine.return_value = mock_trading_engine
|
|
mock_trading_engine.execute_order.return_value = mock_order
|
|
|
|
order_data = {
|
|
"exchange_id": 1,
|
|
"symbol": "BTC/USD",
|
|
"side": "buy",
|
|
"order_type": "market",
|
|
"quantity": "0.1",
|
|
"paper_trading": True
|
|
}
|
|
|
|
response = client.post("/api/trading/orders", json=order_data)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == 1
|
|
assert data["symbol"] == "BTC/USD"
|
|
assert data["side"] == "buy"
|
|
mock_trading_engine.execute_order.assert_called_once()
|
|
|
|
@patch('backend.api.trading.get_trading_engine')
|
|
def test_create_order_execution_failed(self, mock_get_engine, client, mock_trading_engine):
|
|
"""Test order creation when execution fails."""
|
|
mock_get_engine.return_value = mock_trading_engine
|
|
mock_trading_engine.execute_order.return_value = None
|
|
|
|
order_data = {
|
|
"exchange_id": 1,
|
|
"symbol": "BTC/USD",
|
|
"side": "buy",
|
|
"order_type": "market",
|
|
"quantity": "0.1",
|
|
"paper_trading": True
|
|
}
|
|
|
|
response = client.post("/api/trading/orders", json=order_data)
|
|
|
|
assert response.status_code == 400
|
|
assert "Order execution failed" in response.json()["detail"]
|
|
|
|
@patch('backend.api.trading.get_trading_engine')
|
|
def test_create_order_invalid_data(self, client):
|
|
"""Test order creation with invalid data."""
|
|
order_data = {
|
|
"exchange_id": "invalid",
|
|
"symbol": "BTC/USD",
|
|
"side": "buy",
|
|
"order_type": "market",
|
|
"quantity": "0.1"
|
|
}
|
|
|
|
response = client.post("/api/trading/orders", json=order_data)
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
|
|
class TestGetOrders:
|
|
"""Tests for GET /api/trading/orders."""
|
|
|
|
@patch('backend.api.trading.get_db')
|
|
def test_get_orders_success(self, mock_get_db, client, mock_database):
|
|
"""Test getting orders."""
|
|
mock_db = Mock()
|
|
mock_session = Mock()
|
|
mock_db.get_session.return_value = mock_session
|
|
mock_get_db.return_value = mock_db
|
|
|
|
# Create mock orders
|
|
mock_order1 = Mock(spec=Order)
|
|
mock_order1.id = 1
|
|
mock_order1.symbol = "BTC/USD"
|
|
mock_order1.paper_trading = True
|
|
|
|
mock_order2 = Mock(spec=Order)
|
|
mock_order2.id = 2
|
|
mock_order2.symbol = "ETH/USD"
|
|
mock_order2.paper_trading = True
|
|
|
|
mock_session.query.return_value.filter_by.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_order1, mock_order2]
|
|
|
|
response = client.get("/api/trading/orders?paper_trading=true&limit=10")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) == 2
|
|
|
|
|
|
class TestGetOrder:
|
|
"""Tests for GET /api/trading/orders/{order_id}."""
|
|
|
|
@patch('backend.api.trading.get_db')
|
|
def test_get_order_success(self, mock_get_db, client, mock_database):
|
|
"""Test getting order by ID."""
|
|
mock_db = Mock()
|
|
mock_session = Mock()
|
|
mock_db.get_session.return_value = mock_session
|
|
mock_get_db.return_value = mock_db
|
|
|
|
mock_order = Mock(spec=Order)
|
|
mock_order.id = 1
|
|
mock_order.symbol = "BTC/USD"
|
|
|
|
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
|
|
|
|
response = client.get("/api/trading/orders/1")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == 1
|
|
|
|
@patch('backend.api.trading.get_db')
|
|
def test_get_order_not_found(self, mock_get_db, client):
|
|
"""Test getting non-existent order."""
|
|
mock_db = Mock()
|
|
mock_session = Mock()
|
|
mock_db.get_session.return_value = mock_session
|
|
mock_get_db.return_value = mock_db
|
|
|
|
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
|
|
|
response = client.get("/api/trading/orders/999")
|
|
|
|
assert response.status_code == 404
|
|
assert "Order not found" in response.json()["detail"]
|
|
|
|
|
|
class TestCancelOrder:
|
|
"""Tests for POST /api/trading/orders/{order_id}/cancel."""
|
|
|
|
@patch('backend.api.trading.get_trading_engine')
|
|
@patch('backend.api.trading.get_db')
|
|
def test_cancel_order_success(self, mock_get_db, mock_get_engine, client, mock_trading_engine):
|
|
"""Test successful order cancellation."""
|
|
mock_get_engine.return_value = mock_trading_engine
|
|
mock_trading_engine.order_manager.cancel_order.return_value = True
|
|
|
|
mock_db = Mock()
|
|
mock_session = Mock()
|
|
mock_db.get_session.return_value = mock_session
|
|
mock_get_db.return_value = mock_db
|
|
|
|
mock_order = Mock(spec=Order)
|
|
mock_order.id = 1
|
|
mock_order.status = OrderStatus.OPEN
|
|
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
|
|
|
|
response = client.post("/api/trading/orders/1/cancel")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "cancelled"
|
|
assert data["order_id"] == 1
|
|
|
|
@patch('backend.api.trading.get_db')
|
|
def test_cancel_order_not_found(self, mock_get_db, client):
|
|
"""Test cancelling non-existent order."""
|
|
mock_db = Mock()
|
|
mock_session = Mock()
|
|
mock_db.get_session.return_value = mock_session
|
|
mock_get_db.return_value = mock_db
|
|
|
|
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
|
|
|
response = client.post("/api/trading/orders/999/cancel")
|
|
|
|
assert response.status_code == 404
|
|
assert "Order not found" in response.json()["detail"]
|
|
|
|
@patch('backend.api.trading.get_trading_engine')
|
|
@patch('backend.api.trading.get_db')
|
|
def test_cancel_order_already_filled(self, mock_get_db, mock_get_engine, client, mock_trading_engine):
|
|
"""Test cancelling already filled order."""
|
|
mock_get_engine.return_value = mock_trading_engine
|
|
|
|
mock_db = Mock()
|
|
mock_session = Mock()
|
|
mock_db.get_session.return_value = mock_session
|
|
mock_get_db.return_value = mock_db
|
|
|
|
mock_order = Mock(spec=Order)
|
|
mock_order.id = 1
|
|
mock_order.status = OrderStatus.FILLED
|
|
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_order
|
|
|
|
response = client.post("/api/trading/orders/1/cancel")
|
|
|
|
assert response.status_code == 400
|
|
assert "cannot be cancelled" in response.json()["detail"]
|
|
|
|
|
|
class TestGetPositions:
|
|
"""Tests for GET /api/trading/positions."""
|
|
|
|
@patch('backend.api.trading.get_paper_trading')
|
|
def test_get_positions_paper_trading(self, mock_get_paper, client):
|
|
"""Test getting positions for paper trading."""
|
|
mock_paper = Mock()
|
|
mock_position = Mock()
|
|
mock_position.symbol = "BTC/USD"
|
|
mock_position.quantity = Decimal("0.1")
|
|
mock_position.entry_price = Decimal("50000")
|
|
mock_position.current_price = Decimal("51000")
|
|
mock_position.unrealized_pnl = Decimal("100")
|
|
mock_position.realized_pnl = Decimal("0")
|
|
mock_paper.get_positions.return_value = [mock_position]
|
|
mock_get_paper.return_value = mock_paper
|
|
|
|
response = client.get("/api/trading/positions?paper_trading=true")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
if len(data) > 0:
|
|
assert data[0]["symbol"] == "BTC/USD"
|
|
|
|
|
|
class TestGetBalance:
|
|
"""Tests for GET /api/trading/balance."""
|
|
|
|
@patch('backend.api.trading.get_paper_trading')
|
|
def test_get_balance_paper_trading(self, mock_get_paper, client):
|
|
"""Test getting balance for paper trading."""
|
|
mock_paper = Mock()
|
|
mock_paper.get_balance.return_value = Decimal("1000")
|
|
mock_paper.get_performance.return_value = {"total_return": 0.1}
|
|
mock_get_paper.return_value = mock_paper
|
|
|
|
response = client.get("/api/trading/balance?paper_trading=true")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "balance" in data
|
|
assert "performance" in data
|
|
assert data["balance"] == 1000.0
|
|
|