Local changes: Updated model training, removed debug instrumentation, and configuration improvements

This commit is contained in:
kfox
2025-12-26 01:15:43 -05:00
commit cc60da49e7
388 changed files with 57127 additions and 0 deletions

0
src/core/__init__.py Normal file
View File

256
src/core/config.py Normal file
View File

@@ -0,0 +1,256 @@
"""Configuration management system with YAML and environment variables."""
import os
import yaml
from pathlib import Path
from typing import Any, Dict, Optional
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
class Config:
"""Configuration manager with XDG directory support."""
def __init__(self, config_file: Optional[str] = None):
"""Initialize configuration manager.
Args:
config_file: Optional path to config file. If None, uses XDG default.
"""
self._setup_xdg_directories()
# Determine config file priority:
# 1. Explicit argument
# 2. Local project config (dev mode)
# 3. XDG config (user mode)
local_config = Path(__file__).parent.parent.parent / "config" / "config.yaml"
if config_file:
self.config_file = Path(config_file)
elif local_config.exists():
self.config_file = local_config
else:
self.config_file = self.config_dir / "config.yaml"
self._config: Dict[str, Any] = {}
self._load_config()
def _setup_xdg_directories(self):
"""Set up XDG Base Directory Specification directories."""
home = Path.home()
# XDG_CONFIG_HOME or default
xdg_config = os.getenv("XDG_CONFIG_HOME", home / ".config")
self.config_dir = Path(xdg_config) / "crypto_trader"
self.config_dir.mkdir(parents=True, exist_ok=True)
# XDG_DATA_HOME or default
xdg_data = os.getenv("XDG_DATA_HOME", home / ".local" / "share")
self.data_dir = Path(xdg_data) / "crypto_trader"
self.data_dir.mkdir(parents=True, exist_ok=True)
# Create subdirectories
(self.data_dir / "historical").mkdir(exist_ok=True)
(self.data_dir / "backups").mkdir(exist_ok=True)
(self.data_dir / "logs").mkdir(exist_ok=True)
# XDG_CACHE_HOME or default
xdg_cache = os.getenv("XDG_CACHE_HOME", home / ".cache")
self.cache_dir = Path(xdg_cache) / "crypto_trader"
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _load_config(self):
"""Load configuration from YAML file and environment variables."""
# Load defaults
default_config = self._get_default_config()
self._config = default_config.copy()
# Load from file if it exists
if self.config_file.exists():
with open(self.config_file, 'r') as f:
file_config = yaml.safe_load(f) or {}
self._config.update(file_config)
# Override with environment variables
self._load_from_env()
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration."""
return {
"app": {
"name": "Crypto Trader",
"version": "0.1.0",
},
"database": {
"type": "postgresql",
"url": None, # For PostgreSQL
},
"logging": {
"level": os.getenv("LOG_LEVEL", "INFO"),
"dir": str(self.data_dir / "logs"),
"retention_days": 30,
"rotation": "daily",
},
"paper_trading": {
"enabled": True,
"default_capital": float(os.getenv("PAPER_TRADING_CAPITAL", "100.0")),
},
"updates": {
"check_on_startup": os.getenv("UPDATE_CHECK_ON_STARTUP", "true").lower() == "true",
"repository_url": os.getenv("UPDATE_REPOSITORY_URL", ""),
},
"exchanges": {},
"strategies": {
"default_timeframe": "1h",
},
"risk": {
"max_drawdown_percent": 20.0,
"daily_loss_limit_percent": 5.0,
"position_size_percent": 2.0,
},
"trading": {
"default_fees": {
"maker": 0.001, # 0.1%
"taker": 0.001, # 0.1%
"minimum": 0.0,
},
"exchanges": {},
},
"data_providers": {
"primary": [
{"name": "kraken", "enabled": True, "priority": 1},
{"name": "coinbase", "enabled": True, "priority": 2},
{"name": "binance", "enabled": True, "priority": 3},
],
"fallback": {
"name": "coingecko",
"enabled": True,
"api_key": "",
},
"caching": {
"ticker_ttl": 2, # seconds
"ohlcv_ttl": 60, # seconds
"max_cache_size": 1000,
},
"websocket": {
"enabled": True,
"reconnect_interval": 5, # seconds
"ping_interval": 30, # seconds
},
},
"redis": {
"host": os.getenv("REDIS_HOST", "127.0.0.1"),
"port": int(os.getenv("REDIS_PORT", 6379)),
"db": int(os.getenv("REDIS_DB", 0)),
"password": os.getenv("REDIS_PASSWORD", None),
"socket_connect_timeout": 5,
},
"celery": {
"broker_url": os.getenv("CELERY_BROKER_URL", "redis://127.0.0.1:6379/0"),
"result_backend": os.getenv("CELERY_RESULT_BACKEND", "redis://127.0.0.1:6379/0"),
},
}
def _load_from_env(self):
"""Load configuration from environment variables."""
# Database
if db_url := os.getenv("DATABASE_URL"):
self._config["database"]["url"] = db_url
self._config["database"]["type"] = "postgresql"
# Logging
if log_level := os.getenv("LOG_LEVEL"):
self._config["logging"]["level"] = log_level
if log_dir := os.getenv("LOG_DIR"):
self._config["logging"]["dir"] = log_dir
# Paper trading
if capital := os.getenv("PAPER_TRADING_CAPITAL"):
self._config["paper_trading"]["default_capital"] = float(capital)
def get(self, key: str, default: Any = None) -> Any:
"""Get configuration value using dot notation.
Args:
key: Configuration key (e.g., "database.path")
default: Default value if key not found
Returns:
Configuration value or default
"""
keys = key.split(".")
value = self._config
for k in keys:
if isinstance(value, dict):
value = value.get(k)
if value is None:
return default
else:
return default
return value
def set(self, key: str, value: Any):
"""Set configuration value using dot notation.
Args:
key: Configuration key (e.g., "database.path")
value: Value to set
"""
keys = key.split(".")
config = self._config
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
config[keys[-1]] = value
def save(self):
"""Save configuration to file."""
with open(self.config_file, 'w') as f:
yaml.dump(self._config, f, default_flow_style=False, sort_keys=False)
@property
def config_dir(self) -> Path:
"""Get config directory path."""
return self._config_dir
@config_dir.setter
def config_dir(self, value: Path):
"""Set config directory."""
self._config_dir = value
@property
def data_dir(self) -> Path:
"""Get data directory path."""
return self._data_dir
@data_dir.setter
def data_dir(self, value: Path):
"""Set data directory."""
self._data_dir = value
@property
def cache_dir(self) -> Path:
"""Get cache directory path."""
return self._cache_dir
@cache_dir.setter
def cache_dir(self, value: Path):
"""Set cache directory."""
self._cache_dir = value
# Global config instance
_config_instance: Optional[Config] = None
def get_config() -> Config:
"""Get global configuration instance."""
global _config_instance
if _config_instance is None:
_config_instance = Config()
return _config_instance

416
src/core/database.py Normal file
View File

@@ -0,0 +1,416 @@
"""Database connection and models using SQLAlchemy."""
from datetime import datetime
from decimal import Decimal
from enum import Enum
from pathlib import Path
from typing import Optional
from sqlalchemy import (
create_engine, Column, Integer, String, Float, Boolean, DateTime,
Text, ForeignKey, JSON, Enum as SQLEnum, Numeric
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship, Session
from .config import get_config
Base = declarative_base()
class OrderType(str, Enum):
"""Order type enumeration."""
MARKET = "market"
LIMIT = "limit"
STOP_LOSS = "stop_loss"
TAKE_PROFIT = "take_profit"
TRAILING_STOP = "trailing_stop"
OCO = "oco"
ICEBERG = "iceberg"
class OrderSide(str, Enum):
"""Order side enumeration."""
BUY = "buy"
SELL = "sell"
class OrderStatus(str, Enum):
"""Order status enumeration."""
PENDING = "pending"
OPEN = "open"
PARTIALLY_FILLED = "partially_filled"
FILLED = "filled"
CANCELLED = "cancelled"
REJECTED = "rejected"
EXPIRED = "expired"
class TradeType(str, Enum):
"""Trade type enumeration."""
SPOT = "spot"
FUTURES = "futures"
MARGIN = "margin"
class Exchange(Base):
"""Exchange configuration and credentials."""
__tablename__ = "exchanges"
id = Column(Integer, primary_key=True)
name = Column(String(50), nullable=False, unique=True)
api_key_encrypted = Column(Text) # Encrypted API key
api_secret_encrypted = Column(Text) # Encrypted API secret
sandbox = Column(Boolean, default=False)
read_only = Column(Boolean, default=True) # Read-only mode
enabled = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
trades = relationship("Trade", back_populates="exchange")
orders = relationship("Order", back_populates="exchange")
positions = relationship("Position", back_populates="exchange")
class Strategy(Base):
"""Strategy definitions and parameters."""
__tablename__ = "strategies"
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
description = Column(Text)
strategy_type = Column(String(50)) # technical, momentum, grid, dca, etc.
class_name = Column(String(100)) # Python class name
parameters = Column(JSON) # Strategy parameters
timeframes = Column(JSON) # Multi-timeframe configuration
enabled = Column(Boolean, default=False) # Available to Autopilot
running = Column(Boolean, default=False) # Currently running manually
paper_trading = Column(Boolean, default=True)
version = Column(String(20), default="1.0.0")
schedule = Column(JSON) # Scheduling configuration
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
trades = relationship("Trade", back_populates="strategy")
backtest_results = relationship("BacktestResult", back_populates="strategy")
class Order(Base):
"""Order history with state tracking."""
__tablename__ = "orders"
id = Column(Integer, primary_key=True)
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=True)
exchange_order_id = Column(String(100)) # Exchange's order ID
symbol = Column(String(20), nullable=False)
order_type = Column(SQLEnum(OrderType), nullable=False)
side = Column(SQLEnum(OrderSide), nullable=False)
status = Column(SQLEnum(OrderStatus), default=OrderStatus.PENDING)
quantity = Column(Numeric(20, 8), nullable=False)
price = Column(Numeric(20, 8)) # For limit orders
filled_quantity = Column(Numeric(20, 8), default=0)
average_fill_price = Column(Numeric(20, 8))
fee = Column(Numeric(20, 8), default=0)
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
leverage = Column(Integer, default=1) # For futures/margin
paper_trading = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
filled_at = Column(DateTime)
# Relationships
exchange = relationship("Exchange", back_populates="orders")
strategy = relationship("Strategy")
trades = relationship("Trade", back_populates="order")
class Trade(Base):
"""All executed trades (paper and live)."""
__tablename__ = "trades"
id = Column(Integer, primary_key=True)
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=True)
order_id = Column(Integer, ForeignKey("orders.id"), nullable=True)
symbol = Column(String(20), nullable=False)
side = Column(SQLEnum(OrderSide), nullable=False)
quantity = Column(Numeric(20, 8), nullable=False)
price = Column(Numeric(20, 8), nullable=False)
fee = Column(Numeric(20, 8), default=0)
total = Column(Numeric(20, 8), nullable=False) # quantity * price + fee
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
paper_trading = Column(Boolean, default=True)
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
# Relationships
exchange = relationship("Exchange", back_populates="trades")
strategy = relationship("Strategy", back_populates="trades")
order = relationship("Order", back_populates="trades")
class Position(Base):
"""Current open positions (spot and futures)."""
__tablename__ = "positions"
id = Column(Integer, primary_key=True)
exchange_id = Column(Integer, ForeignKey("exchanges.id"), nullable=False)
symbol = Column(String(20), nullable=False)
side = Column(String(10)) # long, short
quantity = Column(Numeric(20, 8), nullable=False)
entry_price = Column(Numeric(20, 8), nullable=False)
current_price = Column(Numeric(20, 8))
unrealized_pnl = Column(Numeric(20, 8), default=0)
realized_pnl = Column(Numeric(20, 8), default=0)
trade_type = Column(SQLEnum(TradeType), default=TradeType.SPOT)
leverage = Column(Integer, default=1)
margin = Column(Numeric(20, 8))
paper_trading = Column(Boolean, default=True)
opened_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
exchange = relationship("Exchange", back_populates="positions")
class PortfolioSnapshot(Base):
"""Historical portfolio values."""
__tablename__ = "portfolio_snapshots"
id = Column(Integer, primary_key=True)
total_value = Column(Numeric(20, 8), nullable=False)
cash = Column(Numeric(20, 8), nullable=False)
positions_value = Column(Numeric(20, 8), nullable=False)
unrealized_pnl = Column(Numeric(20, 8), default=0)
realized_pnl = Column(Numeric(20, 8), default=0)
paper_trading = Column(Boolean, default=True)
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
class MarketData(Base):
"""OHLCV historical data (multiple timeframes)."""
__tablename__ = "market_data"
id = Column(Integer, primary_key=True)
exchange = Column(String(50), nullable=False)
symbol = Column(String(20), nullable=False)
timeframe = Column(String(10), nullable=False) # 1m, 5m, 15m, 1h, 1d, etc.
timestamp = Column(DateTime, nullable=False, index=True)
open = Column(Numeric(20, 8), nullable=False)
high = Column(Numeric(20, 8), nullable=False)
low = Column(Numeric(20, 8), nullable=False)
close = Column(Numeric(20, 8), nullable=False)
volume = Column(Numeric(20, 8), nullable=False)
class BacktestResult(Base):
"""Backtesting results and metrics."""
__tablename__ = "backtest_results"
id = Column(Integer, primary_key=True)
strategy_id = Column(Integer, ForeignKey("strategies.id"), nullable=False)
start_date = Column(DateTime, nullable=False)
end_date = Column(DateTime, nullable=False)
initial_capital = Column(Numeric(20, 8), nullable=False)
final_capital = Column(Numeric(20, 8), nullable=False)
total_return = Column(Numeric(10, 4)) # Percentage
sharpe_ratio = Column(Numeric(10, 4))
sortino_ratio = Column(Numeric(10, 4))
max_drawdown = Column(Numeric(10, 4))
win_rate = Column(Numeric(10, 4))
total_trades = Column(Integer, default=0)
parameters = Column(JSON) # Parameters used in backtest
metrics = Column(JSON) # Additional metrics
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
strategy = relationship("Strategy", back_populates="backtest_results")
class RiskLimit(Base):
"""Risk management configuration."""
__tablename__ = "risk_limits"
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
limit_type = Column(String(50)) # max_drawdown, daily_loss, position_size, etc.
value = Column(Numeric(10, 4), nullable=False)
enabled = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class Alert(Base):
"""Alert definitions and history."""
__tablename__ = "alerts"
id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
alert_type = Column(String(50)) # price, indicator, risk, system
condition = Column(JSON) # Alert condition configuration
enabled = Column(Boolean, default=True)
triggered = Column(Boolean, default=False)
triggered_at = Column(DateTime)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class RebalancingEvent(Base):
"""Portfolio rebalancing history."""
__tablename__ = "rebalancing_events"
id = Column(Integer, primary_key=True)
trigger_type = Column(String(50)) # time, threshold, manual
target_allocations = Column(JSON) # Target portfolio allocations
before_allocations = Column(JSON) # Allocations before rebalancing
after_allocations = Column(JSON) # Allocations after rebalancing
orders_placed = Column(JSON) # Orders placed for rebalancing
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
class AppState(Base):
"""Application state for recovery."""
__tablename__ = "app_state"
id = Column(Integer, primary_key=True)
key = Column(String(100), unique=True, nullable=False)
value = Column(JSON)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class AuditLog(Base):
"""Security and action audit trail."""
__tablename__ = "audit_log"
id = Column(Integer, primary_key=True)
action = Column(String(100), nullable=False)
entity_type = Column(String(50)) # exchange, strategy, order, etc.
entity_id = Column(Integer)
details = Column(JSON)
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
class MarketConditionsSnapshot(Base):
"""Market conditions snapshot for ML training."""
__tablename__ = "market_conditions_snapshot"
id = Column(Integer, primary_key=True)
symbol = Column(String(20), nullable=False)
timeframe = Column(String(10), nullable=False)
regime = Column(String(50)) # Market regime classification
features = Column(JSON) # Market condition features
strategy_name = Column(String(100)) # Strategy used at this time
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
class StrategyPerformance(Base):
"""Strategy performance records for ML training."""
__tablename__ = "strategy_performance"
id = Column(Integer, primary_key=True)
strategy_name = Column(String(100), nullable=False, index=True)
symbol = Column(String(20), nullable=False)
timeframe = Column(String(10), nullable=False)
market_regime = Column(String(50), index=True) # Market regime when trade executed
return_pct = Column(Numeric(10, 4)) # Return percentage
sharpe_ratio = Column(Numeric(10, 4)) # Sharpe ratio
win_rate = Column(Numeric(5, 2)) # Win rate (0-100)
max_drawdown = Column(Numeric(10, 4)) # Maximum drawdown
trade_count = Column(Integer, default=1) # Number of trades in this period
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
class MLModelMetadata(Base):
"""ML model metadata and versions."""
__tablename__ = "ml_model_metadata"
id = Column(Integer, primary_key=True)
model_name = Column(String(100), nullable=False)
model_type = Column(String(50)) # classifier, regressor
version = Column(String(20))
file_path = Column(String(500)) # Path to saved model file
training_metrics = Column(JSON) # Training metrics (accuracy, MSE, etc.)
feature_names = Column(JSON) # List of feature names
strategy_names = Column(JSON) # List of strategy names
training_samples = Column(Integer) # Number of training samples
trained_at = Column(DateTime, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
class Database:
"""Database connection manager."""
def __init__(self):
"""Initialize database connection."""
self.config = get_config()
self.engine = self._create_engine()
self.SessionLocal = async_sessionmaker(
bind=self.engine,
class_=AsyncSession,
expire_on_commit=False
)
# self._create_tables() # Tables should be created via alembic or separate init script in async
def _create_engine(self):
"""Create database engine."""
db_type = self.config.get("database.type", "postgresql")
if db_type == "postgresql":
db_url = self.config.get("database.url")
if not db_url:
raise ValueError("PostgreSQL URL not configured")
# Ensure URL uses async driver (e.g. postgresql+asyncpg)
if "postgresql://" in db_url and "postgresql+asyncpg://" not in db_url:
# This is a naive replacement, in production we should handle this better
db_url = db_url.replace("postgresql://", "postgresql+asyncpg://")
# Add connection timeout to prevent hanging
# asyncpg connect timeout is set via connect_timeout in connect_args
return create_async_engine(
db_url,
echo=False,
connect_args={
"server_settings": {"application_name": "crypto_trader"},
"timeout": 5, # 5 second connection timeout
},
pool_pre_ping=True, # Verify connections before using
pool_recycle=3600, # Recycle connections after 1 hour
pool_timeout=5, # Timeout when getting connection from pool
)
else:
raise ValueError(f"Unsupported database type: {db_type}. Only 'postgresql' is supported.")
async def create_tables(self):
"""Create all database tables."""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
def get_session(self) -> AsyncSession:
"""Get a database session."""
return self.SessionLocal()
async def close(self):
"""Close database connection."""
await self.engine.dispose()
# Global database instance
_db_instance: Optional[Database] = None
def get_database() -> Database:
"""Get global database instance."""
global _db_instance
if _db_instance is None:
_db_instance = Database()
return _db_instance

128
src/core/logger.py Normal file
View File

@@ -0,0 +1,128 @@
"""Configurable logging system with XDG directory support."""
import logging
import logging.handlers
import yaml
from pathlib import Path
from typing import Optional
from .config import get_config
class LoggingConfig:
"""Logging configuration manager."""
def __init__(self):
"""Initialize logging configuration."""
self.config = get_config()
self.log_dir = Path(self.config.get("logging.dir", "~/.local/share/crypto_trader/logs")).expanduser()
self.log_dir.mkdir(parents=True, exist_ok=True)
self.retention_days = self.config.get("logging.retention_days", 30)
self._setup_logging()
def _setup_logging(self):
"""Set up logging configuration."""
log_level = self.config.get("logging.level", "INFO")
level = getattr(logging, log_level.upper(), logging.INFO)
# Root logger
root_logger = logging.getLogger()
root_logger.setLevel(level)
# Clear existing handlers
root_logger.handlers.clear()
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
console_formatter = logging.Formatter(
'%(asctime)s [%(levelname)s] %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
# File handler with rotation
log_file = self.log_dir / "crypto_trader.log"
file_handler = logging.handlers.TimedRotatingFileHandler(
log_file,
when='midnight',
interval=1,
backupCount=self.retention_days,
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s [%(levelname)s] %(name)s:%(lineno)d: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_formatter)
root_logger.addHandler(file_handler)
# Compress old logs
self._setup_log_compression()
def _setup_log_compression(self):
"""Set up log compression for old log files."""
import gzip
import glob
# Compress logs older than retention period
log_files = list(self.log_dir.glob("crypto_trader.log.*"))
for log_file in log_files:
if not log_file.name.endswith('.gz'):
try:
with open(log_file, 'rb') as f_in:
with gzip.open(f"{log_file}.gz", 'wb') as f_out:
f_out.writelines(f_in)
log_file.unlink()
except Exception:
pass # Skip if compression fails
def get_logger(self, name: str) -> logging.Logger:
"""Get a logger with the specified name.
Args:
name: Logger name (typically module name)
Returns:
Logger instance
"""
logger = logging.getLogger(name)
return logger
def set_level(self, level: str):
"""Set logging level.
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
"""
log_level = getattr(logging, level.upper(), logging.INFO)
logging.getLogger().setLevel(log_level)
for handler in logging.getLogger().handlers:
handler.setLevel(log_level)
# Global logging config instance
_logging_config: Optional[LoggingConfig] = None
def get_logger(name: str) -> logging.Logger:
"""Get a logger instance.
Args:
name: Logger name (typically __name__)
Returns:
Logger instance
"""
global _logging_config
if _logging_config is None:
_logging_config = LoggingConfig()
return _logging_config.get_logger(name)
def setup_logging():
"""Set up logging system."""
global _logging_config
_logging_config = LoggingConfig()

261
src/core/pubsub.py Normal file
View File

@@ -0,0 +1,261 @@
"""Redis Pub/Sub for real-time event broadcasting across workers."""
import asyncio
import json
from typing import Callable, Dict, Any, Optional, List
from src.core.redis import get_redis_client
from src.core.logger import get_logger
logger = get_logger(__name__)
# Channel names
CHANNEL_MARKET_EVENTS = "crypto_trader:market_events"
CHANNEL_TRADE_EVENTS = "crypto_trader:trade_events"
CHANNEL_SYSTEM_EVENTS = "crypto_trader:system_events"
CHANNEL_AUTOPILOT_EVENTS = "crypto_trader:autopilot_events"
class RedisPubSub:
"""Redis Pub/Sub handler for real-time event broadcasting."""
def __init__(self):
"""Initialize Redis Pub/Sub."""
self.redis = get_redis_client()
self._subscribers: Dict[str, List[Callable]] = {}
self._pubsub = None
self._running = False
self._listen_task: Optional[asyncio.Task] = None
async def publish(self, channel: str, event_type: str, data: Dict[str, Any]) -> int:
"""Publish an event to a channel.
Args:
channel: Channel name
event_type: Type of event (e.g., 'price_update', 'trade_executed')
data: Event data
Returns:
Number of subscribers that received the message
"""
message = {
"type": event_type,
"data": data,
"timestamp": asyncio.get_event_loop().time()
}
try:
client = self.redis.get_client()
count = await client.publish(channel, json.dumps(message))
logger.debug(f"Published {event_type} to {channel} ({count} subscribers)")
return count
except Exception as e:
logger.error(f"Failed to publish to {channel}: {e}")
return 0
async def subscribe(self, channel: str, callback: Callable[[Dict[str, Any]], None]) -> None:
"""Subscribe to a channel.
Args:
channel: Channel name
callback: Async function to call when message received
"""
if channel not in self._subscribers:
self._subscribers[channel] = []
self._subscribers[channel].append(callback)
logger.info(f"Subscribed to channel: {channel}")
# Start listener if not running
if not self._running:
await self._start_listener()
async def unsubscribe(self, channel: str, callback: Callable = None) -> None:
"""Unsubscribe from a channel.
Args:
channel: Channel name
callback: Specific callback to remove, or None to remove all
"""
if channel in self._subscribers:
if callback:
self._subscribers[channel] = [c for c in self._subscribers[channel] if c != callback]
else:
del self._subscribers[channel]
logger.info(f"Unsubscribed from channel: {channel}")
async def _start_listener(self) -> None:
"""Start the Pub/Sub listener."""
if self._running:
return
self._running = True
self._listen_task = asyncio.create_task(self._listen())
logger.info("Started Redis Pub/Sub listener")
async def _listen(self) -> None:
"""Listen for messages on subscribed channels."""
try:
client = self.redis.get_client()
self._pubsub = client.pubsub()
# Subscribe to all registered channels
channels = list(self._subscribers.keys())
if channels:
await self._pubsub.subscribe(*channels)
while self._running:
try:
message = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message and message['type'] == 'message':
channel = message['channel']
if isinstance(channel, bytes):
channel = channel.decode('utf-8')
data = message['data']
if isinstance(data, bytes):
data = data.decode('utf-8')
try:
parsed = json.loads(data)
except json.JSONDecodeError:
parsed = {"raw": data}
# Call all subscribers for this channel
callbacks = self._subscribers.get(channel, [])
for callback in callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(parsed)
else:
callback(parsed)
except Exception as e:
logger.error(f"Subscriber callback error: {e}")
await asyncio.sleep(0.01) # Prevent busy loop
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Pub/Sub listener error: {e}")
await asyncio.sleep(1)
finally:
if self._pubsub:
await self._pubsub.close()
self._running = False
async def stop(self) -> None:
"""Stop the Pub/Sub listener."""
self._running = False
if self._listen_task:
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
logger.info("Stopped Redis Pub/Sub listener")
# Convenience methods for common event types
async def publish_price_update(self, symbol: str, price: float, bid: float = None, ask: float = None) -> int:
"""Publish a price update event.
Args:
symbol: Trading symbol
price: Current price
bid: Bid price
ask: Ask price
Returns:
Number of subscribers notified
"""
return await self.publish(CHANNEL_MARKET_EVENTS, "price_update", {
"symbol": symbol,
"price": price,
"bid": bid,
"ask": ask
})
async def publish_trade_executed(
self,
symbol: str,
side: str,
quantity: float,
price: float,
order_id: str = None
) -> int:
"""Publish a trade execution event.
Args:
symbol: Trading symbol
side: 'buy' or 'sell'
quantity: Trade quantity
price: Execution price
order_id: Order ID
Returns:
Number of subscribers notified
"""
return await self.publish(CHANNEL_TRADE_EVENTS, "trade_executed", {
"symbol": symbol,
"side": side,
"quantity": quantity,
"price": price,
"order_id": order_id
})
async def publish_autopilot_status(
self,
symbol: str,
status: str,
action: str = None,
details: Dict[str, Any] = None
) -> int:
"""Publish an autopilot status event.
Args:
symbol: Trading symbol
status: 'started', 'stopped', 'error', 'signal'
action: Optional action taken
details: Additional details
Returns:
Number of subscribers notified
"""
return await self.publish(CHANNEL_AUTOPILOT_EVENTS, "autopilot_status", {
"symbol": symbol,
"status": status,
"action": action,
"details": details or {}
})
async def publish_system_event(self, event_type: str, message: str, severity: str = "info") -> int:
"""Publish a system event.
Args:
event_type: Event type (e.g., 'startup', 'shutdown', 'error')
message: Event message
severity: 'info', 'warning', 'error'
Returns:
Number of subscribers notified
"""
return await self.publish(CHANNEL_SYSTEM_EVENTS, event_type, {
"message": message,
"severity": severity
})
# Global Pub/Sub instance
_redis_pubsub: Optional[RedisPubSub] = None
def get_redis_pubsub() -> RedisPubSub:
"""Get global Redis Pub/Sub instance."""
global _redis_pubsub
if _redis_pubsub is None:
_redis_pubsub = RedisPubSub()
return _redis_pubsub

70
src/core/redis.py Normal file
View File

@@ -0,0 +1,70 @@
"""Redis client wrapper."""
import redis.asyncio as redis
from typing import Optional
from src.core.config import get_config
from src.core.logger import get_logger
logger = get_logger(__name__)
class RedisClient:
"""Redis client wrapper with automatic connection handling."""
def __init__(self):
"""Initialize Redis client."""
self.config = get_config()
self._client: Optional[redis.Redis] = None
self._pool: Optional[redis.ConnectionPool] = None
def get_client(self) -> redis.Redis:
"""Get or create Redis client.
Returns:
Async Redis client
"""
if self._client is None:
self._connect()
return self._client
def _connect(self):
"""Connect to Redis."""
redis_config = self.config.get("redis", {})
host = redis_config.get("host", "localhost")
port = redis_config.get("port", 6379)
db = redis_config.get("db", 0)
password = redis_config.get("password")
logger.info(f"Connecting to Redis at {host}:{port}/{db}")
try:
self._pool = redis.ConnectionPool(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
socket_connect_timeout=redis_config.get("socket_connect_timeout", 5)
)
self._client = redis.Redis(connection_pool=self._pool)
except Exception as e:
logger.error(f"Failed to create Redis client: {e}")
raise
async def close(self):
"""Close Redis connection."""
if self._client:
await self._client.aclose()
logger.info("Redis connection closed")
# Global instance
_redis_client: Optional[RedisClient] = None
def get_redis_client() -> RedisClient:
"""Get global Redis client instance."""
global _redis_client
if _redis_client is None:
_redis_client = RedisClient()
return _redis_client

99
src/core/repositories.py Normal file
View File

@@ -0,0 +1,99 @@
"""Database repositories for data access."""
from typing import Optional, List, Sequence
from datetime import datetime
from decimal import Decimal
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from .database import Order, Position, OrderStatus, OrderSide, OrderType, MarketData
class BaseRepository:
"""Base repository."""
def __init__(self, session: AsyncSession):
"""Initialize repository."""
self.session = session
class OrderRepository(BaseRepository):
"""Order repository."""
async def create(self, order: Order) -> Order:
"""Create new order."""
self.session.add(order)
await self.session.commit()
await self.session.refresh(order)
return order
async def get_by_id(self, order_id: int) -> Optional[Order]:
"""Get order by ID."""
result = await self.session.execute(
select(Order).where(Order.id == order_id)
)
return result.scalar_one_or_none()
async def get_all(self, limit: int = 100, offset: int = 0) -> Sequence[Order]:
"""Get all orders."""
result = await self.session.execute(
select(Order).limit(limit).offset(offset).order_by(Order.created_at.desc())
)
return result.scalars().all()
async def update_status(
self,
order_id: int,
status: OrderStatus,
exchange_order_id: Optional[str] = None,
fee: Optional[Decimal] = None
) -> Optional[Order]:
"""Update order status."""
values = {"status": status, "updated_at": datetime.utcnow()}
if exchange_order_id:
values["exchange_order_id"] = exchange_order_id
if fee is not None:
values["fee"] = fee
await self.session.execute(
update(Order)
.where(Order.id == order_id)
.values(**values)
)
await self.session.commit()
return await self.get_by_id(order_id)
async def get_open_orders(self, paper_trading: bool = True) -> Sequence[Order]:
"""Get open orders."""
result = await self.session.execute(
select(Order).where(
Order.paper_trading == paper_trading,
Order.status.in_([OrderStatus.PENDING, OrderStatus.OPEN, OrderStatus.PARTIALLY_FILLED])
)
)
return result.scalars().all()
async def delete(self, order_id: int) -> bool:
"""Delete order."""
result = await self.session.execute(
delete(Order).where(Order.id == order_id)
)
await self.session.commit()
return result.rowcount > 0
class PositionRepository(BaseRepository):
"""Position repository."""
async def get_all(self, paper_trading: bool = True) -> Sequence[Position]:
"""Get all positions."""
result = await self.session.execute(
select(Position).where(Position.paper_trading == paper_trading)
)
return result.scalars().all()
async def get_by_symbol(self, symbol: str, paper_trading: bool = True) -> Optional[Position]:
"""Get position by symbol."""
result = await self.session.execute(
select(Position).where(
Position.symbol == symbol,
Position.paper_trading == paper_trading
)
)
return result.scalar_one_or_none()