Local changes: Updated model training, removed debug instrumentation, and configuration improvements
This commit is contained in:
96
frontend/src/hooks/__tests__/useProviderStatus.test.ts
Normal file
96
frontend/src/hooks/__tests__/useProviderStatus.test.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { ReactNode } from 'react'
|
||||
import { useProviderStatus } from '../useProviderStatus'
|
||||
import * as marketDataApi from '../../api/marketData'
|
||||
|
||||
vi.mock('../../api/marketData')
|
||||
vi.mock('../../components/WebSocketProvider', () => ({
|
||||
useWebSocketContext: () => ({
|
||||
isConnected: true,
|
||||
lastMessage: null,
|
||||
subscribe: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('useProviderStatus', () => {
|
||||
let queryClient: QueryClient
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
const wrapper = ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client= { queryClient } > { children } </QueryClientProvider>
|
||||
)
|
||||
|
||||
it('returns loading state initially', () => {
|
||||
vi.mocked(marketDataApi.marketDataApi.getProviderStatus).mockImplementation(
|
||||
() => new Promise(() => { }) // Never resolves
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useProviderStatus(), { wrapper })
|
||||
|
||||
expect(result.current.isLoading).toBe(true)
|
||||
})
|
||||
|
||||
it('returns provider status after loading', async () => {
|
||||
const mockStatus = {
|
||||
primary_provider: 'CoinGecko',
|
||||
primary_healthy: true,
|
||||
fallback_provider: 'CCXT',
|
||||
fallback_healthy: true,
|
||||
last_check: '2024-01-01T00:00:00Z',
|
||||
}
|
||||
|
||||
vi.mocked(marketDataApi.marketDataApi.getProviderStatus).mockResolvedValue(mockStatus)
|
||||
|
||||
const { result } = renderHook(() => useProviderStatus(), { wrapper })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoading).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.status).toEqual(mockStatus)
|
||||
})
|
||||
|
||||
it('returns error when API call fails', async () => {
|
||||
const mockError = new Error('API Error')
|
||||
vi.mocked(marketDataApi.marketDataApi.getProviderStatus).mockRejectedValue(mockError)
|
||||
|
||||
const { result } = renderHook(() => useProviderStatus(), { wrapper })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.error).not.toBeNull()
|
||||
})
|
||||
|
||||
expect(result.current.error?.message).toBe('API Error')
|
||||
})
|
||||
|
||||
it('provides refetch function', async () => {
|
||||
const mockStatus = {
|
||||
primary_provider: 'CoinGecko',
|
||||
primary_healthy: true,
|
||||
fallback_provider: 'CCXT',
|
||||
fallback_healthy: true,
|
||||
last_check: '2024-01-01T00:00:00Z',
|
||||
}
|
||||
|
||||
vi.mocked(marketDataApi.marketDataApi.getProviderStatus).mockResolvedValue(mockStatus)
|
||||
|
||||
const { result } = renderHook(() => useProviderStatus(), { wrapper })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoading).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.refetch).toBeDefined()
|
||||
expect(typeof result.current.refetch).toBe('function')
|
||||
})
|
||||
})
|
||||
134
frontend/src/hooks/__tests__/useRealtimeData.test.ts
Normal file
134
frontend/src/hooks/__tests__/useRealtimeData.test.ts
Normal file
@@ -0,0 +1,134 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { renderHook, waitFor } from '@testing-library/react'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { ReactNode } from 'react'
|
||||
import { useRealtimeData } from '../useRealtimeData'
|
||||
|
||||
const mockSubscribe = vi.fn(() => vi.fn())
|
||||
const mockShowInfo = vi.fn()
|
||||
const mockShowWarning = vi.fn()
|
||||
|
||||
vi.mock('../../components/WebSocketProvider', () => ({
|
||||
useWebSocketContext: () => ({
|
||||
isConnected: true,
|
||||
lastMessage: null,
|
||||
subscribe: mockSubscribe,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../contexts/SnackbarContext', () => ({
|
||||
useSnackbar: () => ({
|
||||
showInfo: mockShowInfo,
|
||||
showWarning: mockShowWarning,
|
||||
showError: vi.fn(),
|
||||
showSuccess: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('useRealtimeData', () => {
|
||||
let queryClient: QueryClient
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
const wrapper = ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client= { queryClient } > { children } </QueryClientProvider>
|
||||
)
|
||||
|
||||
it('subscribes to all message types on mount', () => {
|
||||
renderHook(() => useRealtimeData(), { wrapper })
|
||||
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('order_update', expect.any(Function))
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('position_update', expect.any(Function))
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('price_update', expect.any(Function))
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('alert_triggered', expect.any(Function))
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('strategy_signal', expect.any(Function))
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('system_event', expect.any(Function))
|
||||
})
|
||||
|
||||
it('handles order update messages', async () => {
|
||||
const unsubscribeMock = vi.fn()
|
||||
let orderHandler: ((message: any) => void) | undefined
|
||||
|
||||
mockSubscribe.mockImplementation((type: string, handler: any) => {
|
||||
if (type === 'order_update') {
|
||||
orderHandler = handler
|
||||
}
|
||||
return unsubscribeMock
|
||||
})
|
||||
|
||||
const invalidateSpy = vi.spyOn(queryClient, 'invalidateQueries')
|
||||
|
||||
renderHook(() => useRealtimeData(), { wrapper })
|
||||
|
||||
// Simulate order filled message
|
||||
if (orderHandler) {
|
||||
orderHandler({ type: 'order_update', order_id: '123', status: 'filled' })
|
||||
}
|
||||
|
||||
await waitFor(() => {
|
||||
expect(invalidateSpy).toHaveBeenCalledWith({ queryKey: ['orders'] })
|
||||
expect(invalidateSpy).toHaveBeenCalledWith({ queryKey: ['balance'] })
|
||||
})
|
||||
|
||||
expect(mockShowInfo).toHaveBeenCalledWith('Order 123 filled')
|
||||
})
|
||||
|
||||
it('handles alert triggered messages', async () => {
|
||||
let alertHandler: ((message: any) => void) | undefined
|
||||
|
||||
mockSubscribe.mockImplementation((type: string, handler: any) => {
|
||||
if (type === 'alert_triggered') {
|
||||
alertHandler = handler
|
||||
}
|
||||
return vi.fn()
|
||||
})
|
||||
|
||||
const invalidateSpy = vi.spyOn(queryClient, 'invalidateQueries')
|
||||
|
||||
renderHook(() => useRealtimeData(), { wrapper })
|
||||
|
||||
// Simulate alert triggered
|
||||
if (alertHandler) {
|
||||
alertHandler({ type: 'alert_triggered', alert_name: 'BTC Price Alert' })
|
||||
}
|
||||
|
||||
await waitFor(() => {
|
||||
expect(invalidateSpy).toHaveBeenCalledWith({ queryKey: ['alerts'] })
|
||||
})
|
||||
|
||||
expect(mockShowWarning).toHaveBeenCalledWith('Alert triggered: BTC Price Alert')
|
||||
})
|
||||
|
||||
it('unsubscribes from all message types on unmount', () => {
|
||||
const unsubscribeMocks = {
|
||||
order_update: vi.fn(),
|
||||
position_update: vi.fn(),
|
||||
price_update: vi.fn(),
|
||||
alert_triggered: vi.fn(),
|
||||
strategy_signal: vi.fn(),
|
||||
system_event: vi.fn(),
|
||||
}
|
||||
|
||||
mockSubscribe.mockImplementation((type: string) => {
|
||||
return unsubscribeMocks[type as keyof typeof unsubscribeMocks] || vi.fn()
|
||||
})
|
||||
|
||||
const { unmount } = renderHook(() => useRealtimeData(), { wrapper })
|
||||
|
||||
unmount()
|
||||
|
||||
expect(unsubscribeMocks.order_update).toHaveBeenCalled()
|
||||
expect(unsubscribeMocks.position_update).toHaveBeenCalled()
|
||||
expect(unsubscribeMocks.price_update).toHaveBeenCalled()
|
||||
expect(unsubscribeMocks.alert_triggered).toHaveBeenCalled()
|
||||
expect(unsubscribeMocks.strategy_signal).toHaveBeenCalled()
|
||||
expect(unsubscribeMocks.system_event).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
214
frontend/src/hooks/__tests__/useWebSocket.test.ts
Normal file
214
frontend/src/hooks/__tests__/useWebSocket.test.ts
Normal file
@@ -0,0 +1,214 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
import { useWebSocket } from '../useWebSocket'
|
||||
|
||||
// Mock WebSocket class
|
||||
class MockWebSocket {
|
||||
static CONNECTING = 0
|
||||
static OPEN = 1
|
||||
static CLOSING = 2
|
||||
static CLOSED = 3
|
||||
|
||||
readyState = MockWebSocket.CONNECTING
|
||||
url: string
|
||||
onopen: (() => void) | null = null
|
||||
onmessage: ((event: { data: string }) => void) | null = null
|
||||
onerror: ((error: Event) => void) | null = null
|
||||
onclose: (() => void) | null = null
|
||||
|
||||
constructor(url: string) {
|
||||
this.url = url
|
||||
mockWebSocketInstances.push(this)
|
||||
}
|
||||
|
||||
send = vi.fn()
|
||||
close = vi.fn(() => {
|
||||
this.readyState = MockWebSocket.CLOSED
|
||||
this.onclose?.()
|
||||
})
|
||||
|
||||
// Helper to simulate connection
|
||||
simulateOpen() {
|
||||
this.readyState = MockWebSocket.OPEN
|
||||
this.onopen?.()
|
||||
}
|
||||
|
||||
// Helper to simulate message
|
||||
simulateMessage(data: object) {
|
||||
this.onmessage?.({ data: JSON.stringify(data) })
|
||||
}
|
||||
|
||||
// Helper to simulate close
|
||||
simulateClose() {
|
||||
this.readyState = MockWebSocket.CLOSED
|
||||
this.onclose?.()
|
||||
}
|
||||
|
||||
// Helper to simulate error
|
||||
simulateError() {
|
||||
this.onerror?.(new Event('error'))
|
||||
}
|
||||
}
|
||||
|
||||
let mockWebSocketInstances: MockWebSocket[] = []
|
||||
|
||||
describe('useWebSocket', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
mockWebSocketInstances = []
|
||||
; (globalThis as any).WebSocket = MockWebSocket
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('initializes as disconnected', () => {
|
||||
const { result } = renderHook(() => useWebSocket('ws://localhost:8000/ws'))
|
||||
|
||||
expect(result.current.isConnected).toBe(false)
|
||||
})
|
||||
|
||||
it('connects and sets isConnected to true', async () => {
|
||||
const { result } = renderHook(() => useWebSocket('ws://localhost:8000/ws'))
|
||||
|
||||
// Simulate WebSocket open
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateOpen()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isConnected).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('receives and stores messages', async () => {
|
||||
const { result } = renderHook(() => useWebSocket('ws://localhost:8000/ws'))
|
||||
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateOpen()
|
||||
})
|
||||
|
||||
const testMessage = { type: 'order_update', order_id: '123' }
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateMessage(testMessage)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.lastMessage).toEqual(expect.objectContaining(testMessage))
|
||||
})
|
||||
})
|
||||
|
||||
it('adds messages to message history', async () => {
|
||||
const { result } = renderHook(() => useWebSocket('ws://localhost:8000/ws'))
|
||||
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateOpen()
|
||||
})
|
||||
|
||||
const message1 = { type: 'order_update' as const, order_id: '1' }
|
||||
const message2 = { type: 'position_update' as const, position_id: '2' }
|
||||
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateMessage(message1)
|
||||
mockWebSocketInstances[0].simulateMessage(message2)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.messageHistory).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
it('sends messages when connected', async () => {
|
||||
const { result } = renderHook(() => useWebSocket('ws://localhost:8000/ws'))
|
||||
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateOpen()
|
||||
})
|
||||
|
||||
const testMessage = { action: 'subscribe', channel: 'prices' }
|
||||
act(() => {
|
||||
result.current.sendMessage(testMessage)
|
||||
})
|
||||
|
||||
expect(mockWebSocketInstances[0].send).toHaveBeenCalledWith(JSON.stringify(testMessage))
|
||||
})
|
||||
|
||||
it('allows subscribing to specific message types', async () => {
|
||||
const { result } = renderHook(() => useWebSocket('ws://localhost:8000/ws'))
|
||||
const handler = vi.fn()
|
||||
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateOpen()
|
||||
result.current.subscribe('order_update', handler)
|
||||
})
|
||||
|
||||
const testMessage = { type: 'order_update', order_id: '123' }
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateMessage(testMessage)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(handler).toHaveBeenCalledWith(expect.objectContaining(testMessage))
|
||||
})
|
||||
})
|
||||
|
||||
it('cleans up subscription on unsubscribe', async () => {
|
||||
const { result } = renderHook(() => useWebSocket('ws://localhost:8000/ws'))
|
||||
const handler = vi.fn()
|
||||
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateOpen()
|
||||
})
|
||||
|
||||
let unsubscribe: () => void
|
||||
act(() => {
|
||||
unsubscribe = result.current.subscribe('order_update', handler)
|
||||
})
|
||||
|
||||
act(() => {
|
||||
unsubscribe()
|
||||
})
|
||||
|
||||
const testMessage = { type: 'order_update', order_id: '123' }
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateMessage(testMessage)
|
||||
})
|
||||
|
||||
// Handler should not be called after unsubscribe
|
||||
expect(handler).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('sets isConnected to false on disconnect', async () => {
|
||||
const { result } = renderHook(() => useWebSocket('ws://localhost:8000/ws'))
|
||||
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateOpen()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isConnected).toBe(true)
|
||||
})
|
||||
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateClose()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isConnected).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('closes WebSocket on unmount', () => {
|
||||
const { unmount } = renderHook(() => useWebSocket('ws://localhost:8000/ws'))
|
||||
|
||||
act(() => {
|
||||
mockWebSocketInstances[0].simulateOpen()
|
||||
})
|
||||
|
||||
unmount()
|
||||
|
||||
expect(mockWebSocketInstances[0].close).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
60
frontend/src/hooks/useProviderStatus.ts
Normal file
60
frontend/src/hooks/useProviderStatus.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { marketDataApi, ProviderStatus } from '../api/marketData'
|
||||
import { useWebSocketContext } from '../components/WebSocketProvider'
|
||||
import { useEffect, useState } from 'react'
|
||||
|
||||
export interface ProviderStatusData {
|
||||
status: ProviderStatus | null
|
||||
isLoading: boolean
|
||||
error: Error | null
|
||||
refetch: () => void
|
||||
}
|
||||
|
||||
export function useProviderStatus(): ProviderStatusData {
|
||||
const { isConnected, lastMessage } = useWebSocketContext()
|
||||
const [status, setStatus] = useState<ProviderStatus | null>(null)
|
||||
|
||||
const {
|
||||
data,
|
||||
isLoading,
|
||||
error,
|
||||
refetch,
|
||||
} = useQuery({
|
||||
queryKey: ['provider-status'],
|
||||
queryFn: () => marketDataApi.getProviderStatus(),
|
||||
refetchInterval: 10000, // Refetch every 10 seconds
|
||||
})
|
||||
|
||||
// Update local state when query data changes
|
||||
useEffect(() => {
|
||||
if (data) {
|
||||
setStatus(data)
|
||||
}
|
||||
}, [data])
|
||||
|
||||
// Listen for provider status updates via WebSocket
|
||||
useEffect(() => {
|
||||
if (!isConnected || !lastMessage) return
|
||||
|
||||
try {
|
||||
const message = typeof lastMessage === 'string' ? JSON.parse(lastMessage) : lastMessage
|
||||
|
||||
if (message.type === 'provider_status_update') {
|
||||
// Update status from WebSocket message
|
||||
setStatus((prev) => ({
|
||||
...prev!,
|
||||
...message.data,
|
||||
}))
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore parsing errors
|
||||
}
|
||||
}, [isConnected, lastMessage])
|
||||
|
||||
return {
|
||||
status,
|
||||
isLoading,
|
||||
error: error as Error | null,
|
||||
refetch,
|
||||
}
|
||||
}
|
||||
69
frontend/src/hooks/useRealtimeData.ts
Normal file
69
frontend/src/hooks/useRealtimeData.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { useEffect } from 'react'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import { useWebSocketContext } from '../components/WebSocketProvider'
|
||||
import { useSnackbar } from '../contexts/SnackbarContext'
|
||||
|
||||
export function useRealtimeData() {
|
||||
const queryClient = useQueryClient()
|
||||
const { isConnected, lastMessage, subscribe } = useWebSocketContext()
|
||||
const { showInfo, showWarning } = useSnackbar()
|
||||
|
||||
useEffect(() => {
|
||||
if (!isConnected) return
|
||||
|
||||
// Subscribe to order updates
|
||||
const unsubscribeOrder = subscribe('order_update', (message) => {
|
||||
queryClient.invalidateQueries({ queryKey: ['orders'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['balance'] })
|
||||
|
||||
if (message.status === 'filled') {
|
||||
showInfo(`Order ${message.order_id} filled`)
|
||||
}
|
||||
})
|
||||
|
||||
// Subscribe to position updates
|
||||
const unsubscribePosition = subscribe('position_update', (message) => {
|
||||
queryClient.invalidateQueries({ queryKey: ['positions'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['portfolio'] })
|
||||
})
|
||||
|
||||
// Subscribe to price updates
|
||||
const unsubscribePrice = subscribe('price_update', (message) => {
|
||||
// Invalidate market data queries for the specific symbol
|
||||
if (message.symbol) {
|
||||
queryClient.invalidateQueries({ queryKey: ['market-data', message.symbol] })
|
||||
}
|
||||
})
|
||||
|
||||
// Subscribe to alert triggers
|
||||
const unsubscribeAlert = subscribe('alert_triggered', (message) => {
|
||||
queryClient.invalidateQueries({ queryKey: ['alerts'] })
|
||||
showWarning(`Alert triggered: ${message.alert_name || 'Unknown alert'}`)
|
||||
})
|
||||
|
||||
// Subscribe to strategy signals
|
||||
const unsubscribeSignal = subscribe('strategy_signal', (message) => {
|
||||
queryClient.invalidateQueries({ queryKey: ['autopilot-status'] })
|
||||
if (message.signal_type) {
|
||||
showInfo(`Strategy signal: ${message.signal_type.toUpperCase()} for ${message.symbol || 'N/A'}`)
|
||||
}
|
||||
})
|
||||
|
||||
// Subscribe to system events
|
||||
const unsubscribeSystem = subscribe('system_event', (message) => {
|
||||
if (message.event_type === 'error') {
|
||||
showWarning(`System event: ${message.message || 'Unknown error'}`)
|
||||
}
|
||||
})
|
||||
|
||||
return () => {
|
||||
unsubscribeOrder()
|
||||
unsubscribePosition()
|
||||
unsubscribePrice()
|
||||
unsubscribeAlert()
|
||||
unsubscribeSignal()
|
||||
unsubscribeSystem()
|
||||
}
|
||||
}, [isConnected, subscribe, queryClient, showInfo, showWarning])
|
||||
}
|
||||
|
||||
116
frontend/src/hooks/useWebSocket.ts
Normal file
116
frontend/src/hooks/useWebSocket.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import { useEffect, useRef, useState, useCallback } from 'react'
|
||||
|
||||
export interface WebSocketMessage {
|
||||
type: 'order_update' | 'position_update' | 'price_update' | 'alert_triggered' | 'strategy_signal' | 'system_event'
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
export function useWebSocket(url: string) {
|
||||
const [isConnected, setIsConnected] = useState(false)
|
||||
const [lastMessage, setLastMessage] = useState<WebSocketMessage | null>(null)
|
||||
const [messageHistory, setMessageHistory] = useState<WebSocketMessage[]>([])
|
||||
const wsRef = useRef<WebSocket | null>(null)
|
||||
const reconnectTimeoutRef = useRef<NodeJS.Timeout>()
|
||||
const messageHandlersRef = useRef<Map<string, (message: WebSocketMessage) => void>>(new Map())
|
||||
const isConnectingRef = useRef(false)
|
||||
|
||||
useEffect(() => {
|
||||
let isMounted = true
|
||||
|
||||
const connect = () => {
|
||||
// Prevent duplicate connections
|
||||
if (isConnectingRef.current || wsRef.current?.readyState === WebSocket.OPEN || wsRef.current?.readyState === WebSocket.CONNECTING) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
isConnectingRef.current = true
|
||||
const ws = new WebSocket(url)
|
||||
wsRef.current = ws
|
||||
|
||||
ws.onopen = () => {
|
||||
if (isMounted) {
|
||||
isConnectingRef.current = false
|
||||
setIsConnected(true)
|
||||
console.log('WebSocket connected')
|
||||
}
|
||||
}
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
if (!isMounted) return
|
||||
try {
|
||||
const message = JSON.parse(event.data) as WebSocketMessage
|
||||
setLastMessage(message)
|
||||
setMessageHistory((prev) => [...prev.slice(-99), message]) // Keep last 100 messages
|
||||
|
||||
// Call registered handlers for this message type
|
||||
const handlers = messageHandlersRef.current.get(message.type)
|
||||
if (handlers) {
|
||||
handlers(message)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to parse WebSocket message:', error)
|
||||
}
|
||||
}
|
||||
|
||||
ws.onerror = (error) => {
|
||||
isConnectingRef.current = false
|
||||
// Only log error if we're still mounted (avoid noise from StrictMode cleanup)
|
||||
if (isMounted) {
|
||||
console.error('WebSocket error:', error)
|
||||
}
|
||||
}
|
||||
|
||||
ws.onclose = () => {
|
||||
isConnectingRef.current = false
|
||||
if (isMounted) {
|
||||
setIsConnected(false)
|
||||
console.log('WebSocket disconnected')
|
||||
// Reconnect after 3 seconds
|
||||
reconnectTimeoutRef.current = setTimeout(() => {
|
||||
if (isMounted) connect()
|
||||
}, 3000)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
isConnectingRef.current = false
|
||||
console.error('Failed to create WebSocket:', error)
|
||||
}
|
||||
}
|
||||
|
||||
connect()
|
||||
|
||||
return () => {
|
||||
isMounted = false
|
||||
isConnectingRef.current = false
|
||||
if (reconnectTimeoutRef.current) {
|
||||
clearTimeout(reconnectTimeoutRef.current)
|
||||
}
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close()
|
||||
wsRef.current = null
|
||||
}
|
||||
}
|
||||
}, [url])
|
||||
|
||||
const sendMessage = useCallback((message: any) => {
|
||||
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(JSON.stringify(message))
|
||||
}
|
||||
}, [])
|
||||
|
||||
const subscribe = useCallback((messageType: string, handler: (message: WebSocketMessage) => void) => {
|
||||
messageHandlersRef.current.set(messageType, handler)
|
||||
return () => {
|
||||
messageHandlersRef.current.delete(messageType)
|
||||
}
|
||||
}, [])
|
||||
|
||||
return {
|
||||
isConnected,
|
||||
lastMessage,
|
||||
messageHistory,
|
||||
sendMessage,
|
||||
subscribe,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user