Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from datetime import datetime | |
| from llm_handler import LLMHandler | |
| from memory_manager import MemoryManager | |
| from tool_executor import ToolExecutor | |
| from character_learner import CharacterLearner | |
| from audio_handler import AudioHandler | |
| class ConversationalAgent: | |
| def __init__(self, model_name: str | None = None): | |
| # Allow dynamic model override from UI; fall back to env / default. | |
| self.llm_handler = LLMHandler(model_override=model_name) | |
| self.memory_manager = MemoryManager() | |
| self.tool_executor = ToolExecutor() | |
| self.character_learner = CharacterLearner(self.memory_manager) | |
| self.audio_handler = AudioHandler() | |
| self.user_id = os.getenv("USER_NAME", "User") | |
| self.memory_manager.initialize_user_profile(self.user_id) | |
| def process_message(self, message, history, use_voice=False): | |
| if not message or not message.strip(): | |
| return history, "" | |
| try: | |
| self.memory_manager.store_conversation(self.user_id, message, "user") | |
| learned_traits = self.character_learner.extract_and_learn(self.user_id, message, "user") | |
| relevant_memories = self.memory_manager.get_relevant_memories(self.user_id, message, limit=5) | |
| user_profile = self.memory_manager.get_user_profile(self.user_id) | |
| context = self._build_context(message, relevant_memories, user_profile) | |
| tools_needed = self._should_use_tools(message) | |
| tool_results = "" | |
| if tools_needed: | |
| tool_results = self.tool_executor.execute_tools(message) | |
| if tool_results: | |
| context += f"\n\nTool Results:\n{tool_results}" | |
| full_response = "" | |
| for chunk in self.llm_handler.generate_streaming(context): | |
| full_response += chunk | |
| self.memory_manager.store_conversation(self.user_id, full_response, "assistant") | |
| self.character_learner.extract_and_learn(self.user_id, full_response, "assistant") | |
| audio_output = None | |
| if use_voice and full_response: | |
| audio_output = self.audio_handler.text_to_speech(full_response) | |
| final_history = history + [[message, full_response]] | |
| yield final_history, "", audio_output | |
| except Exception as e: | |
| print(f"Error processing message: {str(e)}") | |
| error_history = history + [[message, f"I apologize, but I encountered an error: {str(e)}"]] | |
| yield error_history, "", None | |
| def process_voice_input(self, audio, history): | |
| if audio is None: | |
| return history, "" | |
| try: | |
| text = self.audio_handler.speech_to_text(audio) | |
| if text: | |
| return history, text | |
| else: | |
| return history, "" | |
| except Exception as e: | |
| print(f"Error processing voice input: {str(e)}") | |
| return history, "" | |
| def _build_context(self, message, memories, user_profile): | |
| context_parts = [] | |
| system_prompt = os.getenv("SYSTEM_PROMPT", "You are a helpful, friendly AI assistant.") | |
| context_parts.append(f"System: {system_prompt}") | |
| if user_profile: | |
| profile_info = f"\n\nUser Profile for {self.user_id}:" | |
| if user_profile.get('learned_traits'): | |
| traits = __import__('json').loads(user_profile['learned_traits']) | |
| if traits.get('interests'): | |
| profile_info += f"\nInterests: {', '.join(traits['interests'][:5])}" | |
| if traits.get('background'): | |
| profile_info += f"\nBackground: {traits['background']}" | |
| context_parts.append(profile_info) | |
| if memories: | |
| context_parts.append("\n\nRelevant past context:") | |
| for mem in memories[:3]: | |
| role = mem['role'].capitalize() | |
| msg = mem['message'][:200] | |
| context_parts.append(f"{role}: {msg}") | |
| context_parts.append(f"\n\nCurrent User Message: {message}") | |
| context_parts.append("\nAssistant:") | |
| return "\n".join(context_parts) | |
| def _should_use_tools(self, message): | |
| tool_keywords = ['search', 'find', 'google', 'what is', 'who is', 'calculate', 'compute', 'run', 'execute', 'code', 'add task', 'create task', 'workflow', 'automate'] | |
| message_lower = message.lower() | |
| return any(keyword in message_lower for keyword in tool_keywords) | |
| def get_memory_stats(self): | |
| profile = self.memory_manager.get_user_profile(self.user_id) | |
| if not profile: | |
| return "No profile data yet." | |
| stats = [f"**User:** {self.user_id}", f"**Profile Created:** {profile.get('created_at', 'Unknown')}"] | |
| if profile.get('learned_traits'): | |
| traits = __import__('json').loads(profile['learned_traits']) | |
| stats.append("\n**Learned Information:**") | |
| if traits.get('interests'): | |
| stats.append(f"- Interests ({len(traits['interests'])}): {', '.join(traits['interests'][:5])}") | |
| if traits.get('background'): | |
| stats.append(f"- Background: {traits['background']}") | |
| if traits.get('communication_style'): | |
| stats.append(f"- Communication Style: {traits['communication_style']}") | |
| if traits.get('expertise'): | |
| stats.append(f"- Expertise Areas: {', '.join(traits['expertise'][:3])}") | |
| stats.append(f"\n**Total Conversations:** {self.memory_manager.get_conversation_count(self.user_id)}") | |
| return "\n".join(stats) | |
| def create_interface(): | |
| """Create and configure the Gradio interface with LLM selection.""" | |
| # Top-level stateful agent; will be re-created when model changes. | |
| with gr.Blocks(title="Personal AI Assistant", theme=gr.themes.Soft()) as demo: | |
| # Global state: selected model and agent instance | |
| model_state = gr.State(os.getenv("PREFERRED_MODEL", "google/gemini-2.0-flash-exp")) | |
| agent_state = gr.State(ConversationalAgent(model_state.value)) | |
| gr.Markdown( | |
| """ | |
| # π€ Personal AI Assistant | |
| Your intelligent companion that learns about you over time and helps with various tasks. | |
| **Features:** | |
| - π¬ Natural conversation with memory | |
| - π€ Voice input and output | |
| - π§ Learns your preferences and interests | |
| - π§ Can search the web, execute code, and trigger workflows | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=500, | |
| show_label=True, | |
| avatar_images=(None, "π€"), | |
| type="messages" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| msg_input = gr.Textbox( | |
| label="Type your message...", | |
| placeholder="Ask me anything...", | |
| lines=2, | |
| show_label=False | |
| ) | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="π€ Voice", | |
| show_label=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| submit_btn = gr.Button("Send π¬", variant="primary") | |
| voice_btn = gr.Button("Send with Voice π") | |
| with gr.Column(scale=1): | |
| clear_btn = gr.Button("Clear ποΈ") | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio( | |
| label="Voice Response", | |
| autoplay=True, | |
| type="numpy" | |
| ) | |
| # Settings / model selection + memory stats | |
| with gr.Tab("Settings"): | |
| gr.Markdown("### π§ LLM Settings") | |
| with gr.Row(): | |
| llm_model = gr.Dropdown( | |
| label="Select LLM model (via OpenRouter)", | |
| choices=[ | |
| "google/gemini-2.0-flash-exp", | |
| "anthropic/claude-3.5-sonnet", | |
| "anthropic/claude-3.5-haiku", | |
| "openai/gpt-4.1-mini", | |
| ], | |
| value=os.getenv("PREFERRED_MODEL", "google/gemini-2.0-flash-exp"), | |
| ) | |
| apply_model_btn = gr.Button("Apply Model") | |
| with gr.Tab("Memory Stats"): | |
| gr.Markdown("### π Memory Stats") | |
| stats_display = gr.Markdown("Click 'Refresh Stats' to view") | |
| refresh_btn = gr.Button("Refresh Stats π") | |
| # --- Helper to (re)build agent when model changes --- | |
| def build_agent(model_name): | |
| try: | |
| return ConversationalAgent(model_name) | |
| except Exception as e: | |
| # Surface configuration errors (e.g., missing OPENROUTER_API_KEY) | |
| print(f"Error initializing ConversationalAgent with model '{model_name}': {e}") | |
| # Fallback to default without crashing UI | |
| return ConversationalAgent() | |
| # --- Stats updater uses current agent instance --- | |
| def update_stats(agent_obj): | |
| return agent_obj.get_memory_stats() | |
| # --- Core chat handlers using agent_state --- | |
| def respond(message, history, agent_obj): | |
| return agent_obj.process_message(message, history, use_voice=False) | |
| def respond_with_voice(message, history, agent_obj): | |
| return agent_obj.process_message(message, history, use_voice=True) | |
| def process_audio(audio, history, agent_obj): | |
| return agent_obj.process_voice_input(audio, history) | |
| def clear_history(): | |
| return [], "" | |
| # --- Wire chat events to use current agent_state --- | |
| msg_input.submit( | |
| respond, | |
| inputs=[msg_input, chatbot, agent_state], | |
| outputs=[chatbot, msg_input, audio_output], | |
| ) | |
| submit_btn.click( | |
| respond, | |
| inputs=[msg_input, chatbot, agent_state], | |
| outputs=[chatbot, msg_input, audio_output], | |
| ) | |
| voice_btn.click( | |
| respond_with_voice, | |
| inputs=[msg_input, chatbot, agent_state], | |
| outputs=[chatbot, msg_input, audio_output], | |
| ) | |
| audio_input.change( | |
| process_audio, | |
| inputs=[audio_input, chatbot, agent_state], | |
| outputs=[chatbot, msg_input], | |
| ) | |
| clear_btn.click( | |
| clear_history, | |
| outputs=[chatbot, msg_input], | |
| ) | |
| # Model apply: update model_state + agent_state | |
| def apply_model(selected_model, _old_agent): | |
| new_agent = build_agent(selected_model) | |
| return selected_model, new_agent | |
| apply_model_btn.click( | |
| apply_model, | |
| inputs=[llm_model, agent_state], | |
| outputs=[model_state, agent_state], | |
| ) | |
| # Stats button uses current agent_state | |
| refresh_btn.click( | |
| update_stats, | |
| inputs=[agent_state], | |
| outputs=[stats_display], | |
| ) | |
| # Load stats on startup | |
| demo.load( | |
| update_stats, | |
| inputs=[agent_state], | |
| outputs=[stats_display], | |
| ) | |
| return demo | |
| # ... (rest of your file is unchanged) | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |