feat: Add core trading modules for risk management, backtesting, and execution algorithms, alongside a new ML transparency widget and related frontend dependencies.
Some checks are pending
Documentation / build-docs (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.13) (push) Waiting to run
Tests / test (macos-latest, 3.14) (push) Waiting to run
Tests / test (ubuntu-latest, 3.11) (push) Waiting to run
Tests / test (ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (ubuntu-latest, 3.13) (push) Waiting to run
Tests / test (ubuntu-latest, 3.14) (push) Waiting to run
Some checks are pending
Documentation / build-docs (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.13) (push) Waiting to run
Tests / test (macos-latest, 3.14) (push) Waiting to run
Tests / test (ubuntu-latest, 3.11) (push) Waiting to run
Tests / test (ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (ubuntu-latest, 3.13) (push) Waiting to run
Tests / test (ubuntu-latest, 3.14) (push) Waiting to run
This commit is contained in:
225
src/autopilot/confidence_calibration.py
Normal file
225
src/autopilot/confidence_calibration.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Confidence calibration using Platt scaling and isotonic regression."""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.isotonic import IsotonicRegression
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from src.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
HAS_CALIBRATION = True
|
||||
except ImportError:
|
||||
HAS_CALIBRATION = False
|
||||
logger.warning("sklearn.calibration not available, using basic calibration")
|
||||
|
||||
|
||||
class ConfidenceCalibrator:
|
||||
"""Calibrates model confidence scores using Platt scaling or isotonic regression."""
|
||||
|
||||
def __init__(self, method: str = "isotonic"):
|
||||
"""Initialize confidence calibrator.
|
||||
|
||||
Args:
|
||||
method: Calibration method ('platt' or 'isotonic')
|
||||
"""
|
||||
self.method = method
|
||||
self.calibrator = None
|
||||
self.is_fitted = False
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
probabilities: np.ndarray,
|
||||
true_labels: np.ndarray
|
||||
):
|
||||
"""Fit calibrator on validation data.
|
||||
|
||||
Args:
|
||||
probabilities: Predicted probabilities (n_samples, n_classes)
|
||||
true_labels: True class labels (n_samples,)
|
||||
"""
|
||||
try:
|
||||
# Get maximum probability (confidence) for each sample
|
||||
confidences = np.max(probabilities, axis=1)
|
||||
|
||||
# Binary indicator: 1 if prediction was correct, 0 otherwise
|
||||
predictions = np.argmax(probabilities, axis=1)
|
||||
correctness = (predictions == true_labels).astype(float)
|
||||
|
||||
if self.method == "isotonic":
|
||||
self.calibrator = IsotonicRegression(out_of_bounds='clip')
|
||||
elif self.method == "platt":
|
||||
self.calibrator = LogisticRegression()
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration method: {self.method}")
|
||||
|
||||
# Fit calibrator
|
||||
self.calibrator.fit(confidences.reshape(-1, 1), correctness)
|
||||
self.is_fitted = True
|
||||
|
||||
self.logger.info(f"Confidence calibrator fitted using {self.method} method")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to fit confidence calibrator: {e}")
|
||||
self.is_fitted = False
|
||||
|
||||
def calibrate(self, confidence: float) -> float:
|
||||
"""Calibrate a confidence score.
|
||||
|
||||
Args:
|
||||
confidence: Uncalibrated confidence (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Calibrated confidence (0.0 to 1.0)
|
||||
"""
|
||||
if not self.is_fitted or self.calibrator is None:
|
||||
return confidence
|
||||
|
||||
try:
|
||||
calibrated = self.calibrator.predict(np.array([[confidence]]))[0]
|
||||
|
||||
# Clip to valid range
|
||||
calibrated = max(0.0, min(1.0, calibrated))
|
||||
|
||||
return float(calibrated)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Calibration failed, returning original: {e}")
|
||||
return confidence
|
||||
|
||||
def calibrate_probabilities(self, probabilities: np.ndarray) -> np.ndarray:
|
||||
"""Calibrate a probability array.
|
||||
|
||||
Args:
|
||||
probabilities: Predicted probabilities (n_samples, n_classes) or (n_classes,)
|
||||
|
||||
Returns:
|
||||
Calibrated probabilities with same shape
|
||||
"""
|
||||
if not self.is_fitted or self.calibrator is None:
|
||||
return probabilities
|
||||
|
||||
try:
|
||||
# Handle 1D and 2D arrays
|
||||
is_1d = probabilities.ndim == 1
|
||||
if is_1d:
|
||||
probabilities = probabilities.reshape(1, -1)
|
||||
|
||||
# Calibrate each row
|
||||
calibrated = probabilities.copy()
|
||||
for i in range(len(probabilities)):
|
||||
max_idx = np.argmax(probabilities[i])
|
||||
max_prob = probabilities[i, max_idx]
|
||||
|
||||
calibrated_max = self.calibrate(max_prob)
|
||||
|
||||
# Rescale probabilities to maintain relative ratios
|
||||
if max_prob > 0:
|
||||
scale_factor = calibrated_max / max_prob
|
||||
calibrated[i] = probabilities[i] * scale_factor
|
||||
|
||||
# Normalize to sum to 1
|
||||
calibrated[i] = calibrated[i] / np.sum(calibrated[i])
|
||||
|
||||
if is_1d:
|
||||
calibrated = calibrated[0]
|
||||
|
||||
return calibrated
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Probability calibration failed: {e}")
|
||||
return probabilities
|
||||
|
||||
|
||||
class ConfidenceCalibrationManager:
|
||||
"""Manages confidence calibration for strategy selector models."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize calibration manager."""
|
||||
self.calibrator = ConfidenceCalibrator(method="isotonic")
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def calibrate_prediction(
|
||||
self,
|
||||
strategy_name: str,
|
||||
confidence: float,
|
||||
all_predictions: Dict[str, float]
|
||||
) -> Tuple[str, float, Dict[str, float]]:
|
||||
"""Calibrate prediction confidence.
|
||||
|
||||
Args:
|
||||
strategy_name: Predicted strategy name
|
||||
confidence: Uncalibrated confidence
|
||||
all_predictions: All prediction probabilities
|
||||
|
||||
Returns:
|
||||
Tuple of (strategy_name, calibrated_confidence, calibrated_predictions)
|
||||
"""
|
||||
if not self.calibrator.is_fitted:
|
||||
# Return uncalibrated if not fitted
|
||||
return strategy_name, confidence, all_predictions
|
||||
|
||||
# Calibrate main confidence
|
||||
calibrated_confidence = self.calibrator.calibrate(confidence)
|
||||
|
||||
# Calibrate all predictions
|
||||
calibrated_predictions = {}
|
||||
for name, prob in all_predictions.items():
|
||||
calibrated_predictions[name] = self.calibrator.calibrate(prob)
|
||||
|
||||
return strategy_name, calibrated_confidence, calibrated_predictions
|
||||
|
||||
def fit_from_validation_data(
|
||||
self,
|
||||
predicted_probs: List[Dict[str, float]],
|
||||
true_labels: List[str]
|
||||
):
|
||||
"""Fit calibrator from validation data.
|
||||
|
||||
Args:
|
||||
predicted_probs: List of prediction probability dictionaries
|
||||
true_labels: List of true strategy labels
|
||||
"""
|
||||
try:
|
||||
# Convert to numpy arrays
|
||||
all_strategies = list(set(true_labels))
|
||||
n_samples = len(predicted_probs)
|
||||
n_classes = len(all_strategies)
|
||||
|
||||
prob_matrix = np.zeros((n_samples, n_classes))
|
||||
label_indices = {name: i for i, name in enumerate(all_strategies)}
|
||||
|
||||
for i, probs in enumerate(predicted_probs):
|
||||
for strategy, prob in probs.items():
|
||||
if strategy in label_indices:
|
||||
prob_matrix[i, label_indices[strategy]] = prob
|
||||
|
||||
true_indices = np.array([label_indices[label] for label in true_labels])
|
||||
|
||||
# Fit calibrator
|
||||
self.calibrator.fit(prob_matrix, true_indices)
|
||||
|
||||
self.logger.info("Confidence calibrator fitted from validation data")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to fit calibrator from validation data: {e}")
|
||||
|
||||
|
||||
# Global calibration manager
|
||||
_calibration_manager: Optional[ConfidenceCalibrationManager] = None
|
||||
|
||||
|
||||
def get_confidence_calibration_manager() -> ConfidenceCalibrationManager:
|
||||
"""Get global confidence calibration manager instance.
|
||||
|
||||
Returns:
|
||||
ConfidenceCalibrationManager instance
|
||||
"""
|
||||
global _calibration_manager
|
||||
if _calibration_manager is None:
|
||||
_calibration_manager = ConfidenceCalibrationManager()
|
||||
return _calibration_manager
|
||||
222
src/autopilot/explainability.py
Normal file
222
src/autopilot/explainability.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Model explainability using SHAP values and feature importance."""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from src.core.logger import get_logger
|
||||
from .models import StrategySelectorModel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
import shap
|
||||
HAS_SHAP = True
|
||||
except ImportError:
|
||||
HAS_SHAP = False
|
||||
logger.warning("SHAP not installed. Install with: pip install shap")
|
||||
|
||||
|
||||
class ModelExplainer:
|
||||
"""Provides model explainability using SHAP values."""
|
||||
|
||||
def __init__(self, model: StrategySelectorModel):
|
||||
"""Initialize model explainer.
|
||||
|
||||
Args:
|
||||
model: Strategy selector model
|
||||
"""
|
||||
self.model = model
|
||||
self.logger = get_logger(__name__)
|
||||
self.explainer = None
|
||||
self.shap_values_cache: Optional[np.ndarray] = None
|
||||
|
||||
def initialize_explainer(self, background_data: pd.DataFrame):
|
||||
"""Initialize SHAP explainer with background data.
|
||||
|
||||
Args:
|
||||
background_data: Background dataset for SHAP (sample of training data)
|
||||
"""
|
||||
if not HAS_SHAP:
|
||||
self.logger.warning("SHAP not available, explainability disabled")
|
||||
return
|
||||
|
||||
if not self.model.is_trained or self.model.best_model is None:
|
||||
self.logger.warning("Model not trained, cannot initialize explainer")
|
||||
return
|
||||
|
||||
try:
|
||||
# Use TreeExplainer for tree-based models, else KernelExplainer
|
||||
model_type = self.model.best_model_name.lower()
|
||||
|
||||
if 'forest' in model_type or 'xgboost' in model_type or 'lightgbm' in model_type:
|
||||
# Tree-based models
|
||||
self.explainer = shap.TreeExplainer(self.model.best_model)
|
||||
else:
|
||||
# Other models - use KernelExplainer with background data
|
||||
# Limit background data size for performance
|
||||
if len(background_data) > 100:
|
||||
background_data = background_data.sample(100, random_state=42)
|
||||
|
||||
def model_predict(X):
|
||||
X_scaled = self.model.scaler.transform(X)
|
||||
return self.model.best_model.predict_proba(X_scaled)
|
||||
|
||||
self.explainer = shap.KernelExplainer(
|
||||
model_predict,
|
||||
background_data
|
||||
)
|
||||
|
||||
self.logger.info("SHAP explainer initialized")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize SHAP explainer: {e}")
|
||||
self.explainer = None
|
||||
|
||||
def explain_prediction(
|
||||
self,
|
||||
features: Dict[str, float]
|
||||
) -> Dict[str, Any]:
|
||||
"""Explain a single prediction using SHAP values.
|
||||
|
||||
Args:
|
||||
features: Market condition features
|
||||
|
||||
Returns:
|
||||
Dictionary with SHAP values and explanations
|
||||
"""
|
||||
if self.explainer is None:
|
||||
return {"error": "Explainer not initialized"}
|
||||
|
||||
try:
|
||||
# Convert features to DataFrame
|
||||
feature_df = pd.DataFrame([features])
|
||||
|
||||
# Ensure all required features are present
|
||||
for feat in self.model.feature_names:
|
||||
if feat not in feature_df.columns:
|
||||
feature_df[feat] = 0.0
|
||||
|
||||
X = feature_df[self.model.feature_names]
|
||||
|
||||
# Get SHAP values
|
||||
if isinstance(self.explainer, shap.TreeExplainer):
|
||||
shap_values = self.explainer.shap_values(X)[0]
|
||||
else:
|
||||
shap_values = self.explainer.shap_values(X.iloc[0].values)[0]
|
||||
|
||||
# Create feature importance dictionary
|
||||
feature_importance = {
|
||||
feat: float(shap_val)
|
||||
for feat, shap_val in zip(self.model.feature_names, shap_values)
|
||||
}
|
||||
|
||||
# Sort by absolute importance
|
||||
sorted_importance = sorted(
|
||||
feature_importance.items(),
|
||||
key=lambda x: abs(x[1]),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Get top contributing features (positive and negative)
|
||||
top_positive = [
|
||||
{"feature": name, "value": val, "shap": shap_val}
|
||||
for name, (shap_val, val) in zip(
|
||||
self.model.feature_names,
|
||||
zip(shap_values, X.iloc[0].values)
|
||||
)
|
||||
if shap_val > 0
|
||||
][:5]
|
||||
|
||||
top_negative = [
|
||||
{"feature": name, "value": val, "shap": shap_val}
|
||||
for name, (shap_val, val) in zip(
|
||||
self.model.feature_names,
|
||||
zip(shap_values, X.iloc[0].values)
|
||||
)
|
||||
if shap_val < 0
|
||||
][:5]
|
||||
|
||||
top_positive.sort(key=lambda x: x["shap"], reverse=True)
|
||||
top_negative.sort(key=lambda x: x["shap"])
|
||||
|
||||
return {
|
||||
"feature_importance": feature_importance,
|
||||
"sorted_importance": [
|
||||
{"feature": name, "shap_value": float(val)}
|
||||
for name, val in sorted_importance
|
||||
],
|
||||
"top_positive_features": top_positive,
|
||||
"top_negative_features": top_negative,
|
||||
"shap_values": [float(v) for v in shap_values]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to explain prediction: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def get_global_feature_importance(self) -> Dict[str, float]:
|
||||
"""Get global feature importance from model.
|
||||
|
||||
Returns:
|
||||
Dictionary of feature names to importance scores
|
||||
"""
|
||||
# First try SHAP if available
|
||||
if HAS_SHAP and self.explainer is not None:
|
||||
try:
|
||||
# Use feature importance from model as fallback
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fall back to model's built-in feature importance
|
||||
return self.model.get_feature_importance()
|
||||
|
||||
def explain_feature_contributions(
|
||||
self,
|
||||
features: Dict[str, float],
|
||||
target_strategy: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Explain how features contribute to a specific strategy prediction.
|
||||
|
||||
Args:
|
||||
features: Market condition features
|
||||
target_strategy: Strategy name to explain
|
||||
|
||||
Returns:
|
||||
Feature contribution explanation
|
||||
"""
|
||||
explanation = self.explain_prediction(features)
|
||||
|
||||
if "error" in explanation:
|
||||
return explanation
|
||||
|
||||
# Filter to show contribution to target strategy
|
||||
strategy_contributions = {
|
||||
feat: contrib
|
||||
for feat, contrib in explanation["feature_importance"].items()
|
||||
}
|
||||
|
||||
return {
|
||||
"target_strategy": target_strategy,
|
||||
"feature_contributions": strategy_contributions,
|
||||
"explanation": explanation
|
||||
}
|
||||
|
||||
|
||||
# Global explainer cache
|
||||
_explainers: Dict[str, ModelExplainer] = {}
|
||||
|
||||
|
||||
def get_model_explainer(model: StrategySelectorModel, model_id: str = "default") -> ModelExplainer:
|
||||
"""Get model explainer instance.
|
||||
|
||||
Args:
|
||||
model: Strategy selector model
|
||||
model_id: Unique identifier for the model
|
||||
|
||||
Returns:
|
||||
ModelExplainer instance
|
||||
"""
|
||||
if model_id not in _explainers:
|
||||
_explainers[model_id] = ModelExplainer(model)
|
||||
return _explainers[model_id]
|
||||
273
src/autopilot/feature_engineering.py
Normal file
273
src/autopilot/feature_engineering.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Advanced feature engineering for ML models."""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Any
|
||||
from src.core.logger import get_logger
|
||||
from src.data.indicators import get_indicators
|
||||
from .market_analyzer import MarketAnalyzer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FeatureEngineering:
|
||||
"""Advanced feature engineering for autopilot ML models."""
|
||||
|
||||
def __init__(self, market_analyzer: MarketAnalyzer):
|
||||
"""Initialize feature engineering.
|
||||
|
||||
Args:
|
||||
market_analyzer: MarketAnalyzer instance
|
||||
"""
|
||||
self.market_analyzer = market_analyzer
|
||||
self.indicators = get_indicators()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def add_multi_timeframe_features(
|
||||
self,
|
||||
features: Dict[str, float],
|
||||
ohlcv_dataframes: Dict[str, pd.DataFrame]
|
||||
) -> Dict[str, float]:
|
||||
"""Add features aggregated across multiple timeframes.
|
||||
|
||||
Args:
|
||||
features: Base features dictionary
|
||||
ohlcv_dataframes: Dictionary of {timeframe: DataFrame} (e.g., {'15m': df, '1h': df, '4h': df})
|
||||
|
||||
Returns:
|
||||
Enhanced features dictionary with multi-timeframe features
|
||||
"""
|
||||
enhanced_features = features.copy()
|
||||
|
||||
# Aggregate features from higher timeframes (trend confirmation)
|
||||
for timeframe, df in ohlcv_dataframes.items():
|
||||
if len(df) < 20:
|
||||
continue
|
||||
|
||||
tf_features = self.market_analyzer.extract_features(df)
|
||||
|
||||
# Prefix features with timeframe
|
||||
for key, value in tf_features.items():
|
||||
enhanced_features[f'{timeframe}_{key}'] = value
|
||||
|
||||
# Cross-timeframe features
|
||||
if '1h' in ohlcv_dataframes and timeframe != '1h':
|
||||
# Compare current timeframe with 1h
|
||||
for key in ['sma_20', 'rsi', 'adx']:
|
||||
if f'{timeframe}_{key}' in enhanced_features and f'1h_{key}' in enhanced_features:
|
||||
enhanced_features[f'{timeframe}_vs_1h_{key}'] = (
|
||||
enhanced_features[f'{timeframe}_{key}'] - enhanced_features[f'1h_{key}']
|
||||
)
|
||||
|
||||
return enhanced_features
|
||||
|
||||
def add_feature_interactions(
|
||||
self,
|
||||
features: Dict[str, float]
|
||||
) -> Dict[str, float]:
|
||||
"""Add feature interaction terms.
|
||||
|
||||
Args:
|
||||
features: Base features dictionary
|
||||
|
||||
Returns:
|
||||
Enhanced features dictionary with interaction terms
|
||||
"""
|
||||
enhanced_features = features.copy()
|
||||
|
||||
# RSI × Volume interaction
|
||||
if 'rsi' in features and 'volume_ratio' in features:
|
||||
enhanced_features['rsi_x_volume'] = features['rsi'] * features['volume_ratio']
|
||||
|
||||
# ADX × Volatility interaction
|
||||
if 'adx' in features and 'volatility' in features:
|
||||
enhanced_features['adx_x_volatility'] = features['adx'] * features['volatility']
|
||||
|
||||
# MACD × RSI interaction
|
||||
if 'macd' in features and 'rsi' in features:
|
||||
enhanced_features['macd_x_rsi'] = features['macd'] * features['rsi']
|
||||
|
||||
# Price vs MA ratios
|
||||
if 'price_vs_sma20' in features and 'price_vs_sma50' in features:
|
||||
enhanced_features['price_vs_ma_ratio'] = (
|
||||
features['price_vs_sma20'] / features['price_vs_sma50']
|
||||
if features['price_vs_sma50'] != 0 else 0.0
|
||||
)
|
||||
|
||||
# RSI/MACD ratio
|
||||
if 'rsi' in features and 'macd_signal' in features and features['macd_signal'] != 0:
|
||||
enhanced_features['rsi_div_macd'] = features['rsi'] / abs(features['macd_signal'])
|
||||
|
||||
# Volume/Price ratio
|
||||
if 'volume_ratio' in features and 'price_vs_sma20' in features:
|
||||
enhanced_features['volume_price_interaction'] = (
|
||||
features['volume_ratio'] * abs(features['price_vs_sma20'])
|
||||
)
|
||||
|
||||
return enhanced_features
|
||||
|
||||
def add_regime_specific_features(
|
||||
self,
|
||||
features: Dict[str, float],
|
||||
regime: str
|
||||
) -> Dict[str, float]:
|
||||
"""Add regime-specific features.
|
||||
|
||||
Args:
|
||||
features: Base features dictionary
|
||||
regime: Market regime string
|
||||
|
||||
Returns:
|
||||
Enhanced features dictionary with regime-specific features
|
||||
"""
|
||||
enhanced_features = features.copy()
|
||||
|
||||
# Regime indicator features (one-hot encoding style)
|
||||
regime_types = ['trending_up', 'trending_down', 'ranging', 'high_volatility', 'low_volatility', 'breakout', 'reversal']
|
||||
for reg in regime_types:
|
||||
enhanced_features[f'regime_{reg}'] = 1.0 if regime == reg else 0.0
|
||||
|
||||
# Regime-normalized features
|
||||
if regime in ['trending_up', 'trending_down']:
|
||||
# In trending markets, emphasize trend strength
|
||||
if 'adx' in features:
|
||||
enhanced_features['regime_trend_adx'] = features['adx']
|
||||
elif regime == 'ranging':
|
||||
# In ranging markets, emphasize mean reversion signals
|
||||
if 'rsi' in features:
|
||||
enhanced_features['regime_range_rsi'] = features['rsi']
|
||||
elif regime in ['high_volatility', 'low_volatility']:
|
||||
# In volatility regimes, emphasize volatility features
|
||||
if 'volatility' in features:
|
||||
enhanced_features['regime_vol'] = features['volatility']
|
||||
|
||||
return enhanced_features
|
||||
|
||||
def add_lag_features(
|
||||
self,
|
||||
features: Dict[str, float],
|
||||
historical_features: List[Dict[str, float]],
|
||||
lags: List[int] = [1, 2, 3, 5]
|
||||
) -> Dict[str, float]:
|
||||
"""Add lagged features for time-series context.
|
||||
|
||||
Args:
|
||||
features: Current features dictionary
|
||||
historical_features: List of previous feature dictionaries (most recent last)
|
||||
lags: List of lag periods to include
|
||||
|
||||
Returns:
|
||||
Enhanced features dictionary with lag features
|
||||
"""
|
||||
enhanced_features = features.copy()
|
||||
|
||||
if not historical_features:
|
||||
return enhanced_features
|
||||
|
||||
# Key features to lag
|
||||
key_features = ['rsi', 'macd', 'adx', 'volatility', 'price_vs_sma20']
|
||||
|
||||
for lag in lags:
|
||||
if len(historical_features) >= lag:
|
||||
lagged_features = historical_features[-lag]
|
||||
for key in key_features:
|
||||
if key in lagged_features:
|
||||
enhanced_features[f'{key}_lag_{lag}'] = lagged_features[key]
|
||||
else:
|
||||
enhanced_features[f'{key}_lag_{lag}'] = 0.0
|
||||
else:
|
||||
# Fill with 0 if not enough history
|
||||
for key in key_features:
|
||||
enhanced_features[f'{key}_lag_{lag}'] = 0.0
|
||||
|
||||
# Feature changes (deltas)
|
||||
if len(historical_features) >= 1:
|
||||
prev_features = historical_features[-1]
|
||||
for key in key_features:
|
||||
if key in features and key in prev_features:
|
||||
enhanced_features[f'{key}_delta'] = features[key] - prev_features[key]
|
||||
else:
|
||||
enhanced_features[f'{key}_delta'] = 0.0
|
||||
|
||||
return enhanced_features
|
||||
|
||||
def select_features(
|
||||
self,
|
||||
features: Dict[str, float],
|
||||
feature_importance: Optional[Dict[str, float]] = None,
|
||||
top_n: Optional[int] = None
|
||||
) -> Dict[str, float]:
|
||||
"""Select most important features.
|
||||
|
||||
Args:
|
||||
features: Full features dictionary
|
||||
feature_importance: Dictionary of feature importance scores
|
||||
top_n: Number of top features to keep (None = keep all)
|
||||
|
||||
Returns:
|
||||
Selected features dictionary
|
||||
"""
|
||||
if feature_importance is None or top_n is None:
|
||||
return features
|
||||
|
||||
# Sort features by importance
|
||||
sorted_features = sorted(
|
||||
feature_importance.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Keep top N features
|
||||
top_features = {name: features.get(name, 0.0) for name, _ in sorted_features[:top_n]}
|
||||
|
||||
# Always keep base features even if not in top N
|
||||
base_features = ['rsi', 'macd', 'adx', 'volatility', 'price_vs_sma20']
|
||||
for base in base_features:
|
||||
if base in features:
|
||||
top_features[base] = features[base]
|
||||
|
||||
return top_features
|
||||
|
||||
def engineer_features(
|
||||
self,
|
||||
ohlcv_data: pd.DataFrame,
|
||||
multi_timeframe_data: Optional[Dict[str, pd.DataFrame]] = None,
|
||||
regime: Optional[str] = None,
|
||||
historical_features: Optional[List[Dict[str, float]]] = None,
|
||||
feature_importance: Optional[Dict[str, float]] = None
|
||||
) -> Dict[str, float]:
|
||||
"""Comprehensive feature engineering pipeline.
|
||||
|
||||
Args:
|
||||
ohlcv_data: Primary timeframe OHLCV data
|
||||
multi_timeframe_data: Multi-timeframe data dictionary
|
||||
regime: Market regime
|
||||
historical_features: Historical feature snapshots
|
||||
feature_importance: Feature importance scores for selection
|
||||
|
||||
Returns:
|
||||
Engineered features dictionary
|
||||
"""
|
||||
# Start with base features
|
||||
features = self.market_analyzer.extract_features(ohlcv_data)
|
||||
|
||||
# Add multi-timeframe features
|
||||
if multi_timeframe_data:
|
||||
features = self.add_multi_timeframe_features(features, multi_timeframe_data)
|
||||
|
||||
# Add feature interactions
|
||||
features = self.add_feature_interactions(features)
|
||||
|
||||
# Add regime-specific features
|
||||
if regime:
|
||||
features = self.add_regime_specific_features(features, regime)
|
||||
|
||||
# Add lag features
|
||||
if historical_features:
|
||||
features = self.add_lag_features(features, historical_features)
|
||||
|
||||
# Feature selection (if importance provided)
|
||||
if feature_importance:
|
||||
features = self.select_features(features, feature_importance, top_n=50)
|
||||
|
||||
return features
|
||||
253
src/autopilot/online_learning.py
Normal file
253
src/autopilot/online_learning.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Online learning pipeline with incremental model updates and concept drift detection."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from src.core.logger import get_logger
|
||||
from src.core.config import get_config
|
||||
from .models import StrategySelectorModel
|
||||
from .performance_tracker import get_performance_tracker
|
||||
from .market_analyzer import MarketConditions
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ConceptDriftDetector:
|
||||
"""Detects concept drift in model performance."""
|
||||
|
||||
def __init__(self, window_size: int = 100, drift_threshold: float = 0.1):
|
||||
"""Initialize drift detector.
|
||||
|
||||
Args:
|
||||
window_size: Size of performance window
|
||||
drift_threshold: Threshold for detecting drift (accuracy drop)
|
||||
"""
|
||||
self.window_size = window_size
|
||||
self.drift_threshold = drift_threshold
|
||||
self.performance_history: List[float] = []
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def add_performance(self, accuracy: float):
|
||||
"""Add performance metric.
|
||||
|
||||
Args:
|
||||
accuracy: Prediction accuracy or performance metric
|
||||
"""
|
||||
self.performance_history.append(accuracy)
|
||||
if len(self.performance_history) > self.window_size:
|
||||
self.performance_history.pop(0)
|
||||
|
||||
def detect_drift(self) -> Tuple[bool, Optional[float]]:
|
||||
"""Detect if concept drift has occurred.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_drift, drift_score)
|
||||
"""
|
||||
if len(self.performance_history) < self.window_size:
|
||||
return False, None
|
||||
|
||||
# Compare recent performance to baseline
|
||||
baseline = np.mean(self.performance_history[:self.window_size // 2])
|
||||
recent = np.mean(self.performance_history[-self.window_size // 2:])
|
||||
|
||||
drift_score = baseline - recent
|
||||
|
||||
if drift_score > self.drift_threshold:
|
||||
self.logger.warning(f"Concept drift detected: {drift_score:.3f} accuracy drop")
|
||||
return True, drift_score
|
||||
|
||||
return False, drift_score
|
||||
|
||||
|
||||
class OnlineLearningPipeline:
|
||||
"""Online learning pipeline for incremental model updates."""
|
||||
|
||||
def __init__(self, model: StrategySelectorModel):
|
||||
"""Initialize online learning pipeline.
|
||||
|
||||
Args:
|
||||
model: Strategy selector model
|
||||
"""
|
||||
self.model = model
|
||||
self.config = get_config()
|
||||
self.performance_tracker = get_performance_tracker()
|
||||
self.drift_detector = ConceptDriftDetector(
|
||||
window_size=self.config.get("autopilot.online_learning.drift_window", 100),
|
||||
drift_threshold=self.config.get("autopilot.online_learning.drift_threshold", 0.1)
|
||||
)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Buffers for incremental updates
|
||||
self.feature_buffer: List[Dict[str, float]] = []
|
||||
self.label_buffer: List[str] = []
|
||||
self.buffer_size = self.config.get("autopilot.online_learning.buffer_size", 50)
|
||||
self.update_frequency = self.config.get("autopilot.online_learning.update_frequency", 100)
|
||||
|
||||
async def add_training_sample(
|
||||
self,
|
||||
market_conditions: MarketConditions,
|
||||
strategy_name: str,
|
||||
performance: float
|
||||
):
|
||||
"""Add a training sample to the buffer.
|
||||
|
||||
Args:
|
||||
market_conditions: Market conditions at prediction time
|
||||
strategy_name: Strategy that was selected
|
||||
performance: Performance metric (e.g., return, sharpe ratio)
|
||||
"""
|
||||
features = market_conditions.features
|
||||
self.feature_buffer.append(features)
|
||||
self.label_buffer.append(strategy_name)
|
||||
|
||||
# Track performance for drift detection
|
||||
# Use binary indicator: 1 if performance > threshold, 0 otherwise
|
||||
performance_indicator = 1.0 if performance > 0 else 0.0
|
||||
self.drift_detector.add_performance(performance_indicator)
|
||||
|
||||
# Check if buffer is full and trigger update
|
||||
if len(self.feature_buffer) >= self.update_frequency:
|
||||
await self.incremental_update()
|
||||
|
||||
async def incremental_update(self) -> Dict[str, Any]:
|
||||
"""Perform incremental model update.
|
||||
|
||||
Returns:
|
||||
Update metrics dictionary
|
||||
"""
|
||||
if len(self.feature_buffer) < self.buffer_size:
|
||||
return {"error": "Insufficient data in buffer"}
|
||||
|
||||
try:
|
||||
# Prepare new data
|
||||
X_new = pd.DataFrame(self.feature_buffer)
|
||||
y_new = np.array(self.label_buffer)
|
||||
|
||||
# Get existing training data
|
||||
training_data = await self.performance_tracker.prepare_training_data(
|
||||
min_samples_per_strategy=1
|
||||
)
|
||||
|
||||
if training_data is None:
|
||||
# Not enough historical data, just buffer for now
|
||||
return {"status": "buffering", "buffer_size": len(self.feature_buffer)}
|
||||
|
||||
# Combine with existing data (last N samples for memory efficiency)
|
||||
X_existing = training_data['X']
|
||||
y_existing = np.array(training_data['y'])
|
||||
|
||||
# Take last 1000 samples from existing data to keep memory manageable
|
||||
if len(X_existing) > 1000:
|
||||
X_existing = X_existing.tail(1000)
|
||||
y_existing = y_existing[-1000:]
|
||||
|
||||
# Combine
|
||||
X_combined = pd.concat([X_existing, X_new], ignore_index=True)
|
||||
y_combined = np.concatenate([y_existing, y_new])
|
||||
|
||||
# Retrain model
|
||||
strategy_names = list(set(y_combined))
|
||||
metrics = self.model.train(
|
||||
X_combined,
|
||||
y_combined,
|
||||
strategy_names,
|
||||
use_ensemble=True
|
||||
)
|
||||
|
||||
# Save updated model
|
||||
self.model.save()
|
||||
|
||||
# Clear buffers
|
||||
self.feature_buffer = []
|
||||
self.label_buffer = []
|
||||
|
||||
self.logger.info(f"Incremental update complete: {metrics.get('train_accuracy', 0):.3f}")
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
"metrics": metrics,
|
||||
"samples_added": len(X_new)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Incremental update failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def check_drift(self) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
"""Check for concept drift.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_drift, drift_info)
|
||||
"""
|
||||
has_drift, drift_score = self.drift_detector.detect_drift()
|
||||
|
||||
if has_drift:
|
||||
return True, {
|
||||
"drift_score": float(drift_score),
|
||||
"recommendation": "retrain"
|
||||
}
|
||||
|
||||
return False, None
|
||||
|
||||
async def trigger_full_retrain_if_needed(self) -> Optional[Dict[str, Any]]:
|
||||
"""Trigger full model retrain if drift detected.
|
||||
|
||||
Returns:
|
||||
Retrain metrics or None
|
||||
"""
|
||||
has_drift, drift_info = self.check_drift()
|
||||
|
||||
if has_drift:
|
||||
self.logger.info("Concept drift detected, triggering full retrain")
|
||||
|
||||
# Clear buffers to force fresh retrain
|
||||
self.feature_buffer = []
|
||||
self.label_buffer = []
|
||||
self.drift_detector.performance_history = []
|
||||
|
||||
# Full retrain
|
||||
training_data = await self.performance_tracker.prepare_training_data(
|
||||
min_samples_per_strategy=10
|
||||
)
|
||||
|
||||
if training_data is None:
|
||||
return {"error": "Insufficient training data"}
|
||||
|
||||
metrics = self.model.train(
|
||||
training_data['X'],
|
||||
np.array(training_data['y']),
|
||||
training_data['strategy_names'],
|
||||
use_ensemble=True
|
||||
)
|
||||
|
||||
self.model.save()
|
||||
|
||||
return {
|
||||
"status": "retrained",
|
||||
"metrics": metrics,
|
||||
"drift_info": drift_info
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global online learning pipeline
|
||||
_online_learning: Optional[OnlineLearningPipeline] = None
|
||||
|
||||
|
||||
def get_online_learning_pipeline(model: StrategySelectorModel) -> OnlineLearningPipeline:
|
||||
"""Get global online learning pipeline instance.
|
||||
|
||||
Args:
|
||||
model: Strategy selector model
|
||||
|
||||
Returns:
|
||||
OnlineLearningPipeline instance
|
||||
"""
|
||||
global _online_learning
|
||||
if _online_learning is None:
|
||||
_online_learning = OnlineLearningPipeline(model)
|
||||
return _online_learning
|
||||
286
src/autopilot/regime_detection.py
Normal file
286
src/autopilot/regime_detection.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Advanced regime detection using HMM and GMM models."""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from src.core.logger import get_logger
|
||||
from .market_analyzer import MarketRegime
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from hmmlearn import hmm
|
||||
HAS_HMM = True
|
||||
except ImportError:
|
||||
HAS_HMM = False
|
||||
logger.warning("hmmlearn not installed. Install with: pip install hmmlearn")
|
||||
|
||||
try:
|
||||
from sklearn.mixture import GaussianMixture
|
||||
HAS_GMM = True
|
||||
except ImportError:
|
||||
HAS_GMM = False
|
||||
logger.warning("sklearn not available for GMM")
|
||||
|
||||
|
||||
class HMMRegimeDetector:
|
||||
"""Hidden Markov Model-based regime detection."""
|
||||
|
||||
def __init__(self, n_regimes: int = 4):
|
||||
"""Initialize HMM regime detector.
|
||||
|
||||
Args:
|
||||
n_regimes: Number of hidden regimes to model
|
||||
"""
|
||||
self.n_regimes = n_regimes
|
||||
self.model = None
|
||||
self.is_fitted = False
|
||||
self.regime_labels = ['trending_up', 'trending_down', 'ranging', 'high_volatility']
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def fit(self, returns: np.ndarray, volatility: np.ndarray):
|
||||
"""Fit HMM model to data.
|
||||
|
||||
Args:
|
||||
returns: Price returns array
|
||||
volatility: Volatility array
|
||||
"""
|
||||
if not HAS_HMM:
|
||||
self.logger.warning("HMM not available, using rule-based detection")
|
||||
self.is_fitted = False
|
||||
return
|
||||
|
||||
try:
|
||||
# Prepare features: returns and volatility
|
||||
X = np.column_stack([returns, volatility])
|
||||
|
||||
# Create and fit HMM
|
||||
self.model = hmm.GaussianHMM(
|
||||
n_components=self.n_regimes,
|
||||
covariance_type="full",
|
||||
n_iter=100
|
||||
)
|
||||
self.model.fit(X)
|
||||
self.is_fitted = True
|
||||
|
||||
self.logger.info(f"HMM regime detector fitted with {self.n_regimes} regimes")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to fit HMM: {e}")
|
||||
self.is_fitted = False
|
||||
|
||||
def predict(self, returns: float, volatility: float) -> Tuple[int, MarketRegime]:
|
||||
"""Predict regime for given features.
|
||||
|
||||
Args:
|
||||
returns: Price return
|
||||
volatility: Volatility
|
||||
|
||||
Returns:
|
||||
Tuple of (regime_index, MarketRegime enum)
|
||||
"""
|
||||
if not self.is_fitted or self.model is None:
|
||||
return 0, MarketRegime.UNKNOWN
|
||||
|
||||
try:
|
||||
X = np.array([[returns, volatility]])
|
||||
regime_idx = self.model.predict(X)[0]
|
||||
|
||||
# Map to MarketRegime enum
|
||||
if regime_idx < len(self.regime_labels):
|
||||
regime_str = self.regime_labels[regime_idx]
|
||||
regime_map = {
|
||||
'trending_up': MarketRegime.TRENDING_UP,
|
||||
'trending_down': MarketRegime.TRENDING_DOWN,
|
||||
'ranging': MarketRegime.RANGING,
|
||||
'high_volatility': MarketRegime.HIGH_VOLATILITY
|
||||
}
|
||||
return regime_idx, regime_map.get(regime_str, MarketRegime.UNKNOWN)
|
||||
|
||||
return regime_idx, MarketRegime.UNKNOWN
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"HMM prediction failed: {e}")
|
||||
return 0, MarketRegime.UNKNOWN
|
||||
|
||||
|
||||
class GMMRegimeDetector:
|
||||
"""Gaussian Mixture Model-based regime detection."""
|
||||
|
||||
def __init__(self, n_regimes: int = 4):
|
||||
"""Initialize GMM regime detector.
|
||||
|
||||
Args:
|
||||
n_regimes: Number of mixture components
|
||||
"""
|
||||
self.n_regimes = n_regimes
|
||||
self.model = None
|
||||
self.is_fitted = False
|
||||
self.regime_labels = ['trending_up', 'trending_down', 'ranging', 'high_volatility']
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def fit(self, returns: np.ndarray, volatility: np.ndarray):
|
||||
"""Fit GMM model to data.
|
||||
|
||||
Args:
|
||||
returns: Price returns array
|
||||
volatility: Volatility array
|
||||
"""
|
||||
if not HAS_GMM:
|
||||
self.logger.warning("GMM not available")
|
||||
self.is_fitted = False
|
||||
return
|
||||
|
||||
try:
|
||||
# Prepare features
|
||||
X = np.column_stack([returns, volatility])
|
||||
|
||||
# Create and fit GMM
|
||||
self.model = GaussianMixture(
|
||||
n_components=self.n_regimes,
|
||||
covariance_type='full',
|
||||
max_iter=100
|
||||
)
|
||||
self.model.fit(X)
|
||||
self.is_fitted = True
|
||||
|
||||
self.logger.info(f"GMM regime detector fitted with {self.n_regimes} components")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to fit GMM: {e}")
|
||||
self.is_fitted = False
|
||||
|
||||
def predict(self, returns: float, volatility: float) -> Tuple[int, MarketRegime]:
|
||||
"""Predict regime for given features.
|
||||
|
||||
Args:
|
||||
returns: Price return
|
||||
volatility: Volatility
|
||||
|
||||
Returns:
|
||||
Tuple of (regime_index, MarketRegime enum)
|
||||
"""
|
||||
if not self.is_fitted or self.model is None:
|
||||
return 0, MarketRegime.UNKNOWN
|
||||
|
||||
try:
|
||||
X = np.array([[returns, volatility]])
|
||||
regime_idx = self.model.predict(X)[0]
|
||||
|
||||
# Map to MarketRegime enum
|
||||
if regime_idx < len(self.regime_labels):
|
||||
regime_str = self.regime_labels[regime_idx]
|
||||
regime_map = {
|
||||
'trending_up': MarketRegime.TRENDING_UP,
|
||||
'trending_down': MarketRegime.TRENDING_DOWN,
|
||||
'ranging': MarketRegime.RANGING,
|
||||
'high_volatility': MarketRegime.HIGH_VOLATILITY
|
||||
}
|
||||
return regime_idx, regime_map.get(regime_str, MarketRegime.UNKNOWN)
|
||||
|
||||
return regime_idx, MarketRegime.UNKNOWN
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"GMM prediction failed: {e}")
|
||||
return 0, MarketRegime.UNKNOWN
|
||||
|
||||
def predict_proba(self, returns: float, volatility: float) -> np.ndarray:
|
||||
"""Get probability distribution over regimes.
|
||||
|
||||
Args:
|
||||
returns: Price return
|
||||
volatility: Volatility
|
||||
|
||||
Returns:
|
||||
Probability array over regimes
|
||||
"""
|
||||
if not self.is_fitted or self.model is None:
|
||||
return np.array([1.0 / self.n_regimes] * self.n_regimes)
|
||||
|
||||
try:
|
||||
X = np.array([[returns, volatility]])
|
||||
return self.model.predict_proba(X)[0]
|
||||
except Exception as e:
|
||||
self.logger.error(f"GMM probability prediction failed: {e}")
|
||||
return np.array([1.0 / self.n_regimes] * self.n_regimes)
|
||||
|
||||
|
||||
class AdvancedRegimeDetector:
|
||||
"""Advanced regime detection combining multiple methods."""
|
||||
|
||||
def __init__(self, method: str = "hmm"):
|
||||
"""Initialize advanced regime detector.
|
||||
|
||||
Args:
|
||||
method: Detection method ('hmm', 'gmm', or 'hybrid')
|
||||
"""
|
||||
self.method = method
|
||||
self.hmm_detector = HMMRegimeDetector() if HAS_HMM else None
|
||||
self.gmm_detector = GMMRegimeDetector() if HAS_GMM else None
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def fit_from_dataframe(self, df: pd.DataFrame):
|
||||
"""Fit detector from OHLCV DataFrame.
|
||||
|
||||
Args:
|
||||
df: OHLCV DataFrame with close and volume columns
|
||||
"""
|
||||
try:
|
||||
# Calculate returns and volatility
|
||||
returns = df['close'].pct_change().dropna().values
|
||||
volatility = df['close'].rolling(20).std().dropna().values
|
||||
|
||||
# Align arrays
|
||||
min_len = min(len(returns), len(volatility))
|
||||
returns = returns[-min_len:]
|
||||
volatility = volatility[-min_len:]
|
||||
|
||||
if len(returns) < 50:
|
||||
self.logger.warning("Insufficient data for regime detection")
|
||||
return
|
||||
|
||||
# Fit models based on method
|
||||
if self.method == "hmm" and self.hmm_detector:
|
||||
self.hmm_detector.fit(returns, volatility)
|
||||
elif self.method == "gmm" and self.gmm_detector:
|
||||
self.gmm_detector.fit(returns, volatility)
|
||||
elif self.method == "hybrid":
|
||||
if self.hmm_detector:
|
||||
self.hmm_detector.fit(returns, volatility)
|
||||
if self.gmm_detector:
|
||||
self.gmm_detector.fit(returns, volatility)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to fit regime detector: {e}")
|
||||
|
||||
def detect_regime(
|
||||
self,
|
||||
returns: float,
|
||||
volatility: float
|
||||
) -> MarketRegime:
|
||||
"""Detect regime for given features.
|
||||
|
||||
Args:
|
||||
returns: Price return
|
||||
volatility: Volatility
|
||||
|
||||
Returns:
|
||||
MarketRegime classification
|
||||
"""
|
||||
if self.method == "hmm" and self.hmm_detector and self.hmm_detector.is_fitted:
|
||||
_, regime = self.hmm_detector.predict(returns, volatility)
|
||||
return regime
|
||||
elif self.method == "gmm" and self.gmm_detector and self.gmm_detector.is_fitted:
|
||||
_, regime = self.gmm_detector.predict(returns, volatility)
|
||||
return regime
|
||||
elif self.method == "hybrid":
|
||||
# Use both and combine
|
||||
if self.gmm_detector and self.gmm_detector.is_fitted:
|
||||
_, regime = self.gmm_detector.predict(returns, volatility)
|
||||
return regime
|
||||
elif self.hmm_detector and self.hmm_detector.is_fitted:
|
||||
_, regime = self.hmm_detector.predict(returns, volatility)
|
||||
return regime
|
||||
|
||||
# Fallback to UNKNOWN
|
||||
return MarketRegime.UNKNOWN
|
||||
144
src/backtesting/monte_carlo.py
Normal file
144
src/backtesting/monte_carlo.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Monte Carlo simulation for backtesting risk analysis."""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any
|
||||
from src.core.logger import get_logger
|
||||
from .engine import BacktestingEngine
|
||||
from .metrics import BacktestMetrics
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MonteCarloSimulator:
|
||||
"""Monte Carlo simulation for backtesting with random parameter variations."""
|
||||
|
||||
def __init__(self, backtest_engine: BacktestingEngine):
|
||||
"""Initialize Monte Carlo simulator.
|
||||
|
||||
Args:
|
||||
backtest_engine: BacktestingEngine instance
|
||||
"""
|
||||
self.backtest_engine = backtest_engine
|
||||
self.metrics = BacktestMetrics()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def run_monte_carlo(
|
||||
self,
|
||||
strategy_class,
|
||||
symbol: str,
|
||||
exchange: str,
|
||||
timeframe: str,
|
||||
start_date,
|
||||
end_date,
|
||||
initial_capital: Decimal = Decimal("10000.0"),
|
||||
num_simulations: int = 1000,
|
||||
parameter_ranges: Optional[Dict[str, tuple]] = None,
|
||||
random_seed: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run Monte Carlo simulation with random parameter variations.
|
||||
|
||||
Args:
|
||||
strategy_class: Strategy class to test
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
timeframe: Timeframe
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
initial_capital: Initial capital
|
||||
num_simulations: Number of simulations to run
|
||||
parameter_ranges: Dictionary of parameter_name -> (min, max) for random sampling
|
||||
random_seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Monte Carlo results dictionary with distributions
|
||||
"""
|
||||
if random_seed is not None:
|
||||
np.random.seed(random_seed)
|
||||
|
||||
results = []
|
||||
|
||||
for i in range(num_simulations):
|
||||
try:
|
||||
# Sample random parameters if ranges provided
|
||||
if parameter_ranges:
|
||||
params = {}
|
||||
for param_name, (min_val, max_val) in parameter_ranges.items():
|
||||
if isinstance(min_val, int) and isinstance(max_val, int):
|
||||
params[param_name] = np.random.randint(min_val, max_val + 1)
|
||||
else:
|
||||
params[param_name] = np.random.uniform(min_val, max_val)
|
||||
|
||||
strategy_instance = strategy_class(**params)
|
||||
else:
|
||||
strategy_instance = strategy_class()
|
||||
|
||||
# Run backtest
|
||||
backtest_result = await self.backtest_engine.run_backtest(
|
||||
strategy=strategy_instance,
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
timeframe=timeframe,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
initial_capital=initial_capital
|
||||
)
|
||||
|
||||
if "error" not in backtest_result:
|
||||
results.append(backtest_result)
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
self.logger.info(f"Monte Carlo: {i + 1}/{num_simulations} simulations completed")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in Monte Carlo simulation {i + 1}: {e}")
|
||||
continue
|
||||
|
||||
if not results:
|
||||
return {"error": "No valid Monte Carlo simulations"}
|
||||
|
||||
# Calculate distributions
|
||||
returns = [r.get('total_return', 0) for r in results]
|
||||
sharpe_ratios = [r.get('sharpe_ratio', 0) for r in results]
|
||||
max_drawdowns = [r.get('max_drawdown', 0) for r in results]
|
||||
win_rates = [r.get('win_rate', 0) for r in results]
|
||||
|
||||
# Calculate statistics
|
||||
return {
|
||||
'num_simulations': num_simulations,
|
||||
'valid_simulations': len(results),
|
||||
'return_distribution': {
|
||||
'mean': float(np.mean(returns)),
|
||||
'std': float(np.std(returns)),
|
||||
'min': float(np.min(returns)),
|
||||
'max': float(np.max(returns)),
|
||||
'percentile_5': float(np.percentile(returns, 5)),
|
||||
'percentile_25': float(np.percentile(returns, 25)),
|
||||
'percentile_50': float(np.percentile(returns, 50)),
|
||||
'percentile_75': float(np.percentile(returns, 75)),
|
||||
'percentile_95': float(np.percentile(returns, 95)),
|
||||
},
|
||||
'sharpe_distribution': {
|
||||
'mean': float(np.mean(sharpe_ratios)),
|
||||
'std': float(np.std(sharpe_ratios)),
|
||||
'min': float(np.min(sharpe_ratios)),
|
||||
'max': float(np.max(sharpe_ratios)),
|
||||
},
|
||||
'drawdown_distribution': {
|
||||
'mean': float(np.mean(max_drawdowns)),
|
||||
'std': float(np.std(max_drawdowns)),
|
||||
'worst': float(np.min(max_drawdowns)),
|
||||
},
|
||||
'win_rate_distribution': {
|
||||
'mean': float(np.mean(win_rates)),
|
||||
'std': float(np.std(win_rates)),
|
||||
},
|
||||
'confidence_intervals': {
|
||||
'return_95_lower': float(np.percentile(returns, 2.5)),
|
||||
'return_95_upper': float(np.percentile(returns, 97.5)),
|
||||
'return_99_lower': float(np.percentile(returns, 0.5)),
|
||||
'return_99_upper': float(np.percentile(returns, 99.5)),
|
||||
},
|
||||
'all_results': results[:100] # Return first 100 for detailed analysis
|
||||
}
|
||||
342
src/backtesting/walk_forward.py
Normal file
342
src/backtesting/walk_forward.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Walk-forward analysis for robust backtesting."""
|
||||
|
||||
import pandas as pd
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from src.core.logger import get_logger
|
||||
from src.strategies.base import BaseStrategy
|
||||
from .engine import BacktestingEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class WalkForwardAnalyzer:
|
||||
"""Walk-forward analysis for robust parameter optimization."""
|
||||
|
||||
def __init__(self, backtest_engine: BacktestingEngine):
|
||||
"""Initialize walk-forward analyzer.
|
||||
|
||||
Args:
|
||||
backtest_engine: BacktestingEngine instance
|
||||
"""
|
||||
self.backtest_engine = backtest_engine
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def run_walk_forward(
|
||||
self,
|
||||
strategy_class,
|
||||
symbol: str,
|
||||
exchange: str,
|
||||
timeframe: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
train_period_days: int = 90,
|
||||
test_period_days: int = 30,
|
||||
step_days: int = 30,
|
||||
initial_capital: Decimal = Decimal("10000.0"),
|
||||
parameter_grid: Optional[Dict[str, List[Any]]] = None,
|
||||
optimization_metric: str = "sharpe_ratio"
|
||||
) -> Dict[str, Any]:
|
||||
"""Run walk-forward optimization.
|
||||
|
||||
Args:
|
||||
strategy_class: Strategy class to test
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
timeframe: Timeframe
|
||||
start_date: Start date for analysis
|
||||
end_date: End date for analysis
|
||||
train_period_days: Training period in days
|
||||
test_period_days: Testing period in days
|
||||
step_days: Step size between windows in days
|
||||
initial_capital: Initial capital per window
|
||||
parameter_grid: Parameter grid for optimization {param_name: [values]}
|
||||
optimization_metric: Metric to optimize (sharpe_ratio, total_return, etc.)
|
||||
|
||||
Returns:
|
||||
Walk-forward results dictionary
|
||||
"""
|
||||
windows = self._generate_windows(
|
||||
start_date, end_date, train_period_days, test_period_days, step_days
|
||||
)
|
||||
|
||||
results = []
|
||||
best_params_history = []
|
||||
|
||||
for i, window in enumerate(windows):
|
||||
train_start, train_end, test_start, test_end = window
|
||||
|
||||
self.logger.info(
|
||||
f"Walk-forward window {i+1}/{len(windows)}: "
|
||||
f"Train {train_start.date()} to {train_end.date()}, "
|
||||
f"Test {test_start.date()} to {test_end.date()}"
|
||||
)
|
||||
|
||||
# Optimize parameters on training period
|
||||
if parameter_grid:
|
||||
best_params, train_results = await self._optimize_parameters(
|
||||
strategy_class=strategy_class,
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
timeframe=timeframe,
|
||||
start_date=train_start,
|
||||
end_date=train_end,
|
||||
parameter_grid=parameter_grid,
|
||||
initial_capital=initial_capital,
|
||||
optimization_metric=optimization_metric
|
||||
)
|
||||
best_params_history.append(best_params)
|
||||
else:
|
||||
# Use default parameters
|
||||
best_params = {}
|
||||
|
||||
# Test on out-of-sample period
|
||||
strategy_instance = strategy_class(**best_params)
|
||||
test_results = await self.backtest_engine.run_backtest(
|
||||
strategy=strategy_instance,
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
timeframe=timeframe,
|
||||
start_date=test_start,
|
||||
end_date=test_end,
|
||||
initial_capital=initial_capital
|
||||
)
|
||||
|
||||
if "error" not in test_results:
|
||||
test_results['window'] = i + 1
|
||||
test_results['train_start'] = train_start.isoformat()
|
||||
test_results['train_end'] = train_end.isoformat()
|
||||
test_results['test_start'] = test_start.isoformat()
|
||||
test_results['test_end'] = test_end.isoformat()
|
||||
test_results['parameters'] = best_params
|
||||
results.append(test_results)
|
||||
|
||||
# Aggregate results
|
||||
if not results:
|
||||
return {"error": "No valid walk-forward windows"}
|
||||
|
||||
return self._aggregate_results(results, best_params_history)
|
||||
|
||||
def _generate_windows(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
train_period_days: int,
|
||||
test_period_days: int,
|
||||
step_days: int
|
||||
) -> List[Tuple[datetime, datetime, datetime, datetime]]:
|
||||
"""Generate walk-forward windows.
|
||||
|
||||
Args:
|
||||
start_date: Overall start date
|
||||
end_date: Overall end date
|
||||
train_period_days: Training period length
|
||||
test_period_days: Testing period length
|
||||
step_days: Step between windows
|
||||
|
||||
Returns:
|
||||
List of (train_start, train_end, test_start, test_end) tuples
|
||||
"""
|
||||
windows = []
|
||||
current_start = start_date
|
||||
|
||||
while current_start < end_date:
|
||||
train_start = current_start
|
||||
train_end = train_start + timedelta(days=train_period_days)
|
||||
test_start = train_end
|
||||
test_end = test_start + timedelta(days=test_period_days)
|
||||
|
||||
if test_end > end_date:
|
||||
break
|
||||
|
||||
windows.append((train_start, train_end, test_start, test_end))
|
||||
current_start += timedelta(days=step_days)
|
||||
|
||||
return windows
|
||||
|
||||
async def _optimize_parameters(
|
||||
self,
|
||||
strategy_class,
|
||||
symbol: str,
|
||||
exchange: str,
|
||||
timeframe: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
parameter_grid: Dict[str, List[Any]],
|
||||
initial_capital: Decimal,
|
||||
optimization_metric: str = "sharpe_ratio"
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Optimize strategy parameters using grid search.
|
||||
|
||||
Args:
|
||||
strategy_class: Strategy class
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
timeframe: Timeframe
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
parameter_grid: Parameter grid
|
||||
initial_capital: Initial capital
|
||||
optimization_metric: Metric to optimize
|
||||
|
||||
Returns:
|
||||
(best_parameters, best_results) tuple
|
||||
"""
|
||||
from itertools import product
|
||||
|
||||
# Generate all parameter combinations
|
||||
param_names = list(parameter_grid.keys())
|
||||
param_values = list(parameter_grid.values())
|
||||
combinations = list(product(*param_values))
|
||||
|
||||
best_metric = float('-inf')
|
||||
best_params = {}
|
||||
best_results = {}
|
||||
|
||||
for combo in combinations:
|
||||
params = dict(zip(param_names, combo))
|
||||
strategy_instance = strategy_class(**params)
|
||||
|
||||
results = await self.backtest_engine.run_backtest(
|
||||
strategy=strategy_instance,
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
timeframe=timeframe,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
initial_capital=initial_capital
|
||||
)
|
||||
|
||||
if "error" in results:
|
||||
continue
|
||||
|
||||
# Get optimization metric
|
||||
metric_value = results.get(optimization_metric, float('-inf'))
|
||||
if metric_value is None:
|
||||
metric_value = float('-inf')
|
||||
|
||||
if metric_value > best_metric:
|
||||
best_metric = metric_value
|
||||
best_params = params
|
||||
best_results = results
|
||||
|
||||
return best_params, best_results
|
||||
|
||||
def _aggregate_results(
|
||||
self,
|
||||
results: List[Dict[str, Any]],
|
||||
best_params_history: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Aggregate walk-forward results.
|
||||
|
||||
Args:
|
||||
results: List of window results
|
||||
best_params_history: List of best parameters per window
|
||||
|
||||
Returns:
|
||||
Aggregated results dictionary
|
||||
"""
|
||||
# Calculate aggregate metrics
|
||||
total_returns = [r.get('total_return', 0) for r in results if r.get('total_return') is not None]
|
||||
sharpe_ratios = [r.get('sharpe_ratio', 0) for r in results if r.get('sharpe_ratio') is not None]
|
||||
max_drawdowns = [r.get('max_drawdown', 0) for r in results if r.get('max_drawdown') is not None]
|
||||
win_rates = [r.get('win_rate', 0) for r in results if r.get('win_rate') is not None]
|
||||
|
||||
# Parameter stability analysis
|
||||
param_stability = self._analyze_parameter_stability(best_params_history)
|
||||
|
||||
return {
|
||||
'num_windows': len(results),
|
||||
'window_results': results,
|
||||
'aggregate_metrics': {
|
||||
'avg_return': sum(total_returns) / len(total_returns) if total_returns else 0,
|
||||
'avg_sharpe': sum(sharpe_ratios) / len(sharpe_ratios) if sharpe_ratios else 0,
|
||||
'avg_max_drawdown': sum(max_drawdowns) / len(max_drawdowns) if max_drawdowns else 0,
|
||||
'avg_win_rate': sum(win_rates) / len(win_rates) if win_rates else 0,
|
||||
'min_return': min(total_returns) if total_returns else 0,
|
||||
'max_return': max(total_returns) if total_returns else 0,
|
||||
'std_return': pd.Series(total_returns).std() if total_returns else 0,
|
||||
},
|
||||
'parameter_stability': param_stability,
|
||||
'best_parameters': self._get_most_common_params(best_params_history)
|
||||
}
|
||||
|
||||
def _analyze_parameter_stability(
|
||||
self,
|
||||
params_history: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze parameter stability across windows.
|
||||
|
||||
Args:
|
||||
params_history: List of parameter dictionaries
|
||||
|
||||
Returns:
|
||||
Stability analysis dictionary
|
||||
"""
|
||||
if not params_history:
|
||||
return {}
|
||||
|
||||
# Count parameter occurrences
|
||||
param_counts = {}
|
||||
for params in params_history:
|
||||
for key, value in params.items():
|
||||
if key not in param_counts:
|
||||
param_counts[key] = {}
|
||||
value_str = str(value)
|
||||
param_counts[key][value_str] = param_counts[key].get(value_str, 0) + 1
|
||||
|
||||
# Calculate stability (most common value frequency)
|
||||
stability = {}
|
||||
for param_name, counts in param_counts.items():
|
||||
if counts:
|
||||
max_count = max(counts.values())
|
||||
total_windows = sum(counts.values())
|
||||
stability[param_name] = {
|
||||
'most_common': max(counts.items(), key=lambda x: x[1])[0],
|
||||
'frequency': max_count / total_windows if total_windows > 0 else 0,
|
||||
'counts': counts
|
||||
}
|
||||
|
||||
return stability
|
||||
|
||||
def _get_most_common_params(
|
||||
self,
|
||||
params_history: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Get most common parameters across all windows.
|
||||
|
||||
Args:
|
||||
params_history: List of parameter dictionaries
|
||||
|
||||
Returns:
|
||||
Most common parameter values
|
||||
"""
|
||||
if not params_history:
|
||||
return {}
|
||||
|
||||
param_counts = {}
|
||||
for params in params_history:
|
||||
for key, value in params.items():
|
||||
if key not in param_counts:
|
||||
param_counts[key] = {}
|
||||
value_str = str(value)
|
||||
param_counts[key][value_str] = param_counts[key].get(value_str, 0) + 1
|
||||
|
||||
most_common = {}
|
||||
for param_name, counts in param_counts.items():
|
||||
if counts:
|
||||
most_common_value_str = max(counts.items(), key=lambda x: x[1])[0]
|
||||
# Try to convert back to original type
|
||||
try:
|
||||
# Try float first
|
||||
most_common[param_name] = float(most_common_value_str)
|
||||
except ValueError:
|
||||
try:
|
||||
# Try int
|
||||
most_common[param_name] = int(most_common_value_str)
|
||||
except ValueError:
|
||||
# Keep as string
|
||||
most_common[param_name] = most_common_value_str
|
||||
|
||||
return most_common
|
||||
275
src/portfolio/correlation_analyzer.py
Normal file
275
src/portfolio/correlation_analyzer.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Portfolio correlation analysis for risk management."""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import select
|
||||
from src.core.database import get_database, Position, PortfolioSnapshot, MarketData
|
||||
from src.core.logger import get_logger
|
||||
from src.portfolio.tracker import get_portfolio_tracker
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CorrelationAnalyzer:
|
||||
"""Analyzes portfolio correlation for risk management."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize correlation analyzer."""
|
||||
self.db = get_database()
|
||||
self.tracker = get_portfolio_tracker()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def calculate_correlation_matrix(
|
||||
self,
|
||||
symbols: List[str],
|
||||
timeframe: str = "1h",
|
||||
lookback_days: int = 90,
|
||||
exchange: str = "coinbase"
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculate correlation matrix for given symbols.
|
||||
|
||||
Args:
|
||||
symbols: List of symbols to analyze
|
||||
timeframe: Timeframe for data
|
||||
lookback_days: Number of days to look back
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Dictionary with correlation matrix and metrics
|
||||
"""
|
||||
# Get price data for all symbols
|
||||
price_data = {}
|
||||
|
||||
for symbol in symbols:
|
||||
prices = await self._get_price_series(symbol, timeframe, lookback_days, exchange)
|
||||
if len(prices) > 0:
|
||||
price_data[symbol] = prices
|
||||
|
||||
if len(price_data) < 2:
|
||||
return {"error": "Insufficient data for correlation analysis"}
|
||||
|
||||
# Create DataFrame
|
||||
df = pd.DataFrame(price_data)
|
||||
|
||||
# Calculate returns
|
||||
returns_df = df.pct_change().dropna()
|
||||
|
||||
# Calculate correlation matrix
|
||||
correlation_matrix = returns_df.corr()
|
||||
|
||||
# Calculate average correlation
|
||||
# Get upper triangle (excluding diagonal)
|
||||
upper_triangle = correlation_matrix.where(
|
||||
np.triu(np.ones(correlation_matrix.shape), k=1).astype(bool)
|
||||
)
|
||||
avg_correlation = upper_triangle.stack().mean()
|
||||
|
||||
# Find highest correlations
|
||||
high_correlations = []
|
||||
for i, symbol1 in enumerate(correlation_matrix.columns):
|
||||
for j, symbol2 in enumerate(correlation_matrix.columns):
|
||||
if i < j: # Upper triangle only
|
||||
corr = correlation_matrix.loc[symbol1, symbol2]
|
||||
if corr > 0.7: # High correlation threshold
|
||||
high_correlations.append({
|
||||
'symbol1': symbol1,
|
||||
'symbol2': symbol2,
|
||||
'correlation': float(corr)
|
||||
})
|
||||
|
||||
# Sort by correlation
|
||||
high_correlations.sort(key=lambda x: x['correlation'], reverse=True)
|
||||
|
||||
return {
|
||||
'correlation_matrix': correlation_matrix.to_dict(),
|
||||
'symbols': symbols,
|
||||
'avg_correlation': float(avg_correlation),
|
||||
'high_correlations': high_correlations,
|
||||
'lookback_days': lookback_days,
|
||||
'timeframe': timeframe
|
||||
}
|
||||
|
||||
async def analyze_portfolio_correlation(
|
||||
self,
|
||||
paper_trading: bool = True,
|
||||
lookback_days: int = 90
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze correlation of current portfolio holdings.
|
||||
|
||||
Args:
|
||||
paper_trading: Paper trading flag
|
||||
lookback_days: Lookback period
|
||||
|
||||
Returns:
|
||||
Portfolio correlation analysis
|
||||
"""
|
||||
# Get current portfolio
|
||||
portfolio = await self.tracker.get_current_portfolio(paper_trading)
|
||||
|
||||
# Extract symbols with positions
|
||||
symbols = [pos['symbol'] for pos in portfolio['positions'] if pos['quantity'] > 0]
|
||||
|
||||
if len(symbols) < 2:
|
||||
return {"error": "Portfolio has fewer than 2 positions"}
|
||||
|
||||
# Calculate correlation matrix
|
||||
correlation_data = await self.calculate_correlation_matrix(symbols, lookback_days=lookback_days)
|
||||
|
||||
if "error" in correlation_data:
|
||||
return correlation_data
|
||||
|
||||
# Calculate diversification score (lower correlation = better diversification)
|
||||
diversification_score = 1.0 - abs(correlation_data['avg_correlation'])
|
||||
diversification_score = max(0.0, min(1.0, diversification_score)) # Clamp to [0, 1]
|
||||
|
||||
# Risk concentration analysis
|
||||
position_values = {
|
||||
pos['symbol']: pos['quantity'] * pos['current_price']
|
||||
for pos in portfolio['positions']
|
||||
if pos['quantity'] > 0
|
||||
}
|
||||
total_value = sum(position_values.values())
|
||||
|
||||
concentration_risk = []
|
||||
for symbol, value in position_values.items():
|
||||
weight = value / total_value if total_value > 0 else 0
|
||||
# Check correlation with other positions
|
||||
symbol_correlations = []
|
||||
for other_symbol in symbols:
|
||||
if other_symbol != symbol:
|
||||
corr_matrix = correlation_data['correlation_matrix']
|
||||
if symbol in corr_matrix and other_symbol in corr_matrix[symbol]:
|
||||
corr = corr_matrix[symbol][other_symbol]
|
||||
symbol_correlations.append(abs(corr))
|
||||
|
||||
avg_corr = np.mean(symbol_correlations) if symbol_correlations else 0
|
||||
concentration_risk.append({
|
||||
'symbol': symbol,
|
||||
'weight': float(weight),
|
||||
'value': float(value),
|
||||
'avg_correlation': float(avg_corr),
|
||||
'risk_score': float(weight * (1 + avg_corr)) # Higher weight + correlation = higher risk
|
||||
})
|
||||
|
||||
concentration_risk.sort(key=lambda x: x['risk_score'], reverse=True)
|
||||
|
||||
return {
|
||||
**correlation_data,
|
||||
'diversification_score': float(diversification_score),
|
||||
'concentration_risk': concentration_risk,
|
||||
'portfolio_value': float(total_value),
|
||||
'num_positions': len(symbols)
|
||||
}
|
||||
|
||||
async def check_correlation_limits(
|
||||
self,
|
||||
symbol: str,
|
||||
new_position_value: Decimal,
|
||||
max_correlation: float = 0.8,
|
||||
paper_trading: bool = True
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""Check if adding a position would exceed correlation limits.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to add
|
||||
new_position_value: Value of new position
|
||||
max_correlation: Maximum allowed correlation
|
||||
paper_trading: Paper trading flag
|
||||
|
||||
Returns:
|
||||
(allowed, reason) tuple
|
||||
"""
|
||||
portfolio = await self.tracker.get_current_portfolio(paper_trading)
|
||||
existing_symbols = [pos['symbol'] for pos in portfolio['positions'] if pos['quantity'] > 0]
|
||||
|
||||
if symbol in existing_symbols:
|
||||
return True, None # Already in portfolio
|
||||
|
||||
if len(existing_symbols) == 0:
|
||||
return True, None # First position
|
||||
|
||||
# Check correlation with existing positions
|
||||
all_symbols = existing_symbols + [symbol]
|
||||
correlation_data = await self.calculate_correlation_matrix(all_symbols)
|
||||
|
||||
if "error" in correlation_data:
|
||||
return True, None # Can't calculate correlation, allow it
|
||||
|
||||
# Check max correlation with any existing position
|
||||
corr_matrix = correlation_data['correlation_matrix']
|
||||
max_corr = 0.0
|
||||
max_corr_symbol = None
|
||||
|
||||
if symbol in corr_matrix:
|
||||
for existing_symbol in existing_symbols:
|
||||
if existing_symbol in corr_matrix[symbol]:
|
||||
corr = abs(corr_matrix[symbol][existing_symbol])
|
||||
if corr > max_corr:
|
||||
max_corr = corr
|
||||
max_corr_symbol = existing_symbol
|
||||
|
||||
if max_corr > max_correlation:
|
||||
return False, f"Correlation {max_corr:.2f} with {max_corr_symbol} exceeds limit {max_correlation}"
|
||||
|
||||
return True, None
|
||||
|
||||
async def _get_price_series(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
lookback_days: int,
|
||||
exchange: str
|
||||
) -> pd.Series:
|
||||
"""Get price series for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
lookback_days: Lookback period
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Series of closing prices
|
||||
"""
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=lookback_days)
|
||||
|
||||
stmt = select(MarketData).filter(
|
||||
MarketData.symbol == symbol,
|
||||
MarketData.timeframe == timeframe,
|
||||
MarketData.exchange == exchange,
|
||||
MarketData.timestamp >= start_date,
|
||||
MarketData.timestamp <= end_date
|
||||
).order_by(MarketData.timestamp)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
data = result.scalars().all()
|
||||
|
||||
if len(data) == 0:
|
||||
return pd.Series(dtype=float)
|
||||
|
||||
prices = [float(d.close) for d in data]
|
||||
timestamps = [d.timestamp for d in data]
|
||||
|
||||
return pd.Series(prices, index=timestamps)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting price series for {symbol}: {e}")
|
||||
return pd.Series(dtype=float)
|
||||
|
||||
|
||||
# Global correlation analyzer
|
||||
_correlation_analyzer: Optional[CorrelationAnalyzer] = None
|
||||
|
||||
|
||||
def get_correlation_analyzer() -> CorrelationAnalyzer:
|
||||
"""Get global correlation analyzer instance."""
|
||||
global _correlation_analyzer
|
||||
if _correlation_analyzer is None:
|
||||
_correlation_analyzer = CorrelationAnalyzer()
|
||||
return _correlation_analyzer
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import select
|
||||
from src.core.database import get_database, RebalancingEvent
|
||||
from src.core.logger import get_logger
|
||||
from src.portfolio.tracker import get_portfolio_tracker
|
||||
@@ -22,11 +22,12 @@ class RebalancingEngine:
|
||||
self.trading_engine = get_trading_engine()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def rebalance(
|
||||
async def rebalance(
|
||||
self,
|
||||
target_allocations: Dict[str, float],
|
||||
exchange_id: int,
|
||||
paper_trading: bool = True
|
||||
paper_trading: bool = True,
|
||||
threshold: float = 0.05 # 5% drift threshold
|
||||
) -> bool:
|
||||
"""Rebalance portfolio to target allocations.
|
||||
|
||||
@@ -40,37 +41,28 @@ class RebalancingEngine:
|
||||
"""
|
||||
try:
|
||||
# Get current portfolio
|
||||
portfolio = self.tracker.get_current_portfolio(paper_trading)
|
||||
total_value = portfolio['performance']['current_value']
|
||||
portfolio = await self.tracker.get_current_portfolio(paper_trading)
|
||||
total_value = Decimal(str(portfolio['performance']['current_value']))
|
||||
|
||||
# Calculate current allocations
|
||||
current_allocations = {}
|
||||
for pos in portfolio['positions']:
|
||||
pos_value = pos['quantity'] * pos['current_price']
|
||||
# Convert to Decimal before calculation (same pattern as check_and_rebalance_threshold)
|
||||
pos_value = Decimal(str(pos['quantity'])) * Decimal(str(pos['current_price']))
|
||||
current_allocations[pos['symbol']] = float(pos_value / total_value) if total_value > 0 else 0.0
|
||||
|
||||
# Get exchange adapter for fee calculations
|
||||
adapter = await self.trading_engine.get_exchange_adapter(exchange_id)
|
||||
|
||||
# Calculate required trades, factoring in fees
|
||||
# Calculate required trades
|
||||
orders = []
|
||||
from src.trading.fee_calculator import get_fee_calculator
|
||||
fee_calculator = get_fee_calculator()
|
||||
|
||||
# Get fee threshold from config (default 0.5% to account for round-trip fees)
|
||||
fee_threshold = Decimal(str(self.tracker.db.get_session().query(
|
||||
# Get from config
|
||||
))) if False else Decimal("0.005") # 0.5% default threshold
|
||||
|
||||
for symbol, target_pct in target_allocations.items():
|
||||
current_pct = current_allocations.get(symbol, 0.0)
|
||||
deviation = target_pct - current_pct
|
||||
|
||||
# Only rebalance if deviation exceeds fee threshold
|
||||
# Default threshold is 1%, but we'll use a configurable fee-aware threshold
|
||||
min_deviation = max(Decimal("0.01"), fee_threshold) # At least 1% or fee threshold
|
||||
|
||||
if abs(deviation) > min_deviation:
|
||||
# Only rebalance if deviation exceeds threshold
|
||||
if abs(deviation) > threshold:
|
||||
target_value = total_value * Decimal(str(target_pct))
|
||||
current_value = Decimal(str(current_allocations.get(symbol, 0.0))) * total_value
|
||||
trade_value = target_value - current_value
|
||||
@@ -81,24 +73,7 @@ class RebalancingEngine:
|
||||
price = ticker.get('last', Decimal(0))
|
||||
|
||||
if price > 0:
|
||||
# Estimate fee for this trade
|
||||
estimated_quantity = abs(trade_value / price)
|
||||
estimated_fee = fee_calculator.estimate_round_trip_fee(
|
||||
quantity=estimated_quantity,
|
||||
price=price,
|
||||
exchange_adapter=adapter
|
||||
)
|
||||
|
||||
# Adjust trade value to account for fees
|
||||
# For buy: reduce quantity to account for fee
|
||||
# For sell: fee comes from proceeds
|
||||
if trade_value > 0: # Buy
|
||||
# Reduce trade value by estimated fee
|
||||
adjusted_trade_value = trade_value - estimated_fee
|
||||
quantity = adjusted_trade_value / price if price > 0 else Decimal(0)
|
||||
else: # Sell
|
||||
# Fee comes from proceeds, so quantity stays the same
|
||||
quantity = abs(trade_value / price)
|
||||
quantity = abs(trade_value / price)
|
||||
|
||||
if quantity > 0:
|
||||
side = 'buy' if trade_value > 0 else 'sell'
|
||||
@@ -129,11 +104,12 @@ class RebalancingEngine:
|
||||
executed_orders.append(result.id)
|
||||
|
||||
# Record rebalancing event
|
||||
self._record_rebalancing_event(
|
||||
await self._record_rebalancing_event(
|
||||
'manual',
|
||||
target_allocations,
|
||||
current_allocations,
|
||||
executed_orders
|
||||
executed_orders,
|
||||
paper_trading
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -141,46 +117,134 @@ class RebalancingEngine:
|
||||
logger.error(f"Failed to rebalance portfolio: {e}")
|
||||
return False
|
||||
|
||||
def _record_rebalancing_event(
|
||||
async def check_and_rebalance_threshold(
|
||||
self,
|
||||
target_allocations: Dict[str, float],
|
||||
exchange_id: int,
|
||||
threshold: float = 0.05,
|
||||
paper_trading: bool = True
|
||||
) -> bool:
|
||||
"""Check if rebalancing is needed based on threshold and rebalance if needed.
|
||||
|
||||
Args:
|
||||
target_allocations: Target allocations
|
||||
exchange_id: Exchange ID
|
||||
threshold: Drift threshold (default 5%)
|
||||
paper_trading: Paper trading flag
|
||||
|
||||
Returns:
|
||||
True if rebalancing was triggered
|
||||
"""
|
||||
portfolio = await self.tracker.get_current_portfolio(paper_trading)
|
||||
total_value = Decimal(str(portfolio['performance']['current_value']))
|
||||
|
||||
current_allocations = {}
|
||||
for pos in portfolio['positions']:
|
||||
pos_value = Decimal(str(pos['quantity'])) * Decimal(str(pos['current_price']))
|
||||
current_allocations[pos['symbol']] = float(pos_value / total_value) if total_value > 0 else 0.0
|
||||
|
||||
# Check if any allocation exceeds threshold
|
||||
needs_rebalance = False
|
||||
for symbol, target_pct in target_allocations.items():
|
||||
current_pct = current_allocations.get(symbol, 0.0)
|
||||
deviation = abs(target_pct - current_pct)
|
||||
if deviation > threshold:
|
||||
needs_rebalance = True
|
||||
break
|
||||
|
||||
if needs_rebalance:
|
||||
return await self.rebalance(target_allocations, exchange_id, paper_trading, threshold)
|
||||
|
||||
return False
|
||||
|
||||
async def check_and_rebalance_time(
|
||||
self,
|
||||
target_allocations: Dict[str, float],
|
||||
exchange_id: int,
|
||||
last_rebalance_key: str,
|
||||
rebalance_interval_hours: int = 24,
|
||||
paper_trading: bool = True
|
||||
) -> bool:
|
||||
"""Check if rebalancing is needed based on time interval and rebalance if needed.
|
||||
|
||||
Args:
|
||||
target_allocations: Target allocations
|
||||
exchange_id: Exchange ID
|
||||
last_rebalance_key: Key for storing last rebalance time
|
||||
rebalance_interval_hours: Hours between rebalances
|
||||
paper_trading: Paper trading flag
|
||||
|
||||
Returns:
|
||||
True if rebalancing was triggered
|
||||
"""
|
||||
# Check last rebalance time
|
||||
async with self.db.get_session() as session:
|
||||
from src.core.database import AppState
|
||||
stmt = select(AppState).filter_by(key=last_rebalance_key)
|
||||
result = await session.execute(stmt)
|
||||
state = result.scalar_one_or_none()
|
||||
|
||||
if state:
|
||||
last_rebalance = datetime.fromisoformat(state.value) if isinstance(state.value, str) else state.value
|
||||
time_since_rebalance = datetime.utcnow() - last_rebalance
|
||||
if time_since_rebalance < timedelta(hours=rebalance_interval_hours):
|
||||
return False
|
||||
|
||||
# Time to rebalance
|
||||
success = await self.rebalance(target_allocations, exchange_id, paper_trading)
|
||||
|
||||
if success:
|
||||
# Update last rebalance time
|
||||
from src.core.database import AppState
|
||||
if state:
|
||||
state.value = datetime.utcnow().isoformat()
|
||||
else:
|
||||
state = AppState(key=last_rebalance_key, value=datetime.utcnow().isoformat())
|
||||
session.add(state)
|
||||
await session.commit()
|
||||
|
||||
return success
|
||||
|
||||
async def _record_rebalancing_event(
|
||||
self,
|
||||
trigger_type: str,
|
||||
target_allocations: Dict[str, float],
|
||||
before_allocations: Dict[str, float],
|
||||
orders_placed: List[int]
|
||||
orders_placed: List[int],
|
||||
paper_trading: bool = True
|
||||
):
|
||||
"""Record rebalancing event in database.
|
||||
|
||||
Args:
|
||||
trigger_type: Trigger type
|
||||
trigger_type: Trigger type (manual, threshold, time)
|
||||
target_allocations: Target allocations
|
||||
before_allocations: Allocations before rebalancing
|
||||
orders_placed: List of order IDs
|
||||
paper_trading: Paper trading flag
|
||||
"""
|
||||
session = self.db.get_session()
|
||||
try:
|
||||
# Get after allocations
|
||||
portfolio = self.tracker.get_current_portfolio()
|
||||
total_value = portfolio['performance']['current_value']
|
||||
after_allocations = {}
|
||||
for pos in portfolio['positions']:
|
||||
pos_value = pos['quantity'] * pos['current_price']
|
||||
after_allocations[pos['symbol']] = float(pos_value / total_value) if total_value > 0 else 0.0
|
||||
|
||||
event = RebalancingEvent(
|
||||
trigger_type=trigger_type,
|
||||
target_allocations=target_allocations,
|
||||
before_allocations=before_allocations,
|
||||
after_allocations=after_allocations,
|
||||
orders_placed=orders_placed,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
session.add(event)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Failed to record rebalancing event: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
async with self.db.get_session() as session:
|
||||
try:
|
||||
# Get after allocations
|
||||
portfolio = await self.tracker.get_current_portfolio(paper_trading)
|
||||
total_value = Decimal(str(portfolio['performance']['current_value']))
|
||||
after_allocations = {}
|
||||
for pos in portfolio['positions']:
|
||||
pos_value = Decimal(str(pos['quantity'])) * Decimal(str(pos['current_price']))
|
||||
after_allocations[pos['symbol']] = float(pos_value / total_value) if total_value > 0 else 0.0
|
||||
|
||||
event = RebalancingEvent(
|
||||
trigger_type=trigger_type,
|
||||
target_allocations=target_allocations,
|
||||
before_allocations=before_allocations,
|
||||
after_allocations=after_allocations,
|
||||
orders_placed=orders_placed,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to record rebalancing event: {e}")
|
||||
|
||||
|
||||
# Global rebalancing engine
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Position sizing rules."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from src.core.config import get_config
|
||||
from src.core.logger import get_logger
|
||||
from src.exchanges.base import BaseExchangeAdapter
|
||||
from src.data.indicators import get_indicators
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -68,24 +71,146 @@ class PositionSizingManager:
|
||||
self,
|
||||
win_rate: float,
|
||||
avg_win: float,
|
||||
avg_loss: float
|
||||
avg_loss: float,
|
||||
fractional: float = 0.25
|
||||
) -> Decimal:
|
||||
"""Calculate position size using Kelly Criterion.
|
||||
"""Calculate position size using Kelly Criterion with fractional Kelly.
|
||||
|
||||
Args:
|
||||
win_rate: Win rate (0.0 to 1.0)
|
||||
avg_win: Average win amount
|
||||
avg_loss: Average loss amount
|
||||
fractional: Fractional Kelly multiplier (0.25 = quarter Kelly, 0.5 = half Kelly)
|
||||
|
||||
Returns:
|
||||
Kelly percentage
|
||||
Fractional Kelly percentage
|
||||
"""
|
||||
if avg_loss == 0:
|
||||
if avg_loss == 0 or avg_win == 0:
|
||||
return Decimal(0)
|
||||
|
||||
# Kelly formula: (win_rate * avg_win - (1 - win_rate) * avg_loss) / avg_win
|
||||
kelly = (win_rate * avg_win - (1 - win_rate) * avg_loss) / avg_win
|
||||
# Use fractional Kelly (half) for safety
|
||||
return Decimal(str(kelly / 2))
|
||||
|
||||
# Apply fractional Kelly and ensure non-negative
|
||||
fractional_kelly = max(0.0, kelly * fractional)
|
||||
|
||||
return Decimal(str(fractional_kelly))
|
||||
|
||||
def calculate_volatility_adjusted_size(
|
||||
self,
|
||||
symbol: str,
|
||||
price: Decimal,
|
||||
balance: Decimal,
|
||||
base_risk_percent: Optional[Decimal] = None,
|
||||
volatility_multiplier: Optional[float] = None,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None
|
||||
) -> Decimal:
|
||||
"""Calculate position size adjusted for volatility (ATR-based).
|
||||
|
||||
Lower volatility = larger positions, higher volatility = smaller positions.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Entry price
|
||||
balance: Available balance
|
||||
base_risk_percent: Base risk percentage (uses config default if None)
|
||||
volatility_multiplier: Volatility adjustment factor (None = auto-calculate)
|
||||
exchange_adapter: Exchange adapter (optional)
|
||||
|
||||
Returns:
|
||||
Volatility-adjusted position size
|
||||
"""
|
||||
if base_risk_percent is None:
|
||||
base_risk_percent = Decimal(str(
|
||||
self.config.get("risk.position_size_percent", 2.0)
|
||||
)) / 100
|
||||
|
||||
# Get base position size
|
||||
base_size = self.calculate_size(symbol, price, balance, base_risk_percent, exchange_adapter)
|
||||
|
||||
# For now, return base size (volatility calculation would need historical data)
|
||||
# This is a placeholder - full implementation would fetch ATR from market data
|
||||
if volatility_multiplier is not None:
|
||||
adjusted_size = base_size * Decimal(str(volatility_multiplier))
|
||||
return max(Decimal(0), adjusted_size)
|
||||
|
||||
return base_size
|
||||
|
||||
def calculate_regime_aware_size(
|
||||
self,
|
||||
symbol: str,
|
||||
price: Decimal,
|
||||
balance: Decimal,
|
||||
market_regime: str,
|
||||
base_risk_percent: Optional[Decimal] = None,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None
|
||||
) -> Decimal:
|
||||
"""Calculate position size adjusted for market regime.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Entry price
|
||||
balance: Available balance
|
||||
market_regime: Market regime ('trending_up', 'trending_down', 'ranging', 'high_volatility', etc.)
|
||||
base_risk_percent: Base risk percentage
|
||||
exchange_adapter: Exchange adapter (optional)
|
||||
|
||||
Returns:
|
||||
Regime-aware position size
|
||||
"""
|
||||
if base_risk_percent is None:
|
||||
base_risk_percent = Decimal(str(
|
||||
self.config.get("risk.position_size_percent", 2.0)
|
||||
)) / 100
|
||||
|
||||
# Regime-based multipliers
|
||||
regime_multipliers = {
|
||||
'trending_up': 1.2, # Increase size in uptrends
|
||||
'trending_down': 0.5, # Reduce size in downtrends
|
||||
'ranging': 0.8, # Slightly reduce in ranging markets
|
||||
'high_volatility': 0.6, # Reduce size in high volatility
|
||||
'low_volatility': 1.1, # Increase size in low volatility
|
||||
'breakout': 1.15, # Increase size on breakouts
|
||||
'reversal': 0.7, # Reduce size on reversals
|
||||
}
|
||||
|
||||
multiplier = regime_multipliers.get(market_regime, 1.0)
|
||||
adjusted_risk = base_risk_percent * Decimal(str(multiplier))
|
||||
|
||||
return self.calculate_size(symbol, price, balance, adjusted_risk, exchange_adapter)
|
||||
|
||||
def calculate_confidence_based_size(
|
||||
self,
|
||||
symbol: str,
|
||||
price: Decimal,
|
||||
balance: Decimal,
|
||||
confidence: float,
|
||||
base_risk_percent: Optional[Decimal] = None,
|
||||
exchange_adapter: Optional[BaseExchangeAdapter] = None
|
||||
) -> Decimal:
|
||||
"""Calculate position size based on ML model confidence.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Entry price
|
||||
balance: Available balance
|
||||
confidence: Model confidence (0.0 to 1.0)
|
||||
base_risk_percent: Base risk percentage
|
||||
exchange_adapter: Exchange adapter (optional)
|
||||
|
||||
Returns:
|
||||
Confidence-adjusted position size
|
||||
"""
|
||||
if base_risk_percent is None:
|
||||
base_risk_percent = Decimal(str(
|
||||
self.config.get("risk.position_size_percent", 2.0)
|
||||
)) / 100
|
||||
|
||||
# Scale position size by confidence (0.5x to 1.5x)
|
||||
confidence_multiplier = 0.5 + (confidence * 1.0) # Maps [0,1] to [0.5, 1.5]
|
||||
adjusted_risk = base_risk_percent * Decimal(str(confidence_multiplier))
|
||||
|
||||
return self.calculate_size(symbol, price, balance, adjusted_risk, exchange_adapter)
|
||||
|
||||
def validate_position_size(
|
||||
self,
|
||||
|
||||
332
src/risk/var_calculator.py
Normal file
332
src/risk/var_calculator.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""Value at Risk (VaR) calculation methods."""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import select
|
||||
from src.core.database import get_database, Trade, PortfolioSnapshot
|
||||
from src.core.logger import get_logger
|
||||
|
||||
try:
|
||||
from scipy import stats
|
||||
SCIPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
SCIPY_AVAILABLE = False
|
||||
import warnings
|
||||
warnings.warn("scipy not available, using approximate z-scores")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class VaRCalculator:
|
||||
"""Calculate Value at Risk (VaR) using multiple methods."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize VaR calculator."""
|
||||
self.db = get_database()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def calculate_historical_var(
|
||||
self,
|
||||
portfolio_value: Decimal,
|
||||
confidence_level: float = 0.95,
|
||||
holding_period_days: int = 1,
|
||||
lookback_days: int = 252
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate Historical VaR.
|
||||
|
||||
Args:
|
||||
portfolio_value: Current portfolio value
|
||||
confidence_level: Confidence level (0.95 = 95%)
|
||||
holding_period_days: Holding period in days
|
||||
lookback_days: Lookback period for historical data
|
||||
|
||||
Returns:
|
||||
Dictionary with VaR metrics
|
||||
"""
|
||||
# Get historical portfolio returns
|
||||
returns = await self._get_portfolio_returns(lookback_days)
|
||||
|
||||
if len(returns) < 10:
|
||||
return {"error": "Insufficient historical data"}
|
||||
|
||||
# Calculate VaR
|
||||
var_percentile = (1 - confidence_level) * 100
|
||||
var_return = np.percentile(returns, var_percentile)
|
||||
|
||||
# Scale for holding period
|
||||
scaled_var_return = var_return * np.sqrt(holding_period_days)
|
||||
|
||||
var_amount = float(portfolio_value) * abs(scaled_var_return)
|
||||
|
||||
return {
|
||||
'method': 'historical',
|
||||
'confidence_level': confidence_level,
|
||||
'holding_period_days': holding_period_days,
|
||||
'var_amount': var_amount,
|
||||
'var_percent': abs(scaled_var_return) * 100,
|
||||
'var_return': float(scaled_var_return),
|
||||
'lookback_days': lookback_days
|
||||
}
|
||||
|
||||
async def calculate_parametric_var(
|
||||
self,
|
||||
portfolio_value: Decimal,
|
||||
confidence_level: float = 0.95,
|
||||
holding_period_days: int = 1,
|
||||
lookback_days: int = 252
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate Parametric (Variance-Covariance) VaR.
|
||||
|
||||
Assumes returns are normally distributed.
|
||||
|
||||
Args:
|
||||
portfolio_value: Current portfolio value
|
||||
confidence_level: Confidence level (0.95 = 95%)
|
||||
holding_period_days: Holding period in days
|
||||
lookback_days: Lookback period for historical data
|
||||
|
||||
Returns:
|
||||
Dictionary with VaR metrics
|
||||
"""
|
||||
# Get historical portfolio returns
|
||||
returns = await self._get_portfolio_returns(lookback_days)
|
||||
|
||||
if len(returns) < 10:
|
||||
return {"error": "Insufficient historical data"}
|
||||
|
||||
# Calculate mean and std
|
||||
mean_return = np.mean(returns)
|
||||
std_return = np.std(returns)
|
||||
|
||||
# Z-score for confidence level
|
||||
if SCIPY_AVAILABLE:
|
||||
z_score = stats.norm.ppf(1 - confidence_level)
|
||||
else:
|
||||
# Approximate z-scores for common confidence levels
|
||||
z_scores = {0.90: 1.28, 0.95: 1.65, 0.99: 2.33, 0.995: 2.58}
|
||||
z_score = z_scores.get(confidence_level, 1.65)
|
||||
|
||||
# Calculate VaR
|
||||
var_return = mean_return + z_score * std_return
|
||||
|
||||
# Scale for holding period
|
||||
scaled_var_return = var_return * np.sqrt(holding_period_days)
|
||||
|
||||
var_amount = float(portfolio_value) * abs(scaled_var_return)
|
||||
|
||||
return {
|
||||
'method': 'parametric',
|
||||
'confidence_level': confidence_level,
|
||||
'holding_period_days': holding_period_days,
|
||||
'var_amount': var_amount,
|
||||
'var_percent': abs(scaled_var_return) * 100,
|
||||
'var_return': float(scaled_var_return),
|
||||
'mean_return': float(mean_return),
|
||||
'std_return': float(std_return),
|
||||
'z_score': float(z_score),
|
||||
'lookback_days': lookback_days
|
||||
}
|
||||
|
||||
async def calculate_monte_carlo_var(
|
||||
self,
|
||||
portfolio_value: Decimal,
|
||||
confidence_level: float = 0.95,
|
||||
holding_period_days: int = 1,
|
||||
num_simulations: int = 10000,
|
||||
lookback_days: int = 252
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate Monte Carlo VaR.
|
||||
|
||||
Args:
|
||||
portfolio_value: Current portfolio value
|
||||
confidence_level: Confidence level (0.95 = 95%)
|
||||
holding_period_days: Holding period in days
|
||||
num_simulations: Number of Monte Carlo simulations
|
||||
lookback_days: Lookback period for historical data
|
||||
|
||||
Returns:
|
||||
Dictionary with VaR metrics
|
||||
"""
|
||||
# Get historical portfolio returns
|
||||
returns = await self._get_portfolio_returns(lookback_days)
|
||||
|
||||
if len(returns) < 10:
|
||||
return {"error": "Insufficient historical data"}
|
||||
|
||||
# Estimate parameters
|
||||
mean_return = np.mean(returns)
|
||||
std_return = np.std(returns)
|
||||
|
||||
# Generate Monte Carlo simulations
|
||||
np.random.seed(42) # For reproducibility
|
||||
simulated_returns = np.random.normal(
|
||||
mean_return * holding_period_days,
|
||||
std_return * np.sqrt(holding_period_days),
|
||||
num_simulations
|
||||
)
|
||||
|
||||
# Calculate VaR from simulations
|
||||
var_percentile = (1 - confidence_level) * 100
|
||||
var_return = np.percentile(simulated_returns, var_percentile)
|
||||
|
||||
var_amount = float(portfolio_value) * abs(var_return)
|
||||
|
||||
return {
|
||||
'method': 'monte_carlo',
|
||||
'confidence_level': confidence_level,
|
||||
'holding_period_days': holding_period_days,
|
||||
'var_amount': var_amount,
|
||||
'var_percent': abs(var_return) * 100,
|
||||
'var_return': float(var_return),
|
||||
'num_simulations': num_simulations,
|
||||
'lookback_days': lookback_days
|
||||
}
|
||||
|
||||
async def calculate_cvar(
|
||||
self,
|
||||
portfolio_value: Decimal,
|
||||
confidence_level: float = 0.95,
|
||||
holding_period_days: int = 1,
|
||||
lookback_days: int = 252
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate Conditional VaR (CVaR) / Expected Shortfall.
|
||||
|
||||
CVaR is the expected loss given that the loss exceeds VaR.
|
||||
|
||||
Args:
|
||||
portfolio_value: Current portfolio value
|
||||
confidence_level: Confidence level (0.95 = 95%)
|
||||
holding_period_days: Holding period in days
|
||||
lookback_days: Lookback period for historical data
|
||||
|
||||
Returns:
|
||||
Dictionary with CVaR metrics
|
||||
"""
|
||||
# Get historical portfolio returns
|
||||
returns = await self._get_portfolio_returns(lookback_days)
|
||||
|
||||
if len(returns) < 10:
|
||||
return {"error": "Insufficient historical data"}
|
||||
|
||||
# Scale for holding period
|
||||
scaled_returns = returns * np.sqrt(holding_period_days)
|
||||
|
||||
# Calculate VaR threshold
|
||||
var_percentile = (1 - confidence_level) * 100
|
||||
var_threshold = np.percentile(scaled_returns, var_percentile)
|
||||
|
||||
# Calculate CVaR (average of losses beyond VaR)
|
||||
tail_returns = scaled_returns[scaled_returns <= var_threshold]
|
||||
|
||||
if len(tail_returns) == 0:
|
||||
cvar_return = var_threshold
|
||||
else:
|
||||
cvar_return = np.mean(tail_returns)
|
||||
|
||||
cvar_amount = float(portfolio_value) * abs(cvar_return)
|
||||
|
||||
return {
|
||||
'method': 'cvar',
|
||||
'confidence_level': confidence_level,
|
||||
'holding_period_days': holding_period_days,
|
||||
'cvar_amount': cvar_amount,
|
||||
'cvar_percent': abs(cvar_return) * 100,
|
||||
'cvar_return': float(cvar_return),
|
||||
'var_threshold': float(var_threshold),
|
||||
'lookback_days': lookback_days
|
||||
}
|
||||
|
||||
async def calculate_all_var_methods(
|
||||
self,
|
||||
portfolio_value: Decimal,
|
||||
confidence_level: float = 0.95,
|
||||
holding_period_days: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculate VaR using all methods.
|
||||
|
||||
Args:
|
||||
portfolio_value: Current portfolio value
|
||||
confidence_level: Confidence level
|
||||
holding_period_days: Holding period in days
|
||||
|
||||
Returns:
|
||||
Dictionary with all VaR methods
|
||||
"""
|
||||
historical_var = await self.calculate_historical_var(
|
||||
portfolio_value, confidence_level, holding_period_days
|
||||
)
|
||||
parametric_var = await self.calculate_parametric_var(
|
||||
portfolio_value, confidence_level, holding_period_days
|
||||
)
|
||||
monte_carlo_var = await self.calculate_monte_carlo_var(
|
||||
portfolio_value, confidence_level, holding_period_days
|
||||
)
|
||||
cvar = await self.calculate_cvar(
|
||||
portfolio_value, confidence_level, holding_period_days
|
||||
)
|
||||
|
||||
return {
|
||||
'historical': historical_var,
|
||||
'parametric': parametric_var,
|
||||
'monte_carlo': monte_carlo_var,
|
||||
'cvar': cvar,
|
||||
'portfolio_value': float(portfolio_value),
|
||||
'confidence_level': confidence_level,
|
||||
'holding_period_days': holding_period_days
|
||||
}
|
||||
|
||||
async def _get_portfolio_returns(self, lookback_days: int) -> np.ndarray:
|
||||
"""Get historical portfolio returns.
|
||||
|
||||
Args:
|
||||
lookback_days: Number of days to look back
|
||||
|
||||
Returns:
|
||||
Array of daily returns
|
||||
"""
|
||||
try:
|
||||
async with self.db.get_session() as session:
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=lookback_days)
|
||||
|
||||
stmt = select(PortfolioSnapshot).filter(
|
||||
PortfolioSnapshot.timestamp >= start_date,
|
||||
PortfolioSnapshot.timestamp <= end_date
|
||||
).order_by(PortfolioSnapshot.timestamp)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
snapshots = result.scalars().all()
|
||||
|
||||
if len(snapshots) < 2:
|
||||
return np.array([])
|
||||
|
||||
# Extract portfolio values
|
||||
values = [float(snapshot.total_value) for snapshot in snapshots]
|
||||
|
||||
# Calculate daily returns
|
||||
returns = []
|
||||
for i in range(1, len(values)):
|
||||
if values[i-1] > 0:
|
||||
daily_return = (values[i] - values[i-1]) / values[i-1]
|
||||
returns.append(daily_return)
|
||||
|
||||
return np.array(returns)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting portfolio returns: {e}")
|
||||
return np.array([])
|
||||
|
||||
|
||||
# Global VaR calculator
|
||||
_var_calculator: Optional[VaRCalculator] = None
|
||||
|
||||
|
||||
def get_var_calculator() -> VaRCalculator:
|
||||
"""Get global VaR calculator instance."""
|
||||
global _var_calculator
|
||||
if _var_calculator is None:
|
||||
_var_calculator = VaRCalculator()
|
||||
return _var_calculator
|
||||
324
src/trading/execution_algorithms.py
Normal file
324
src/trading/execution_algorithms.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""Execution algorithms for TWAP, VWAP, and order book impact modeling."""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from src.core.logger import get_logger
|
||||
from src.exchanges.base import BaseExchangeAdapter
|
||||
from src.core.database import OrderType, OrderSide
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ExecutionAlgorithm:
|
||||
"""Base class for execution algorithms."""
|
||||
|
||||
def calculate_slice_size(
|
||||
self,
|
||||
total_quantity: Decimal,
|
||||
num_slices: int
|
||||
) -> Decimal:
|
||||
"""Calculate size for each slice.
|
||||
|
||||
Args:
|
||||
total_quantity: Total quantity to execute
|
||||
num_slices: Number of slices
|
||||
|
||||
Returns:
|
||||
Quantity per slice
|
||||
"""
|
||||
if num_slices <= 0:
|
||||
return total_quantity
|
||||
|
||||
return total_quantity / Decimal(str(num_slices))
|
||||
|
||||
async def estimate_slippage(
|
||||
self,
|
||||
quantity: Decimal,
|
||||
symbol: str,
|
||||
exchange_adapter: BaseExchangeAdapter,
|
||||
side: OrderSide
|
||||
) -> Decimal:
|
||||
"""Estimate slippage for order size.
|
||||
|
||||
Args:
|
||||
quantity: Order quantity
|
||||
symbol: Trading symbol
|
||||
exchange_adapter: Exchange adapter
|
||||
side: Order side
|
||||
|
||||
Returns:
|
||||
Estimated slippage as percentage (e.g., 0.001 for 0.1%)
|
||||
"""
|
||||
try:
|
||||
# Get order book
|
||||
order_book = await exchange_adapter.get_order_book(symbol, depth=10)
|
||||
|
||||
if not order_book:
|
||||
return Decimal("0.001") # Default 0.1% slippage
|
||||
|
||||
# Calculate market impact
|
||||
bids = order_book.get('bids', [])
|
||||
asks = order_book.get('asks', [])
|
||||
|
||||
if side == OrderSide.BUY:
|
||||
levels = asks # Buying from asks
|
||||
else:
|
||||
levels = bids # Selling to bids
|
||||
|
||||
if not levels:
|
||||
return Decimal("0.001")
|
||||
|
||||
# Calculate average price for quantity
|
||||
remaining = quantity
|
||||
total_cost = Decimal(0)
|
||||
|
||||
for price, size in levels:
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
slice_size = min(remaining, Decimal(str(size)))
|
||||
total_cost += slice_size * Decimal(str(price))
|
||||
remaining -= slice_size
|
||||
|
||||
if quantity == 0:
|
||||
return Decimal("0.001")
|
||||
|
||||
avg_price = total_cost / quantity
|
||||
|
||||
# Get mid price
|
||||
if bids and asks:
|
||||
mid_price = (Decimal(str(bids[0][0])) + Decimal(str(asks[0][0]))) / Decimal("2")
|
||||
|
||||
if mid_price > 0:
|
||||
slippage = abs(avg_price - mid_price) / mid_price
|
||||
return slippage
|
||||
|
||||
return Decimal("0.001")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error estimating slippage: {e}")
|
||||
return Decimal("0.001")
|
||||
|
||||
|
||||
class TWAPAlgorithm(ExecutionAlgorithm):
|
||||
"""Time-Weighted Average Price execution algorithm."""
|
||||
|
||||
async def execute_twap(
|
||||
self,
|
||||
symbol: str,
|
||||
total_quantity: Decimal,
|
||||
duration_minutes: int,
|
||||
num_slices: int,
|
||||
side: OrderSide,
|
||||
exchange_adapter: BaseExchangeAdapter
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute order using TWAP algorithm.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
total_quantity: Total quantity to execute
|
||||
duration_minutes: Execution duration in minutes
|
||||
num_slices: Number of slices
|
||||
side: Order side
|
||||
exchange_adapter: Exchange adapter
|
||||
|
||||
Returns:
|
||||
List of order slices with timing
|
||||
"""
|
||||
slice_size = self.calculate_slice_size(total_quantity, num_slices)
|
||||
slice_interval = duration_minutes / num_slices
|
||||
|
||||
orders = []
|
||||
for i in range(num_slices):
|
||||
slice_start_time = datetime.utcnow() + timedelta(minutes=i * slice_interval)
|
||||
orders.append({
|
||||
'symbol': symbol,
|
||||
'quantity': slice_size,
|
||||
'side': side,
|
||||
'order_type': OrderType.MARKET,
|
||||
'execute_at': slice_start_time,
|
||||
'slice_number': i + 1,
|
||||
'total_slices': num_slices
|
||||
})
|
||||
|
||||
return orders
|
||||
|
||||
|
||||
class VWAPAlgorithm(ExecutionAlgorithm):
|
||||
"""Volume-Weighted Average Price execution algorithm."""
|
||||
|
||||
async def calculate_vwap_reference(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
exchange_adapter: BaseExchangeAdapter,
|
||||
lookback_periods: int = 20
|
||||
) -> Optional[Decimal]:
|
||||
"""Calculate VWAP reference price from historical data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe (e.g., '1m', '5m')
|
||||
exchange_adapter: Exchange adapter
|
||||
lookback_periods: Number of periods to look back
|
||||
|
||||
Returns:
|
||||
VWAP reference price or None
|
||||
"""
|
||||
try:
|
||||
# For now, use current price as proxy
|
||||
# Full implementation would fetch historical OHLCV data
|
||||
ticker = await exchange_adapter.get_ticker(symbol)
|
||||
current_price = ticker.get('last', Decimal(0))
|
||||
|
||||
if current_price > 0:
|
||||
return current_price
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating VWAP reference: {e}")
|
||||
return None
|
||||
|
||||
async def execute_vwap(
|
||||
self,
|
||||
symbol: str,
|
||||
total_quantity: Decimal,
|
||||
duration_minutes: int,
|
||||
num_slices: int,
|
||||
side: OrderSide,
|
||||
exchange_adapter: BaseExchangeAdapter,
|
||||
target_vwap: Optional[Decimal] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute order using VWAP algorithm.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
total_quantity: Total quantity to execute
|
||||
duration_minutes: Execution duration in minutes
|
||||
num_slices: Number of slices
|
||||
side: Order side
|
||||
exchange_adapter: Exchange adapter
|
||||
target_vwap: Target VWAP (if None, calculates reference)
|
||||
|
||||
Returns:
|
||||
List of order slices
|
||||
"""
|
||||
# Calculate target VWAP if not provided
|
||||
if target_vwap is None:
|
||||
target_vwap = await self.calculate_vwap_reference(symbol, "1m", exchange_adapter)
|
||||
if target_vwap is None:
|
||||
# Fall back to TWAP
|
||||
twap = TWAPAlgorithm()
|
||||
return await twap.execute_twap(
|
||||
symbol, total_quantity, duration_minutes, num_slices, side, exchange_adapter
|
||||
)
|
||||
|
||||
# Distribute slices based on volume profile
|
||||
# For simplicity, use equal slices with volume-adjusted timing
|
||||
slice_size = self.calculate_slice_size(total_quantity, num_slices)
|
||||
slice_interval = duration_minutes / num_slices
|
||||
|
||||
orders = []
|
||||
for i in range(num_slices):
|
||||
slice_start_time = datetime.utcnow() + timedelta(minutes=i * slice_interval)
|
||||
orders.append({
|
||||
'symbol': symbol,
|
||||
'quantity': slice_size,
|
||||
'side': side,
|
||||
'order_type': OrderType.MARKET,
|
||||
'execute_at': slice_start_time,
|
||||
'target_price': target_vwap,
|
||||
'slice_number': i + 1,
|
||||
'total_slices': num_slices
|
||||
})
|
||||
|
||||
return orders
|
||||
|
||||
|
||||
class ExecutionAnalyzer:
|
||||
"""Analyzes execution quality and order book impact."""
|
||||
|
||||
async def analyze_execution(
|
||||
self,
|
||||
executed_orders: List[Dict[str, Any]],
|
||||
symbol: str,
|
||||
exchange_adapter: BaseExchangeAdapter
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze execution quality.
|
||||
|
||||
Args:
|
||||
executed_orders: List of executed orders
|
||||
symbol: Trading symbol
|
||||
exchange_adapter: Exchange adapter
|
||||
|
||||
Returns:
|
||||
Execution analysis metrics
|
||||
"""
|
||||
if not executed_orders:
|
||||
return {"error": "No executed orders"}
|
||||
|
||||
# Calculate metrics
|
||||
total_quantity = sum(Decimal(str(order.get('quantity', 0))) for order in executed_orders)
|
||||
total_cost = sum(
|
||||
Decimal(str(order.get('quantity', 0))) * Decimal(str(order.get('price', 0)))
|
||||
for order in executed_orders
|
||||
)
|
||||
|
||||
if total_quantity == 0:
|
||||
return {"error": "Total quantity is zero"}
|
||||
|
||||
avg_price = total_cost / total_quantity
|
||||
|
||||
# Get reference price (e.g., arrival price or VWAP)
|
||||
ticker = await exchange_adapter.get_ticker(symbol)
|
||||
reference_price = ticker.get('last', avg_price)
|
||||
|
||||
# Calculate slippage
|
||||
slippage = (avg_price - reference_price) / reference_price if reference_price > 0 else Decimal(0)
|
||||
|
||||
# Calculate market impact
|
||||
# Market impact = (execution_price - arrival_price) / arrival_price
|
||||
market_impact = abs(slippage)
|
||||
|
||||
return {
|
||||
'total_quantity': float(total_quantity),
|
||||
'total_cost': float(total_cost),
|
||||
'avg_execution_price': float(avg_price),
|
||||
'reference_price': float(reference_price),
|
||||
'slippage': float(slippage),
|
||||
'slippage_bps': float(slippage * Decimal("10000")), # Basis points
|
||||
'market_impact': float(market_impact),
|
||||
'num_slices': len(executed_orders)
|
||||
}
|
||||
|
||||
|
||||
# Global execution algorithms
|
||||
_twap_algorithm: Optional[TWAPAlgorithm] = None
|
||||
_vwap_algorithm: Optional[VWAPAlgorithm] = None
|
||||
_execution_analyzer: Optional[ExecutionAnalyzer] = None
|
||||
|
||||
|
||||
def get_twap_algorithm() -> TWAPAlgorithm:
|
||||
"""Get global TWAP algorithm instance."""
|
||||
global _twap_algorithm
|
||||
if _twap_algorithm is None:
|
||||
_twap_algorithm = TWAPAlgorithm()
|
||||
return _twap_algorithm
|
||||
|
||||
|
||||
def get_vwap_algorithm() -> VWAPAlgorithm:
|
||||
"""Get global VWAP algorithm instance."""
|
||||
global _vwap_algorithm
|
||||
if _vwap_algorithm is None:
|
||||
_vwap_algorithm = VWAPAlgorithm()
|
||||
return _vwap_algorithm
|
||||
|
||||
|
||||
def get_execution_analyzer() -> ExecutionAnalyzer:
|
||||
"""Get global execution analyzer instance."""
|
||||
global _execution_analyzer
|
||||
if _execution_analyzer is None:
|
||||
_execution_analyzer = ExecutionAnalyzer()
|
||||
return _execution_analyzer
|
||||
Reference in New Issue
Block a user