Files
crypto_trader/backend/api/trading.py
kfox 7bd6be64a4
Some checks are pending
Documentation / build-docs (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.13) (push) Waiting to run
Tests / test (macos-latest, 3.14) (push) Waiting to run
Tests / test (ubuntu-latest, 3.11) (push) Waiting to run
Tests / test (ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (ubuntu-latest, 3.13) (push) Waiting to run
Tests / test (ubuntu-latest, 3.14) (push) Waiting to run
feat: Add core trading modules for risk management, backtesting, and execution algorithms, alongside a new ML transparency widget and related frontend dependencies.
2025-12-31 21:25:06 -05:00

279 lines
10 KiB
Python

"""Trading API endpoints."""
from decimal import Decimal
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from ..core.dependencies import get_trading_engine, get_db_session
from ..core.schemas import OrderCreate, OrderResponse, PositionResponse
from src.core.database import Order, OrderSide, OrderType, OrderStatus
from src.core.repositories import OrderRepository, PositionRepository
from src.core.logger import get_logger
from src.trading.paper_trading import get_paper_trading
from src.trading.advanced_orders import get_advanced_order_manager
router = APIRouter()
logger = get_logger(__name__)
@router.post("/orders", response_model=OrderResponse)
async def create_order(
order_data: OrderCreate,
trading_engine=Depends(get_trading_engine)
):
"""Create and execute a trading order.
Supports advanced order types:
- Trailing stop: Set trail_percent (e.g., 0.02 for 2%)
- Bracket orders: Set bracket_take_profit and bracket_stop_loss
- OCO orders: Set oco_price for the second order
- Iceberg orders: Set visible_quantity
"""
try:
# Convert string enums to actual enums
side = OrderSide(order_data.side.value)
order_type = OrderType(order_data.order_type.value)
# Execute the base order
order = await trading_engine.execute_order(
exchange_id=order_data.exchange_id,
strategy_id=order_data.strategy_id,
symbol=order_data.symbol,
side=side,
order_type=order_type,
quantity=order_data.quantity,
price=order_data.price,
paper_trading=order_data.paper_trading
)
if not order:
raise HTTPException(status_code=400, detail="Order execution failed")
# Handle advanced order types
advanced_manager = get_advanced_order_manager()
# Trailing stop
if order_type == OrderType.TRAILING_STOP and order_data.trail_percent:
if not order_data.stop_loss_price:
raise HTTPException(status_code=400, detail="stop_loss_price required for trailing stop")
advanced_manager.create_trailing_stop(
base_order_id=order.id,
initial_stop_price=order_data.stop_loss_price,
trail_percent=order_data.trail_percent
)
# Take profit
if order_data.take_profit_price:
advanced_manager.create_take_profit(
base_order_id=order.id,
target_price=order_data.take_profit_price
)
# Bracket order (entry + take profit + stop loss)
if order_data.bracket_take_profit or order_data.bracket_stop_loss:
if order_data.bracket_take_profit:
advanced_manager.create_take_profit(
base_order_id=order.id,
target_price=order_data.bracket_take_profit
)
if order_data.bracket_stop_loss:
advanced_manager.create_trailing_stop(
base_order_id=order.id,
initial_stop_price=order_data.bracket_stop_loss,
trail_percent=Decimal("0.0") # Fixed stop, not trailing
)
# OCO order
if order_type == OrderType.OCO and order_data.oco_price:
# Create second order
oco_order = await trading_engine.execute_order(
exchange_id=order_data.exchange_id,
strategy_id=order_data.strategy_id,
symbol=order_data.symbol,
side=OrderSide.SELL if side == OrderSide.BUY else OrderSide.BUY,
order_type=OrderType.LIMIT,
quantity=order_data.quantity,
price=order_data.oco_price,
paper_trading=order_data.paper_trading
)
if oco_order:
advanced_manager.create_oco(order.id, oco_order.id)
# Iceberg order
if order_type == OrderType.ICEBERG and order_data.visible_quantity:
advanced_manager.create_iceberg(
total_quantity=order_data.quantity,
visible_quantity=order_data.visible_quantity,
symbol=order_data.symbol,
side=side,
price=order_data.price
)
return OrderResponse.model_validate(order)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error creating order: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/orders", response_model=List[OrderResponse])
async def get_orders(
paper_trading: bool = True,
limit: int = 100,
db: AsyncSession = Depends(get_db_session)
):
"""Get order history."""
try:
repo = OrderRepository(db)
orders = await repo.get_all(limit=limit)
# Filter by paper_trading in memory or add to repo method (repo returns all by default currently sorted by date)
# Let's verify repo method. It has limit/offset but not filtering.
# We should filter here or improve repo.
# For this refactor, let's filter in python for simplicity or assume get_all needs an update.
# Ideally, update repo. But strictly following "get_all" contract.
filtered = [o for o in orders if o.paper_trading == paper_trading]
return [OrderResponse.model_validate(order) for order in filtered]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/orders/{order_id}", response_model=OrderResponse)
async def get_order(
order_id: int,
db: AsyncSession = Depends(get_db_session)
):
"""Get order by ID."""
try:
repo = OrderRepository(db)
order = await repo.get_by_id(order_id)
if not order:
raise HTTPException(status_code=404, detail="Order not found")
return OrderResponse.model_validate(order)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/orders/{order_id}/cancel")
async def cancel_order(
order_id: int,
trading_engine=Depends(get_trading_engine)
):
try:
# We can use trading_engine's cancel which handles DB and Exchange
success = await trading_engine.cancel_order(order_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to cancel order or order not found")
return {"status": "cancelled", "order_id": order_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/orders/cancel-all")
async def cancel_all_orders(
paper_trading: bool = True,
trading_engine=Depends(get_trading_engine),
db: AsyncSession = Depends(get_db_session)
):
"""Cancel all open orders."""
try:
repo = OrderRepository(db)
open_orders = await repo.get_open_orders(paper_trading=paper_trading)
if not open_orders:
return {"status": "no_orders", "cancelled_count": 0}
cancelled_count = 0
failed_count = 0
for order in open_orders:
try:
if await trading_engine.cancel_order(order.id):
cancelled_count += 1
else:
failed_count += 1
except Exception as e:
logger.error(f"Failed to cancel order {order.id}: {e}")
failed_count += 1
return {
"status": "completed",
"cancelled_count": cancelled_count,
"failed_count": failed_count,
"total_orders": len(open_orders)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/positions", response_model=List[PositionResponse])
async def get_positions(
paper_trading: bool = True,
db: AsyncSession = Depends(get_db_session)
):
"""Get current positions."""
try:
if paper_trading:
paper_trading_sim = get_paper_trading()
positions = paper_trading_sim.get_positions()
# positions is a List[Position], convert to PositionResponse list
position_list = []
for pos in positions:
# pos is a Position database object
current_price = pos.current_price if pos.current_price else pos.entry_price
unrealized_pnl = pos.unrealized_pnl if pos.unrealized_pnl else Decimal(0)
realized_pnl = pos.realized_pnl if pos.realized_pnl else Decimal(0)
position_list.append(
PositionResponse(
symbol=pos.symbol,
quantity=pos.quantity,
entry_price=pos.entry_price,
current_price=current_price,
unrealized_pnl=unrealized_pnl,
realized_pnl=realized_pnl
)
)
return position_list
else:
# Live trading positions from database
repo = PositionRepository(db)
positions = await repo.get_all(paper_trading=False)
return [
PositionResponse(
symbol=pos.symbol,
quantity=pos.quantity,
entry_price=pos.entry_price,
current_price=pos.current_price or pos.entry_price,
unrealized_pnl=pos.unrealized_pnl,
realized_pnl=pos.realized_pnl
)
for pos in positions
]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/balance")
async def get_balance(paper_trading: bool = True):
"""Get account balance."""
try:
paper_trading_sim = get_paper_trading()
if paper_trading:
balance = paper_trading_sim.get_balance()
performance = paper_trading_sim.get_performance()
return {
"balance": float(balance),
"performance": performance
}
else:
# Live trading balance from exchange
return {"balance": 0.0, "performance": {}}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))