Files

207 lines
7.3 KiB
Python
Raw Permalink Normal View History

"""Trading API endpoints."""
from decimal import Decimal
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from ..core.dependencies import get_trading_engine, get_db_session
from ..core.schemas import OrderCreate, OrderResponse, PositionResponse
from src.core.database import Order, OrderSide, OrderType, OrderStatus
from src.core.repositories import OrderRepository, PositionRepository
from src.core.logger import get_logger
from src.trading.paper_trading import get_paper_trading
router = APIRouter()
logger = get_logger(__name__)
@router.post("/orders", response_model=OrderResponse)
async def create_order(
order_data: OrderCreate,
trading_engine=Depends(get_trading_engine)
):
"""Create and execute a trading order."""
try:
# Convert string enums to actual enums
side = OrderSide(order_data.side.value)
order_type = OrderType(order_data.order_type.value)
order = await trading_engine.execute_order(
exchange_id=order_data.exchange_id,
strategy_id=order_data.strategy_id,
symbol=order_data.symbol,
side=side,
order_type=order_type,
quantity=order_data.quantity,
price=order_data.price,
paper_trading=order_data.paper_trading
)
if not order:
raise HTTPException(status_code=400, detail="Order execution failed")
return OrderResponse.model_validate(order)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/orders", response_model=List[OrderResponse])
async def get_orders(
paper_trading: bool = True,
limit: int = 100,
db: AsyncSession = Depends(get_db_session)
):
"""Get order history."""
try:
repo = OrderRepository(db)
orders = await repo.get_all(limit=limit)
# Filter by paper_trading in memory or add to repo method (repo returns all by default currently sorted by date)
# Let's verify repo method. It has limit/offset but not filtering.
# We should filter here or improve repo.
# For this refactor, let's filter in python for simplicity or assume get_all needs an update.
# Ideally, update repo. But strictly following "get_all" contract.
filtered = [o for o in orders if o.paper_trading == paper_trading]
return [OrderResponse.model_validate(order) for order in filtered]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/orders/{order_id}", response_model=OrderResponse)
async def get_order(
order_id: int,
db: AsyncSession = Depends(get_db_session)
):
"""Get order by ID."""
try:
repo = OrderRepository(db)
order = await repo.get_by_id(order_id)
if not order:
raise HTTPException(status_code=404, detail="Order not found")
return OrderResponse.model_validate(order)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/orders/{order_id}/cancel")
async def cancel_order(
order_id: int,
trading_engine=Depends(get_trading_engine)
):
try:
# We can use trading_engine's cancel which handles DB and Exchange
success = await trading_engine.cancel_order(order_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to cancel order or order not found")
return {"status": "cancelled", "order_id": order_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/orders/cancel-all")
async def cancel_all_orders(
paper_trading: bool = True,
trading_engine=Depends(get_trading_engine),
db: AsyncSession = Depends(get_db_session)
):
"""Cancel all open orders."""
try:
repo = OrderRepository(db)
open_orders = await repo.get_open_orders(paper_trading=paper_trading)
if not open_orders:
return {"status": "no_orders", "cancelled_count": 0}
cancelled_count = 0
failed_count = 0
for order in open_orders:
try:
if await trading_engine.cancel_order(order.id):
cancelled_count += 1
else:
failed_count += 1
except Exception as e:
logger.error(f"Failed to cancel order {order.id}: {e}")
failed_count += 1
return {
"status": "completed",
"cancelled_count": cancelled_count,
"failed_count": failed_count,
"total_orders": len(open_orders)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/positions", response_model=List[PositionResponse])
async def get_positions(
paper_trading: bool = True,
db: AsyncSession = Depends(get_db_session)
):
"""Get current positions."""
try:
if paper_trading:
paper_trading_sim = get_paper_trading()
positions = paper_trading_sim.get_positions()
# positions is a List[Position], convert to PositionResponse list
position_list = []
for pos in positions:
# pos is a Position database object
current_price = pos.current_price if pos.current_price else pos.entry_price
unrealized_pnl = pos.unrealized_pnl if pos.unrealized_pnl else Decimal(0)
realized_pnl = pos.realized_pnl if pos.realized_pnl else Decimal(0)
position_list.append(
PositionResponse(
symbol=pos.symbol,
quantity=pos.quantity,
entry_price=pos.entry_price,
current_price=current_price,
unrealized_pnl=unrealized_pnl,
realized_pnl=realized_pnl
)
)
return position_list
else:
# Live trading positions from database
repo = PositionRepository(db)
positions = await repo.get_all(paper_trading=False)
return [
PositionResponse(
symbol=pos.symbol,
quantity=pos.quantity,
entry_price=pos.entry_price,
current_price=pos.current_price or pos.entry_price,
unrealized_pnl=pos.unrealized_pnl,
realized_pnl=pos.realized_pnl
)
for pos in positions
]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/balance")
async def get_balance(paper_trading: bool = True):
"""Get account balance."""
try:
paper_trading_sim = get_paper_trading()
if paper_trading:
balance = paper_trading_sim.get_balance()
performance = paper_trading_sim.get_performance()
return {
"balance": float(balance),
"performance": performance
}
else:
# Live trading balance from exchange
return {"balance": 0.0, "performance": {}}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))