Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
256
src/core/config.py
Normal file
256
src/core/config.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Configuration management system with YAML and environment variables."""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration manager with XDG directory support."""
|
||||
|
||||
def __init__(self, config_file: Optional[str] = None):
|
||||
"""Initialize configuration manager.
|
||||
|
||||
Args:
|
||||
config_file: Optional path to config file. If None, uses XDG default.
|
||||
"""
|
||||
self._setup_xdg_directories()
|
||||
|
||||
# Determine config file priority:
|
||||
# 1. Explicit argument
|
||||
# 2. Local project config (dev mode)
|
||||
# 3. XDG config (user mode)
|
||||
|
||||
local_config = Path(__file__).parent.parent.parent / "config" / "config.yaml"
|
||||
|
||||
if config_file:
|
||||
self.config_file = Path(config_file)
|
||||
elif local_config.exists():
|
||||
self.config_file = local_config
|
||||
else:
|
||||
self.config_file = self.config_dir / "config.yaml"
|
||||
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._load_config()
|
||||
|
||||
def _setup_xdg_directories(self):
|
||||
"""Set up XDG Base Directory Specification directories."""
|
||||
home = Path.home()
|
||||
|
||||
# XDG_CONFIG_HOME or default
|
||||
xdg_config = os.getenv("XDG_CONFIG_HOME", home / ".config")
|
||||
self.config_dir = Path(xdg_config) / "crypto_trader"
|
||||
self.config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# XDG_DATA_HOME or default
|
||||
xdg_data = os.getenv("XDG_DATA_HOME", home / ".local" / "share")
|
||||
self.data_dir = Path(xdg_data) / "crypto_trader"
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create subdirectories
|
||||
(self.data_dir / "historical").mkdir(exist_ok=True)
|
||||
(self.data_dir / "backups").mkdir(exist_ok=True)
|
||||
(self.data_dir / "logs").mkdir(exist_ok=True)
|
||||
|
||||
# XDG_CACHE_HOME or default
|
||||
xdg_cache = os.getenv("XDG_CACHE_HOME", home / ".cache")
|
||||
self.cache_dir = Path(xdg_cache) / "crypto_trader"
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from YAML file and environment variables."""
|
||||
# Load defaults
|
||||
default_config = self._get_default_config()
|
||||
self._config = default_config.copy()
|
||||
|
||||
# Load from file if it exists
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, 'r') as f:
|
||||
file_config = yaml.safe_load(f) or {}
|
||||
self._config.update(file_config)
|
||||
|
||||
# Override with environment variables
|
||||
self._load_from_env()
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration."""
|
||||
return {
|
||||
"app": {
|
||||
"name": "Crypto Trader",
|
||||
"version": "0.1.0",
|
||||
},
|
||||
"database": {
|
||||
"type": "postgresql",
|
||||
"url": None, # For PostgreSQL
|
||||
},
|
||||
"logging": {
|
||||
"level": os.getenv("LOG_LEVEL", "INFO"),
|
||||
"dir": str(self.data_dir / "logs"),
|
||||
"retention_days": 30,
|
||||
"rotation": "daily",
|
||||
},
|
||||
"paper_trading": {
|
||||
"enabled": True,
|
||||
"default_capital": float(os.getenv("PAPER_TRADING_CAPITAL", "100.0")),
|
||||
},
|
||||
"updates": {
|
||||
"check_on_startup": os.getenv("UPDATE_CHECK_ON_STARTUP", "true").lower() == "true",
|
||||
"repository_url": os.getenv("UPDATE_REPOSITORY_URL", ""),
|
||||
},
|
||||
"exchanges": {},
|
||||
"strategies": {
|
||||
"default_timeframe": "1h",
|
||||
},
|
||||
"risk": {
|
||||
"max_drawdown_percent": 20.0,
|
||||
"daily_loss_limit_percent": 5.0,
|
||||
"position_size_percent": 2.0,
|
||||
},
|
||||
"trading": {
|
||||
"default_fees": {
|
||||
"maker": 0.001, # 0.1%
|
||||
"taker": 0.001, # 0.1%
|
||||
"minimum": 0.0,
|
||||
},
|
||||
"exchanges": {},
|
||||
},
|
||||
"data_providers": {
|
||||
"primary": [
|
||||
{"name": "kraken", "enabled": True, "priority": 1},
|
||||
{"name": "coinbase", "enabled": True, "priority": 2},
|
||||
{"name": "binance", "enabled": True, "priority": 3},
|
||||
],
|
||||
"fallback": {
|
||||
"name": "coingecko",
|
||||
"enabled": True,
|
||||
"api_key": "",
|
||||
},
|
||||
"caching": {
|
||||
"ticker_ttl": 2, # seconds
|
||||
"ohlcv_ttl": 60, # seconds
|
||||
"max_cache_size": 1000,
|
||||
},
|
||||
"websocket": {
|
||||
"enabled": True,
|
||||
"reconnect_interval": 5, # seconds
|
||||
"ping_interval": 30, # seconds
|
||||
},
|
||||
},
|
||||
"redis": {
|
||||
"host": os.getenv("REDIS_HOST", "127.0.0.1"),
|
||||
"port": int(os.getenv("REDIS_PORT", 6379)),
|
||||
"db": int(os.getenv("REDIS_DB", 0)),
|
||||
"password": os.getenv("REDIS_PASSWORD", None),
|
||||
"socket_connect_timeout": 5,
|
||||
},
|
||||
"celery": {
|
||||
"broker_url": os.getenv("CELERY_BROKER_URL", "redis://127.0.0.1:6379/0"),
|
||||
"result_backend": os.getenv("CELERY_RESULT_BACKEND", "redis://127.0.0.1:6379/0"),
|
||||
},
|
||||
}
|
||||
|
||||
def _load_from_env(self):
|
||||
"""Load configuration from environment variables."""
|
||||
# Database
|
||||
if db_url := os.getenv("DATABASE_URL"):
|
||||
self._config["database"]["url"] = db_url
|
||||
self._config["database"]["type"] = "postgresql"
|
||||
|
||||
# Logging
|
||||
if log_level := os.getenv("LOG_LEVEL"):
|
||||
self._config["logging"]["level"] = log_level
|
||||
if log_dir := os.getenv("LOG_DIR"):
|
||||
self._config["logging"]["dir"] = log_dir
|
||||
|
||||
# Paper trading
|
||||
if capital := os.getenv("PAPER_TRADING_CAPITAL"):
|
||||
self._config["paper_trading"]["default_capital"] = float(capital)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value using dot notation.
|
||||
|
||||
Args:
|
||||
key: Configuration key (e.g., "database.path")
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
Configuration value or default
|
||||
"""
|
||||
keys = key.split(".")
|
||||
value = self._config
|
||||
for k in keys:
|
||||
if isinstance(value, dict):
|
||||
value = value.get(k)
|
||||
if value is None:
|
||||
return default
|
||||
else:
|
||||
return default
|
||||
return value
|
||||
|
||||
def set(self, key: str, value: Any):
|
||||
"""Set configuration value using dot notation.
|
||||
|
||||
Args:
|
||||
key: Configuration key (e.g., "database.path")
|
||||
value: Value to set
|
||||
"""
|
||||
keys = key.split(".")
|
||||
config = self._config
|
||||
for k in keys[:-1]:
|
||||
if k not in config:
|
||||
config[k] = {}
|
||||
config = config[k]
|
||||
config[keys[-1]] = value
|
||||
|
||||
def save(self):
|
||||
"""Save configuration to file."""
|
||||
with open(self.config_file, 'w') as f:
|
||||
yaml.dump(self._config, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
@property
|
||||
def config_dir(self) -> Path:
|
||||
"""Get config directory path."""
|
||||
return self._config_dir
|
||||
|
||||
@config_dir.setter
|
||||
def config_dir(self, value: Path):
|
||||
"""Set config directory."""
|
||||
self._config_dir = value
|
||||
|
||||
@property
|
||||
def data_dir(self) -> Path:
|
||||
"""Get data directory path."""
|
||||
return self._data_dir
|
||||
|
||||
@data_dir.setter
|
||||
def data_dir(self, value: Path):
|
||||
"""Set data directory."""
|
||||
self._data_dir = value
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Get cache directory path."""
|
||||
return self._cache_dir
|
||||
|
||||
@cache_dir.setter
|
||||
def cache_dir(self, value: Path):
|
||||
"""Set cache directory."""
|
||||
self._cache_dir = value
|
||||
|
||||
|
||||
# Global config instance
|
||||
_config_instance: Optional[Config] = None
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""Get global configuration instance."""
|
||||
global _config_instance
|
||||
if _config_instance is None:
|
||||
_config_instance = Config()
|
||||
return _config_instance
|
||||
|
||||
416
src/core/database.py
Normal file
416
src/core/database.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""Database connection and models using SQLAlchemy."""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from sqlalchemy import (
|
||||
create_engine, Column, Integer, String, Float, Boolean, DateTime,
|
||||
Text, ForeignKey, JSON, Enum as SQLEnum, Numeric
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, relationship, Session
|
||||
from .config import get_config
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class OrderType(str, Enum):
|
||||
"""Order type enumeration."""
|
||||
MARKET = "market"
|
||||
LIMIT = "limit"
|
||||
STOP_LOSS = "stop_loss"
|
||||
TAKE_PROFIT = "take_profit"
|
||||
TRAILING_STOP = "trailing_stop"
|
||||
OCO = "oco"
|
||||
ICEBERG = "iceberg"
|
||||
|
||||
|
||||
class OrderSide(str, Enum):
|
||||
"""Order side enumeration."""
|
||||
BUY = "buy"
|
||||
SELL = "sell"
|
||||
|
||||
|
||||
class OrderStatus(str, Enum):
|
||||
"""Order status enumeration."""
|
||||
PENDING = "pending"
|
||||
OPEN = "open"
|
||||
PARTIALLY_FILLED = "partially_filled"
|
||||
FILLED = "filled"
|
||||
CANCELLED = "cancelled"
|
||||
REJECTED = "rejected"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class TradeType(str, Enum):
|
||||
"""Trade type enumeration."""
|
||||
SPOT = "spot"
|
||||
FUTURES = "futures"
|
||||
MARGIN = "margin"
|
||||
|
||||
|
||||
class Exchange(Base):
|
||||
"""Exchange configuration and credentials."""
|
||||
__tablename__ = "exchanges"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50), nullable=False, unique=True)
|
||||
api_key_encrypted = Column(Text) # Encrypted API key
|
||||
api_secret_encrypted = Column(Text) # Encrypted API secret
|
||||
sandbox = Column(Boolean, default=False)
|
||||
read_only = Column(Boolean, default=True) # Read-only mode
|
||||
enabled = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
trades = relationship("Trade", back_populates="exchange")
|
||||
orders = relationship("Order", back_populates="exchange")
|
||||
positions = relationship("Position", back_populates="exchange")
|
||||
|
||||
|
||||
class Strategy(Base):
|
||||
"""Strategy definitions and parameters."""
|
||||
__tablename__ = "strategies"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
description = Column(Text)
|
||||
strategy_type = Column(String(50)) # technical, momentum, grid, dca, etc.
|
||||
class_name = Column(String(100)) # Python class name
|
||||
parameters = Column(JSON) # Strategy parameters
|
||||
timeframes = Column(JSON) # Multi-timeframe configuration
|
||||
enabled = Column(Boolean, default=False) # Available to Autopilot
|
||||
running = Column(Boolean, default=False) # Currently running manually
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
version = Column(String(20), default="1.0.0")
|
||||
schedule = Column(JSON) # Scheduling configuration
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
trades = relationship("Trade", back_populates="strategy")
|
||||
backtest_results = relationship("BacktestResult", back_populates="strategy")
|
||||
|
||||
|
||||
class Order(Base):
|
||||
"""Order history with state tracking."""
|
||||
__tablename__ = "orders"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
|
||||
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=True)
|
||||
exchange_order_id = Column(String(100)) # Exchange's order ID
|
||||
symbol = Column(String(20), nullable=False)
|
||||
order_type = Column(SQLEnum(OrderType), nullable=False)
|
||||
side = Column(SQLEnum(OrderSide), nullable=False)
|
||||
status = Column(SQLEnum(OrderStatus), default=OrderStatus.PENDING)
|
||||
quantity = Column(Numeric(20, 8), nullable=False)
|
||||
price = Column(Numeric(20, 8)) # For limit orders
|
||||
filled_quantity = Column(Numeric(20, 8), default=0)
|
||||
average_fill_price = Column(Numeric(20, 8))
|
||||
fee = Column(Numeric(20, 8), default=0)
|
||||
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
|
||||
leverage = Column(Integer, default=1) # For futures/margin
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
filled_at = Column(DateTime)
|
||||
|
||||
# Relationships
|
||||
exchange = relationship("Exchange", back_populates="orders")
|
||||
strategy = relationship("Strategy")
|
||||
trades = relationship("Trade", back_populates="order")
|
||||
|
||||
|
||||
class Trade(Base):
|
||||
"""All executed trades (paper and live)."""
|
||||
__tablename__ = "trades"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
|
||||
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=True)
|
||||
order_id = Column(Integer, ForeignKey("orders.id"), nullable=True)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
side = Column(SQLEnum(OrderSide), nullable=False)
|
||||
quantity = Column(Numeric(20, 8), nullable=False)
|
||||
price = Column(Numeric(20, 8), nullable=False)
|
||||
fee = Column(Numeric(20, 8), default=0)
|
||||
total = Column(Numeric(20, 8), nullable=False) # quantity * price + fee
|
||||
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
exchange = relationship("Exchange", back_populates="trades")
|
||||
strategy = relationship("Strategy", back_populates="trades")
|
||||
order = relationship("Order", back_populates="trades")
|
||||
|
||||
|
||||
class Position(Base):
|
||||
"""Current open positions (spot and futures)."""
|
||||
__tablename__ = "positions"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
side = Column(String(10)) # long, short
|
||||
quantity = Column(Numeric(20, 8), nullable=False)
|
||||
entry_price = Column(Numeric(20, 8), nullable=False)
|
||||
current_price = Column(Numeric(20, 8))
|
||||
unrealized_pnl = Column(Numeric(20, 8), default=0)
|
||||
realized_pnl = Column(Numeric(20, 8), default=0)
|
||||
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
|
||||
leverage = Column(Integer, default=1)
|
||||
margin = Column(Numeric(20, 8))
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
opened_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
exchange = relationship("Exchange", back_populates="positions")
|
||||
|
||||
|
||||
class PortfolioSnapshot(Base):
|
||||
"""Historical portfolio values."""
|
||||
__tablename__ = "portfolio_snapshots"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
total_value = Column(Numeric(20, 8), nullable=False)
|
||||
cash = Column(Numeric(20, 8), nullable=False)
|
||||
positions_value = Column(Numeric(20, 8), nullable=False)
|
||||
unrealized_pnl = Column(Numeric(20, 8), default=0)
|
||||
realized_pnl = Column(Numeric(20, 8), default=0)
|
||||
paper_trading = Column(Boolean, default=True)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
|
||||
class MarketData(Base):
|
||||
"""OHLCV historical data (multiple timeframes)."""
|
||||
__tablename__ = "market_data"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
exchange = Column(String(50), nullable=False)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
timeframe = Column(String(10), nullable=False) # 1m, 5m, 15m, 1h, 1d, etc.
|
||||
timestamp = Column(DateTime, nullable=False, index=True)
|
||||
open = Column(Numeric(20, 8), nullable=False)
|
||||
high = Column(Numeric(20, 8), nullable=False)
|
||||
low = Column(Numeric(20, 8), nullable=False)
|
||||
close = Column(Numeric(20, 8), nullable=False)
|
||||
volume = Column(Numeric(20, 8), nullable=False)
|
||||
|
||||
|
||||
|
||||
|
||||
class BacktestResult(Base):
|
||||
"""Backtesting results and metrics."""
|
||||
__tablename__ = "backtest_results"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=False)
|
||||
start_date = Column(DateTime, nullable=False)
|
||||
end_date = Column(DateTime, nullable=False)
|
||||
initial_capital = Column(Numeric(20, 8), nullable=False)
|
||||
final_capital = Column(Numeric(20, 8), nullable=False)
|
||||
total_return = Column(Numeric(10, 4)) # Percentage
|
||||
sharpe_ratio = Column(Numeric(10, 4))
|
||||
sortino_ratio = Column(Numeric(10, 4))
|
||||
max_drawdown = Column(Numeric(10, 4))
|
||||
win_rate = Column(Numeric(10, 4))
|
||||
total_trades = Column(Integer, default=0)
|
||||
parameters = Column(JSON) # Parameters used in backtest
|
||||
metrics = Column(JSON) # Additional metrics
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
strategy = relationship("Strategy", back_populates="backtest_results")
|
||||
|
||||
|
||||
class RiskLimit(Base):
|
||||
"""Risk management configuration."""
|
||||
__tablename__ = "risk_limits"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
limit_type = Column(String(50)) # max_drawdown, daily_loss, position_size, etc.
|
||||
value = Column(Numeric(10, 4), nullable=False)
|
||||
enabled = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class Alert(Base):
|
||||
"""Alert definitions and history."""
|
||||
__tablename__ = "alerts"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
alert_type = Column(String(50)) # price, indicator, risk, system
|
||||
condition = Column(JSON) # Alert condition configuration
|
||||
enabled = Column(Boolean, default=True)
|
||||
triggered = Column(Boolean, default=False)
|
||||
triggered_at = Column(DateTime)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class RebalancingEvent(Base):
|
||||
"""Portfolio rebalancing history."""
|
||||
__tablename__ = "rebalancing_events"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
trigger_type = Column(String(50)) # time, threshold, manual
|
||||
target_allocations = Column(JSON) # Target portfolio allocations
|
||||
before_allocations = Column(JSON) # Allocations before rebalancing
|
||||
after_allocations = Column(JSON) # Allocations after rebalancing
|
||||
orders_placed = Column(JSON) # Orders placed for rebalancing
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class AppState(Base):
|
||||
"""Application state for recovery."""
|
||||
__tablename__ = "app_state"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
key = Column(String(100), unique=True, nullable=False)
|
||||
value = Column(JSON)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Security and action audit trail."""
|
||||
__tablename__ = "audit_log"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
action = Column(String(100), nullable=False)
|
||||
entity_type = Column(String(50)) # exchange, strategy, order, etc.
|
||||
entity_id = Column(Integer)
|
||||
details = Column(JSON)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
|
||||
class MarketConditionsSnapshot(Base):
|
||||
"""Market conditions snapshot for ML training."""
|
||||
__tablename__ = "market_conditions_snapshot"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
timeframe = Column(String(10), nullable=False)
|
||||
regime = Column(String(50)) # Market regime classification
|
||||
features = Column(JSON) # Market condition features
|
||||
strategy_name = Column(String(100)) # Strategy used at this time
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
|
||||
|
||||
|
||||
class StrategyPerformance(Base):
|
||||
"""Strategy performance records for ML training."""
|
||||
__tablename__ = "strategy_performance"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
strategy_name = Column(String(100), nullable=False, index=True)
|
||||
symbol = Column(String(20), nullable=False)
|
||||
timeframe = Column(String(10), nullable=False)
|
||||
market_regime = Column(String(50), index=True) # Market regime when trade executed
|
||||
return_pct = Column(Numeric(10, 4)) # Return percentage
|
||||
sharpe_ratio = Column(Numeric(10, 4)) # Sharpe ratio
|
||||
win_rate = Column(Numeric(5, 2)) # Win rate (0-100)
|
||||
max_drawdown = Column(Numeric(10, 4)) # Maximum drawdown
|
||||
trade_count = Column(Integer, default=1) # Number of trades in this period
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
|
||||
|
||||
|
||||
class MLModelMetadata(Base):
|
||||
"""ML model metadata and versions."""
|
||||
__tablename__ = "ml_model_metadata"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
model_name = Column(String(100), nullable=False)
|
||||
model_type = Column(String(50)) # classifier, regressor
|
||||
version = Column(String(20))
|
||||
file_path = Column(String(500)) # Path to saved model file
|
||||
training_metrics = Column(JSON) # Training metrics (accuracy, MSE, etc.)
|
||||
feature_names = Column(JSON) # List of feature names
|
||||
strategy_names = Column(JSON) # List of strategy names
|
||||
training_samples = Column(Integer) # Number of training samples
|
||||
trained_at = Column(DateTime, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
|
||||
class Database:
|
||||
"""Database connection manager."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize database connection."""
|
||||
self.config = get_config()
|
||||
self.engine = self._create_engine()
|
||||
self.SessionLocal = async_sessionmaker(
|
||||
bind=self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
# self._create_tables() # Tables should be created via alembic or separate init script in async
|
||||
|
||||
def _create_engine(self):
|
||||
"""Create database engine."""
|
||||
db_type = self.config.get("database.type", "postgresql")
|
||||
|
||||
if db_type == "postgresql":
|
||||
db_url = self.config.get("database.url")
|
||||
if not db_url:
|
||||
raise ValueError("PostgreSQL URL not configured")
|
||||
# Ensure URL uses async driver (e.g. postgresql+asyncpg)
|
||||
if "postgresql://" in db_url and "postgresql+asyncpg://" not in db_url:
|
||||
# This is a naive replacement, in production we should handle this better
|
||||
db_url = db_url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
# Add connection timeout to prevent hanging
|
||||
# asyncpg connect timeout is set via connect_timeout in connect_args
|
||||
return create_async_engine(
|
||||
db_url,
|
||||
echo=False,
|
||||
connect_args={
|
||||
"server_settings": {"application_name": "crypto_trader"},
|
||||
"timeout": 5, # 5 second connection timeout
|
||||
},
|
||||
pool_pre_ping=True, # Verify connections before using
|
||||
pool_recycle=3600, # Recycle connections after 1 hour
|
||||
pool_timeout=5, # Timeout when getting connection from pool
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {db_type}. Only 'postgresql' is supported.")
|
||||
|
||||
async def create_tables(self):
|
||||
"""Create all database tables."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
def get_session(self) -> AsyncSession:
|
||||
"""Get a database session."""
|
||||
return self.SessionLocal()
|
||||
|
||||
async def close(self):
|
||||
"""Close database connection."""
|
||||
await self.engine.dispose()
|
||||
|
||||
|
||||
# Global database instance
|
||||
_db_instance: Optional[Database] = None
|
||||
|
||||
|
||||
def get_database() -> Database:
|
||||
"""Get global database instance."""
|
||||
global _db_instance
|
||||
if _db_instance is None:
|
||||
_db_instance = Database()
|
||||
return _db_instance
|
||||
|
||||
128
src/core/logger.py
Normal file
128
src/core/logger.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Configurable logging system with XDG directory support."""
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class LoggingConfig:
|
||||
"""Logging configuration manager."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize logging configuration."""
|
||||
self.config = get_config()
|
||||
self.log_dir = Path(self.config.get("logging.dir", "~/.local/share/crypto_trader/logs")).expanduser()
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.retention_days = self.config.get("logging.retention_days", 30)
|
||||
self._setup_logging()
|
||||
|
||||
def _setup_logging(self):
|
||||
"""Set up logging configuration."""
|
||||
log_level = self.config.get("logging.level", "INFO")
|
||||
level = getattr(logging, log_level.upper(), logging.INFO)
|
||||
|
||||
# Root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(level)
|
||||
console_formatter = logging.Formatter(
|
||||
'%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler with rotation
|
||||
log_file = self.log_dir / "crypto_trader.log"
|
||||
file_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
log_file,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=self.retention_days,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s [%(levelname)s] %(name)s:%(lineno)d: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# Compress old logs
|
||||
self._setup_log_compression()
|
||||
|
||||
def _setup_log_compression(self):
|
||||
"""Set up log compression for old log files."""
|
||||
import gzip
|
||||
import glob
|
||||
|
||||
# Compress logs older than retention period
|
||||
log_files = list(self.log_dir.glob("crypto_trader.log.*"))
|
||||
for log_file in log_files:
|
||||
if not log_file.name.endswith('.gz'):
|
||||
try:
|
||||
with open(log_file, 'rb') as f_in:
|
||||
with gzip.open(f"{log_file}.gz", 'wb') as f_out:
|
||||
f_out.writelines(f_in)
|
||||
log_file.unlink()
|
||||
except Exception:
|
||||
pass # Skip if compression fails
|
||||
|
||||
def get_logger(self, name: str) -> logging.Logger:
|
||||
"""Get a logger with the specified name.
|
||||
|
||||
Args:
|
||||
name: Logger name (typically module name)
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
return logger
|
||||
|
||||
def set_level(self, level: str):
|
||||
"""Set logging level.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
"""
|
||||
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||
logging.getLogger().setLevel(log_level)
|
||||
for handler in logging.getLogger().handlers:
|
||||
handler.setLevel(log_level)
|
||||
|
||||
|
||||
# Global logging config instance
|
||||
_logging_config: Optional[LoggingConfig] = None
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name (typically __name__)
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
global _logging_config
|
||||
if _logging_config is None:
|
||||
_logging_config = LoggingConfig()
|
||||
return _logging_config.get_logger(name)
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Set up logging system."""
|
||||
global _logging_config
|
||||
_logging_config = LoggingConfig()
|
||||
|
||||
261
src/core/pubsub.py
Normal file
261
src/core/pubsub.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""Redis Pub/Sub for real-time event broadcasting across workers."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Callable, Dict, Any, Optional, List
|
||||
from src.core.redis import get_redis_client
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Channel names
|
||||
CHANNEL_MARKET_EVENTS = "crypto_trader:market_events"
|
||||
CHANNEL_TRADE_EVENTS = "crypto_trader:trade_events"
|
||||
CHANNEL_SYSTEM_EVENTS = "crypto_trader:system_events"
|
||||
CHANNEL_AUTOPILOT_EVENTS = "crypto_trader:autopilot_events"
|
||||
|
||||
|
||||
class RedisPubSub:
|
||||
"""Redis Pub/Sub handler for real-time event broadcasting."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Redis Pub/Sub."""
|
||||
self.redis = get_redis_client()
|
||||
self._subscribers: Dict[str, List[Callable]] = {}
|
||||
self._pubsub = None
|
||||
self._running = False
|
||||
self._listen_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def publish(self, channel: str, event_type: str, data: Dict[str, Any]) -> int:
|
||||
"""Publish an event to a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name
|
||||
event_type: Type of event (e.g., 'price_update', 'trade_executed')
|
||||
data: Event data
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
"""
|
||||
message = {
|
||||
"type": event_type,
|
||||
"data": data,
|
||||
"timestamp": asyncio.get_event_loop().time()
|
||||
}
|
||||
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
count = await client.publish(channel, json.dumps(message))
|
||||
logger.debug(f"Published {event_type} to {channel} ({count} subscribers)")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to publish to {channel}: {e}")
|
||||
return 0
|
||||
|
||||
async def subscribe(self, channel: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""Subscribe to a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name
|
||||
callback: Async function to call when message received
|
||||
"""
|
||||
if channel not in self._subscribers:
|
||||
self._subscribers[channel] = []
|
||||
self._subscribers[channel].append(callback)
|
||||
|
||||
logger.info(f"Subscribed to channel: {channel}")
|
||||
|
||||
# Start listener if not running
|
||||
if not self._running:
|
||||
await self._start_listener()
|
||||
|
||||
async def unsubscribe(self, channel: str, callback: Callable = None) -> None:
|
||||
"""Unsubscribe from a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name
|
||||
callback: Specific callback to remove, or None to remove all
|
||||
"""
|
||||
if channel in self._subscribers:
|
||||
if callback:
|
||||
self._subscribers[channel] = [c for c in self._subscribers[channel] if c != callback]
|
||||
else:
|
||||
del self._subscribers[channel]
|
||||
|
||||
logger.info(f"Unsubscribed from channel: {channel}")
|
||||
|
||||
async def _start_listener(self) -> None:
|
||||
"""Start the Pub/Sub listener."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._listen_task = asyncio.create_task(self._listen())
|
||||
logger.info("Started Redis Pub/Sub listener")
|
||||
|
||||
async def _listen(self) -> None:
|
||||
"""Listen for messages on subscribed channels."""
|
||||
try:
|
||||
client = self.redis.get_client()
|
||||
self._pubsub = client.pubsub()
|
||||
|
||||
# Subscribe to all registered channels
|
||||
channels = list(self._subscribers.keys())
|
||||
if channels:
|
||||
await self._pubsub.subscribe(*channels)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
message = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
|
||||
if message and message['type'] == 'message':
|
||||
channel = message['channel']
|
||||
if isinstance(channel, bytes):
|
||||
channel = channel.decode('utf-8')
|
||||
|
||||
data = message['data']
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode('utf-8')
|
||||
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
parsed = {"raw": data}
|
||||
|
||||
# Call all subscribers for this channel
|
||||
callbacks = self._subscribers.get(channel, [])
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(parsed)
|
||||
else:
|
||||
callback(parsed)
|
||||
except Exception as e:
|
||||
logger.error(f"Subscriber callback error: {e}")
|
||||
|
||||
await asyncio.sleep(0.01) # Prevent busy loop
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Pub/Sub listener error: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
finally:
|
||||
if self._pubsub:
|
||||
await self._pubsub.close()
|
||||
self._running = False
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Pub/Sub listener."""
|
||||
self._running = False
|
||||
if self._listen_task:
|
||||
self._listen_task.cancel()
|
||||
try:
|
||||
await self._listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Stopped Redis Pub/Sub listener")
|
||||
|
||||
# Convenience methods for common event types
|
||||
|
||||
async def publish_price_update(self, symbol: str, price: float, bid: float = None, ask: float = None) -> int:
|
||||
"""Publish a price update event.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Current price
|
||||
bid: Bid price
|
||||
ask: Ask price
|
||||
|
||||
Returns:
|
||||
Number of subscribers notified
|
||||
"""
|
||||
return await self.publish(CHANNEL_MARKET_EVENTS, "price_update", {
|
||||
"symbol": symbol,
|
||||
"price": price,
|
||||
"bid": bid,
|
||||
"ask": ask
|
||||
})
|
||||
|
||||
async def publish_trade_executed(
|
||||
self,
|
||||
symbol: str,
|
||||
side: str,
|
||||
quantity: float,
|
||||
price: float,
|
||||
order_id: str = None
|
||||
) -> int:
|
||||
"""Publish a trade execution event.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
side: 'buy' or 'sell'
|
||||
quantity: Trade quantity
|
||||
price: Execution price
|
||||
order_id: Order ID
|
||||
|
||||
Returns:
|
||||
Number of subscribers notified
|
||||
"""
|
||||
return await self.publish(CHANNEL_TRADE_EVENTS, "trade_executed", {
|
||||
"symbol": symbol,
|
||||
"side": side,
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"order_id": order_id
|
||||
})
|
||||
|
||||
async def publish_autopilot_status(
|
||||
self,
|
||||
symbol: str,
|
||||
status: str,
|
||||
action: str = None,
|
||||
details: Dict[str, Any] = None
|
||||
) -> int:
|
||||
"""Publish an autopilot status event.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
status: 'started', 'stopped', 'error', 'signal'
|
||||
action: Optional action taken
|
||||
details: Additional details
|
||||
|
||||
Returns:
|
||||
Number of subscribers notified
|
||||
"""
|
||||
return await self.publish(CHANNEL_AUTOPILOT_EVENTS, "autopilot_status", {
|
||||
"symbol": symbol,
|
||||
"status": status,
|
||||
"action": action,
|
||||
"details": details or {}
|
||||
})
|
||||
|
||||
async def publish_system_event(self, event_type: str, message: str, severity: str = "info") -> int:
|
||||
"""Publish a system event.
|
||||
|
||||
Args:
|
||||
event_type: Event type (e.g., 'startup', 'shutdown', 'error')
|
||||
message: Event message
|
||||
severity: 'info', 'warning', 'error'
|
||||
|
||||
Returns:
|
||||
Number of subscribers notified
|
||||
"""
|
||||
return await self.publish(CHANNEL_SYSTEM_EVENTS, event_type, {
|
||||
"message": message,
|
||||
"severity": severity
|
||||
})
|
||||
|
||||
|
||||
# Global Pub/Sub instance
|
||||
_redis_pubsub: Optional[RedisPubSub] = None
|
||||
|
||||
|
||||
def get_redis_pubsub() -> RedisPubSub:
|
||||
"""Get global Redis Pub/Sub instance."""
|
||||
global _redis_pubsub
|
||||
if _redis_pubsub is None:
|
||||
_redis_pubsub = RedisPubSub()
|
||||
return _redis_pubsub
|
||||
70
src/core/redis.py
Normal file
70
src/core/redis.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Redis client wrapper."""
|
||||
|
||||
import redis.asyncio as redis
|
||||
from typing import Optional
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisClient:
|
||||
"""Redis client wrapper with automatic connection handling."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Redis client."""
|
||||
self.config = get_config()
|
||||
self._client: Optional[redis.Redis] = None
|
||||
self._pool: Optional[redis.ConnectionPool] = None
|
||||
|
||||
def get_client(self) -> redis.Redis:
|
||||
"""Get or create Redis client.
|
||||
|
||||
Returns:
|
||||
Async Redis client
|
||||
"""
|
||||
if self._client is None:
|
||||
self._connect()
|
||||
return self._client
|
||||
|
||||
def _connect(self):
|
||||
"""Connect to Redis."""
|
||||
redis_config = self.config.get("redis", {})
|
||||
host = redis_config.get("host", "localhost")
|
||||
port = redis_config.get("port", 6379)
|
||||
db = redis_config.get("db", 0)
|
||||
password = redis_config.get("password")
|
||||
|
||||
logger.info(f"Connecting to Redis at {host}:{port}/{db}")
|
||||
|
||||
try:
|
||||
self._pool = redis.ConnectionPool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=redis_config.get("socket_connect_timeout", 5)
|
||||
)
|
||||
self._client = redis.Redis(connection_pool=self._pool)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Redis client: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
logger.info("Redis connection closed")
|
||||
|
||||
|
||||
# Global instance
|
||||
_redis_client: Optional[RedisClient] = None
|
||||
|
||||
|
||||
def get_redis_client() -> RedisClient:
|
||||
"""Get global Redis client instance."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = RedisClient()
|
||||
return _redis_client
|
||||
99
src/core/repositories.py
Normal file
99
src/core/repositories.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Database repositories for data access."""
|
||||
|
||||
from typing import Optional, List, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from .database import Order, Position, OrderStatus, OrderSide, OrderType, MarketData
|
||||
|
||||
class BaseRepository:
|
||||
"""Base repository."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""Initialize repository."""
|
||||
self.session = session
|
||||
|
||||
class OrderRepository(BaseRepository):
|
||||
"""Order repository."""
|
||||
|
||||
async def create(self, order: Order) -> Order:
|
||||
"""Create new order."""
|
||||
self.session.add(order)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(order)
|
||||
return order
|
||||
|
||||
async def get_by_id(self, order_id: int) -> Optional[Order]:
|
||||
"""Get order by ID."""
|
||||
result = await self.session.execute(
|
||||
select(Order).where(Order.id == order_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_all(self, limit: int = 100, offset: int = 0) -> Sequence[Order]:
|
||||
"""Get all orders."""
|
||||
result = await self.session.execute(
|
||||
select(Order).limit(limit).offset(offset).order_by(Order.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
order_id: int,
|
||||
status: OrderStatus,
|
||||
exchange_order_id: Optional[str] = None,
|
||||
fee: Optional[Decimal] = None
|
||||
) -> Optional[Order]:
|
||||
"""Update order status."""
|
||||
values = {"status": status, "updated_at": datetime.utcnow()}
|
||||
if exchange_order_id:
|
||||
values["exchange_order_id"] = exchange_order_id
|
||||
if fee is not None:
|
||||
values["fee"] = fee
|
||||
|
||||
await self.session.execute(
|
||||
update(Order)
|
||||
.where(Order.id == order_id)
|
||||
.values(**values)
|
||||
)
|
||||
await self.session.commit()
|
||||
return await self.get_by_id(order_id)
|
||||
|
||||
async def get_open_orders(self, paper_trading: bool = True) -> Sequence[Order]:
|
||||
"""Get open orders."""
|
||||
result = await self.session.execute(
|
||||
select(Order).where(
|
||||
Order.paper_trading == paper_trading,
|
||||
Order.status.in_([OrderStatus.PENDING, OrderStatus.OPEN, OrderStatus.PARTIALLY_FILLED])
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def delete(self, order_id: int) -> bool:
|
||||
"""Delete order."""
|
||||
result = await self.session.execute(
|
||||
delete(Order).where(Order.id == order_id)
|
||||
)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
class PositionRepository(BaseRepository):
|
||||
"""Position repository."""
|
||||
|
||||
async def get_all(self, paper_trading: bool = True) -> Sequence[Position]:
|
||||
"""Get all positions."""
|
||||
result = await self.session.execute(
|
||||
select(Position).where(Position.paper_trading == paper_trading)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_by_symbol(self, symbol: str, paper_trading: bool = True) -> Optional[Position]:
|
||||
"""Get position by symbol."""
|
||||
result = await self.session.execute(
|
||||
select(Position).where(
|
||||
Position.symbol == symbol,
|
||||
Position.paper_trading == paper_trading
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
Reference in New Issue
Block a user