52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
|
|
"""Historical data management for backtesting."""
|
||
|
|
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import List, Optional
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
from src.core.database import get_database, MarketData
|
||
|
|
from src.core.logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class DataProvider:
|
||
|
|
"""Provides historical data for backtesting."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
"""Initialize data provider."""
|
||
|
|
self.db = get_database()
|
||
|
|
self.logger = get_logger(__name__)
|
||
|
|
|
||
|
|
def get_data(
|
||
|
|
self,
|
||
|
|
exchange: str,
|
||
|
|
symbol: str,
|
||
|
|
timeframe: str,
|
||
|
|
start_date: datetime,
|
||
|
|
end_date: datetime
|
||
|
|
) -> List[MarketData]:
|
||
|
|
"""Get historical data.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
exchange: Exchange name
|
||
|
|
symbol: Trading symbol
|
||
|
|
timeframe: Timeframe
|
||
|
|
start_date: Start date
|
||
|
|
end_date: End date
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of MarketData objects
|
||
|
|
"""
|
||
|
|
session = self.db.get_session()
|
||
|
|
try:
|
||
|
|
return session.query(MarketData).filter(
|
||
|
|
MarketData.exchange == exchange,
|
||
|
|
MarketData.symbol == symbol,
|
||
|
|
MarketData.timeframe == timeframe,
|
||
|
|
MarketData.timestamp >= start_date,
|
||
|
|
MarketData.timestamp <= end_date
|
||
|
|
).order_by(MarketData.timestamp).all()
|
||
|
|
finally:
|
||
|
|
session.close()
|
||
|
|
|