"""Trading API endpoints.""" from decimal import Decimal from typing import List, Optional from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.ext.asyncio import AsyncSession from ..core.dependencies import get_trading_engine, get_db_session from ..core.schemas import OrderCreate, OrderResponse, PositionResponse from src.core.database import Order, OrderSide, OrderType, OrderStatus from src.core.repositories import OrderRepository, PositionRepository from src.core.logger import get_logger from src.trading.paper_trading import get_paper_trading router = APIRouter() logger = get_logger(__name__) @router.post("/orders", response_model=OrderResponse) async def create_order( order_data: OrderCreate, trading_engine=Depends(get_trading_engine) ): """Create and execute a trading order.""" try: # Convert string enums to actual enums side = OrderSide(order_data.side.value) order_type = OrderType(order_data.order_type.value) order = await trading_engine.execute_order( exchange_id=order_data.exchange_id, strategy_id=order_data.strategy_id, symbol=order_data.symbol, side=side, order_type=order_type, quantity=order_data.quantity, price=order_data.price, paper_trading=order_data.paper_trading ) if not order: raise HTTPException(status_code=400, detail="Order execution failed") return OrderResponse.model_validate(order) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/orders", response_model=List[OrderResponse]) async def get_orders( paper_trading: bool = True, limit: int = 100, db: AsyncSession = Depends(get_db_session) ): """Get order history.""" try: repo = OrderRepository(db) orders = await repo.get_all(limit=limit) # Filter by paper_trading in memory or add to repo method (repo returns all by default currently sorted by date) # Let's verify repo method. It has limit/offset but not filtering. # We should filter here or improve repo. # For this refactor, let's filter in python for simplicity or assume get_all needs an update. # Ideally, update repo. But strictly following "get_all" contract. filtered = [o for o in orders if o.paper_trading == paper_trading] return [OrderResponse.model_validate(order) for order in filtered] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/orders/{order_id}", response_model=OrderResponse) async def get_order( order_id: int, db: AsyncSession = Depends(get_db_session) ): """Get order by ID.""" try: repo = OrderRepository(db) order = await repo.get_by_id(order_id) if not order: raise HTTPException(status_code=404, detail="Order not found") return OrderResponse.model_validate(order) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/orders/{order_id}/cancel") async def cancel_order( order_id: int, trading_engine=Depends(get_trading_engine) ): try: # We can use trading_engine's cancel which handles DB and Exchange success = await trading_engine.cancel_order(order_id) if not success: raise HTTPException(status_code=400, detail="Failed to cancel order or order not found") return {"status": "cancelled", "order_id": order_id} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/orders/cancel-all") async def cancel_all_orders( paper_trading: bool = True, trading_engine=Depends(get_trading_engine), db: AsyncSession = Depends(get_db_session) ): """Cancel all open orders.""" try: repo = OrderRepository(db) open_orders = await repo.get_open_orders(paper_trading=paper_trading) if not open_orders: return {"status": "no_orders", "cancelled_count": 0} cancelled_count = 0 failed_count = 0 for order in open_orders: try: if await trading_engine.cancel_order(order.id): cancelled_count += 1 else: failed_count += 1 except Exception as e: logger.error(f"Failed to cancel order {order.id}: {e}") failed_count += 1 return { "status": "completed", "cancelled_count": cancelled_count, "failed_count": failed_count, "total_orders": len(open_orders) } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/positions", response_model=List[PositionResponse]) async def get_positions( paper_trading: bool = True, db: AsyncSession = Depends(get_db_session) ): """Get current positions.""" try: if paper_trading: paper_trading_sim = get_paper_trading() positions = paper_trading_sim.get_positions() # positions is a List[Position], convert to PositionResponse list position_list = [] for pos in positions: # pos is a Position database object current_price = pos.current_price if pos.current_price else pos.entry_price unrealized_pnl = pos.unrealized_pnl if pos.unrealized_pnl else Decimal(0) realized_pnl = pos.realized_pnl if pos.realized_pnl else Decimal(0) position_list.append( PositionResponse( symbol=pos.symbol, quantity=pos.quantity, entry_price=pos.entry_price, current_price=current_price, unrealized_pnl=unrealized_pnl, realized_pnl=realized_pnl ) ) return position_list else: # Live trading positions from database repo = PositionRepository(db) positions = await repo.get_all(paper_trading=False) return [ PositionResponse( symbol=pos.symbol, quantity=pos.quantity, entry_price=pos.entry_price, current_price=pos.current_price or pos.entry_price, unrealized_pnl=pos.unrealized_pnl, realized_pnl=pos.realized_pnl ) for pos in positions ] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/balance") async def get_balance(paper_trading: bool = True): """Get account balance.""" try: paper_trading_sim = get_paper_trading() if paper_trading: balance = paper_trading_sim.get_balance() performance = paper_trading_sim.get_performance() return { "balance": float(balance), "performance": performance } else: # Live trading balance from exchange return {"balance": 0.0, "performance": {}} except Exception as e: raise HTTPException(status_code=500, detail=str(e))