import { useState, useEffect, useRef, useCallback } from 'react';
function useWebSocketChat(apiKey, userId, chatId) {
const [connected, setConnected] = useState(false);
const [messages, setMessages] = useState([]);
const [streamingResponse, setStreamingResponse] = useState('');
const [isStreaming, setIsStreaming] = useState(false);
const [error, setError] = useState(null);
const ws = useRef(null);
const connect = useCallback(() => {
if (!apiKey || !userId || !chatId) return;
const url = `ws://localhost:8000/ws/chat?api_key=${apiKey}&user_id=${userId}&chat_id=${chatId}`;
ws.current = new WebSocket(url);
ws.current.onopen = () => {
setConnected(true);
setError(null);
};
ws.current.onmessage = (event) => {
const data = JSON.parse(event.data);
switch (data.event_type) {
case 'connection_established':
console.log('Connected to chat');
break;
case 'message_received':
setMessages(prev => [...prev, {
id: data.message_id,
content: data.content,
sender: 'user',
timestamp: data.timestamp
}]);
break;
case 'message_stream_started':
setIsStreaming(true);
setStreamingResponse('');
break;
case 'message_stream_token':
setStreamingResponse(prev => prev + data.token);
break;
case 'message_stream_ended':
setMessages(prev => [...prev, {
id: data.message_id,
content: data.content,
sender: 'assistant',
timestamp: data.timestamp
}]);
setIsStreaming(false);
setStreamingResponse('');
break;
case 'error':
setError(data.error_message);
break;
}
};
ws.current.onclose = () => {
setConnected(false);
};
ws.current.onerror = (error) => {
setError('Connection error');
console.error('WebSocket error:', error);
};
}, [apiKey, userId, chatId]);
const sendMessage = useCallback((content, metadata = {}) => {
if (!ws.current || ws.current.readyState !== WebSocket.OPEN) {
throw new Error('Not connected');
}
const message = {
content,
...(Object.keys(metadata).length > 0 && { metadata })
};
ws.current.send(JSON.stringify(message));
}, []);
const disconnect = useCallback(() => {
if (ws.current) {
ws.current.close();
}
}, []);
useEffect(() => {
connect();
return disconnect;
}, [connect, disconnect]);
return {
connected,
messages,
streamingResponse,
isStreaming,
error,
sendMessage,
reconnect: connect
};
}
// Usage in component
function ChatInterface({ apiKey, userId, chatId }) {
const {
connected,
messages,
streamingResponse,
isStreaming,
error,
sendMessage,
reconnect
} = useWebSocketChat(apiKey, userId, chatId);
const [inputValue, setInputValue] = useState('');
const handleSend = () => {
if (inputValue.trim()) {
sendMessage(inputValue);
setInputValue('');
}
};
if (error) {
return (
<div className="error">
Error: {error}
<button onClick={reconnect}>Reconnect</button>
</div>
);
}
return (
<div className="chat-interface">
<div className="connection-status">
Status: {connected ? 'Connected' : 'Disconnected'}
</div>
<div className="messages">
{messages.map(msg => (
<div key={msg.id} className={`message ${msg.sender}`}>
<strong>{msg.sender === 'user' ? 'You' : 'AI'}:</strong>
<p>{msg.content}</p>
</div>
))}
{isStreaming && (
<div className="message assistant streaming">
<strong>AI:</strong>
<p>{streamingResponse}<span className="cursor">|</span></p>
</div>
)}
</div>
<div className="input-area">
<input
type="text"
value={inputValue}
onChange={(e) => setInputValue(e.target.value)}
onKeyPress={(e) => e.key === 'Enter' && handleSend()}
disabled={!connected}
placeholder="Type your message..."
/>
<button onClick={handleSend} disabled={!connected || !inputValue.trim()}>
Send
</button>
</div>
</div>
);
}