"""Trading API endpoints.""" from decimal import Decimal from typing import List, Optional from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.ext.asyncio import AsyncSession from ..core.dependencies import get_trading_engine, get_db_session from ..core.schemas import OrderCreate, OrderResponse, PositionResponse from src.core.database import Order, OrderSide, OrderType, OrderStatus from src.core.repositories import OrderRepository, PositionRepository from src.core.logger import get_logger from src.trading.paper_trading import get_paper_trading from src.trading.advanced_orders import get_advanced_order_manager router = APIRouter() logger = get_logger(__name__) @router.post("/orders", response_model=OrderResponse) async def create_order( order_data: OrderCreate, trading_engine=Depends(get_trading_engine) ): """Create and execute a trading order. Supports advanced order types: - Trailing stop: Set trail_percent (e.g., 0.02 for 2%) - Bracket orders: Set bracket_take_profit and bracket_stop_loss - OCO orders: Set oco_price for the second order - Iceberg orders: Set visible_quantity """ try: # Convert string enums to actual enums side = OrderSide(order_data.side.value) order_type = OrderType(order_data.order_type.value) # Execute the base order order = await trading_engine.execute_order( exchange_id=order_data.exchange_id, strategy_id=order_data.strategy_id, symbol=order_data.symbol, side=side, order_type=order_type, quantity=order_data.quantity, price=order_data.price, paper_trading=order_data.paper_trading ) if not order: raise HTTPException(status_code=400, detail="Order execution failed") # Handle advanced order types advanced_manager = get_advanced_order_manager() # Trailing stop if order_type == OrderType.TRAILING_STOP and order_data.trail_percent: if not order_data.stop_loss_price: raise HTTPException(status_code=400, detail="stop_loss_price required for trailing stop") advanced_manager.create_trailing_stop( base_order_id=order.id, initial_stop_price=order_data.stop_loss_price, trail_percent=order_data.trail_percent ) # Take profit if order_data.take_profit_price: advanced_manager.create_take_profit( base_order_id=order.id, target_price=order_data.take_profit_price ) # Bracket order (entry + take profit + stop loss) if order_data.bracket_take_profit or order_data.bracket_stop_loss: if order_data.bracket_take_profit: advanced_manager.create_take_profit( base_order_id=order.id, target_price=order_data.bracket_take_profit ) if order_data.bracket_stop_loss: advanced_manager.create_trailing_stop( base_order_id=order.id, initial_stop_price=order_data.bracket_stop_loss, trail_percent=Decimal("0.0") # Fixed stop, not trailing ) # OCO order if order_type == OrderType.OCO and order_data.oco_price: # Create second order oco_order = await trading_engine.execute_order( exchange_id=order_data.exchange_id, strategy_id=order_data.strategy_id, symbol=order_data.symbol, side=OrderSide.SELL if side == OrderSide.BUY else OrderSide.BUY, order_type=OrderType.LIMIT, quantity=order_data.quantity, price=order_data.oco_price, paper_trading=order_data.paper_trading ) if oco_order: advanced_manager.create_oco(order.id, oco_order.id) # Iceberg order if order_type == OrderType.ICEBERG and order_data.visible_quantity: advanced_manager.create_iceberg( total_quantity=order_data.quantity, visible_quantity=order_data.visible_quantity, symbol=order_data.symbol, side=side, price=order_data.price ) return OrderResponse.model_validate(order) except HTTPException: raise except Exception as e: logger.error(f"Error creating order: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get("/orders", response_model=List[OrderResponse]) async def get_orders( paper_trading: bool = True, limit: int = 100, db: AsyncSession = Depends(get_db_session) ): """Get order history.""" try: repo = OrderRepository(db) orders = await repo.get_all(limit=limit) # Filter by paper_trading in memory or add to repo method (repo returns all by default currently sorted by date) # Let's verify repo method. It has limit/offset but not filtering. # We should filter here or improve repo. # For this refactor, let's filter in python for simplicity or assume get_all needs an update. # Ideally, update repo. But strictly following "get_all" contract. filtered = [o for o in orders if o.paper_trading == paper_trading] return [OrderResponse.model_validate(order) for order in filtered] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/orders/{order_id}", response_model=OrderResponse) async def get_order( order_id: int, db: AsyncSession = Depends(get_db_session) ): """Get order by ID.""" try: repo = OrderRepository(db) order = await repo.get_by_id(order_id) if not order: raise HTTPException(status_code=404, detail="Order not found") return OrderResponse.model_validate(order) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/orders/{order_id}/cancel") async def cancel_order( order_id: int, trading_engine=Depends(get_trading_engine) ): try: # We can use trading_engine's cancel which handles DB and Exchange success = await trading_engine.cancel_order(order_id) if not success: raise HTTPException(status_code=400, detail="Failed to cancel order or order not found") return {"status": "cancelled", "order_id": order_id} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/orders/cancel-all") async def cancel_all_orders( paper_trading: bool = True, trading_engine=Depends(get_trading_engine), db: AsyncSession = Depends(get_db_session) ): """Cancel all open orders.""" try: repo = OrderRepository(db) open_orders = await repo.get_open_orders(paper_trading=paper_trading) if not open_orders: return {"status": "no_orders", "cancelled_count": 0} cancelled_count = 0 failed_count = 0 for order in open_orders: try: if await trading_engine.cancel_order(order.id): cancelled_count += 1 else: failed_count += 1 except Exception as e: logger.error(f"Failed to cancel order {order.id}: {e}") failed_count += 1 return { "status": "completed", "cancelled_count": cancelled_count, "failed_count": failed_count, "total_orders": len(open_orders) } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/positions", response_model=List[PositionResponse]) async def get_positions( paper_trading: bool = True, db: AsyncSession = Depends(get_db_session) ): """Get current positions.""" try: if paper_trading: paper_trading_sim = get_paper_trading() positions = paper_trading_sim.get_positions() # positions is a List[Position], convert to PositionResponse list position_list = [] for pos in positions: # pos is a Position database object current_price = pos.current_price if pos.current_price else pos.entry_price unrealized_pnl = pos.unrealized_pnl if pos.unrealized_pnl else Decimal(0) realized_pnl = pos.realized_pnl if pos.realized_pnl else Decimal(0) position_list.append( PositionResponse( symbol=pos.symbol, quantity=pos.quantity, entry_price=pos.entry_price, current_price=current_price, unrealized_pnl=unrealized_pnl, realized_pnl=realized_pnl ) ) return position_list else: # Live trading positions from database repo = PositionRepository(db) positions = await repo.get_all(paper_trading=False) return [ PositionResponse( symbol=pos.symbol, quantity=pos.quantity, entry_price=pos.entry_price, current_price=pos.current_price or pos.entry_price, unrealized_pnl=pos.unrealized_pnl, realized_pnl=pos.realized_pnl ) for pos in positions ] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/balance") async def get_balance(paper_trading: bool = True): """Get account balance.""" try: paper_trading_sim = get_paper_trading() if paper_trading: balance = paper_trading_sim.get_balance() performance = paper_trading_sim.get_performance() return { "balance": float(balance), "performance": performance } else: # Live trading balance from exchange return {"balance": 0.0, "performance": {}} except Exception as e: raise HTTPException(status_code=500, detail=str(e))