构建从反馈中学习的AI代理

我构建了 REFLEX,一个结合强化学习 (RL)、RAG(检索增强生成)和现代代理框架的研究代理,创造出一个真正随着每次对话而改进的东西。

构建从反馈中学习的AI代理

我花了三周时间开发了一个AI研究助手。它工作得不错。用户可以提问,得到答案,然后继续。但有些事情让我困扰。

每次对话都是一样的。这个代理从未变得更好。它犯同样的错误,给出相似的答案,从未从用户实际需求中学习。

那时我意识到:如果代理可以从每次互动中学习呢? 如果它能随着时间推移变得更聪明,而不是通过重新训练,而是通过真实用户的实时反馈?

所以我构建了 REFLEX,一个结合强化学习 (RL)、RAG(检索增强生成)和现代代理框架的研究代理,创造出一个真正随着每次对话而改进的东西。

它的不同之处在于: 当你对一个回答进行评分时,这些反馈不会消失。它们成为训练数据。代理会学习哪些技能有效,用户更喜欢哪些方法,以及如何在下次做得更好。

我将带你一步步了解我是如何构建它的。你将学习如何将强化学习与RAG结合起来,如何实现一个积累知识的技能库,以及如何创建一个让你的代理真正变聪明的反馈循环。

我们将涵盖架构、代码和使这一切奏效的决策。最后,你将理解如何构建一个不仅回答问题,还能更好地回答问题的AI代理。

1、静态AI的问题

大多数AI助手的工作方式是这样的:你问一个问题,他们生成一个答案,就这样。对话结束。没有学到任何东西。没有改进。

我用一个简单的聊天机器人进行了测试。我问了它同一个问题十次。得到了十个相似的答案。第十个回答中的错误和第一个回答中的错误一样。

这就是静态系统的问题。它们被冻结在时间中。它们无法适应。无法从错误中学习。

想想人类是如何学习的。我们尝试某事。我们获得反馈。我们调整。我们再次尝试。每一次迭代都会让我们变得更好。

AI代理也应该这样工作。但大多数不是这样。它们被训练一次,部署后就运行。没有反馈循环。没有改进机制。没有变得更聪明的方法。

这对于研究助手来说尤其令人沮丧。研究问题是微妙的。对一个用户有用的方法可能对另一个用户没用。代理需要学习这些偏好。它需要了解哪些方法会导致更好的结果。

这就是强化学习的作用。RL让代理从经验中学习。每一次互动都成为一个学习机会。每一条反馈都成为训练数据。

RL 本身是不够的。你还需要一种检索相关信息的方法。这就是 RAG 的作用。RAG 给代理提供了一个知识库。它可以搜索文档,找到相关上下文,并利用这些信息生成更好的答案。

RLRAG 结合起来,你就得到了一个强大的代理:它既能访问知识,又能从经验中学习。这就是 REFLEX 所做的。

这里有一个简单的例子。静态代理可能会始终以相同的方式回答问题。它不知道用户是否更喜欢详细的解释还是快速的摘要。它不会学习某些主题需要网络搜索,而其他主题则可以使用知识库。

REFLEX 学习这些模式。它跟踪哪种技能最适合哪种类型的问题。它记住用户评分高的内容。随着时间的推移,它会更好地匹配正确的做法到正确的问题。

区别是巨大的。经过50次对话后,静态代理和第一天一样。经过50次对话后,REFLEX 已经学到了50件新事物。它明显更好。

这就是为什么我们需要自我改进的代理。不是因为静态代理不好。它们对于简单任务来说是好的。但对于复杂的研究、微妙的问题、现实世界的应用场景,我们需要能够适应和改进的代理。

2、架构

在深入代码之前,这里是架构。REFLEX 有五个主要组件协同工作。

首先,是 Agent Core。这是大脑。它使用 Agno 框架,Claude Sonnet 4 作为 LLM。代理可以使用工具如网络搜索,访问知识库,并保持对话上下文。

第二,是 RL Trainer。这处理学习。它存储轨迹,计算优势,并根据反馈更新代理的策略。它使用 GRPO 风格的方法,带有优先经验回放。

第三,是 RAG 系统。这提供知识检索。它使用 LanceDB 进行向量存储,支持混合搜索(语义加关键词),并且可以进行多跳检索以处理复杂查询。

第四,是 Skill Library。这存储已学习的技能。当代理做得好时,这种方法就成为技能。技能根据成功率排名并用于类似的问题。

第五,是 Feedback Loop。用户对响应进行评分。这些评分成为奖励。奖励更新技能。技能提升代理。循环继续。

以下是数据流通过系统的方式:

用户提出一个问题。前端将其发送到 FastAPI 后端。后端调用代理核心。代理从技能库中检索相关技能。它用技能上下文增强查询。如需,它搜索知识库。它使用网络搜索获取最新信息。它生成一个响应。

响应实时返回给用户。用户对其进行评分。反馈传递给 RL trainer。trainer 计算奖励。奖励更新技能统计。表现良好的技能得到提升。代理使用更新后的技能被重新创建。

这就是完整的循环。查询响应反馈学习改进

技术栈是故意现代化的。FastAPI 用于后端,因为它速度快且具有出色的异步支持。Agno v2 用于代理框架,因为它专为 2025 年设计,采用最新的最佳实践。LanceDB 用于向量,因为它简单且支持混合搜索。纯 HTML/CSS/JS 用于前端,有时不需要框架。

一切都设计成一起工作。代理可以访问知识。RL 系统可以从反馈中学习。技能随时间积累。结果是一个真正改进的代理。

3、构建 Agent Core

REFLEX 的核心是 SelfImprovingResearchAgent 类。这是所有东西汇聚的地方。让我向你展示它是如何构建的。

当你初始化代理时,它设置几个关键组件:

    def __init__(  
        self,  
        api_key: Optional[str] = None,  
        openai_api_key: Optional[str] = None,  
        db_path: str = None  
    ):  
        if db_path is None:  
            # Use absolute path to ensure it works in Docker  
            db_path = os.path.join(os.getcwd(), "data", "db", "agent.db")  
        logger.info("Initializing SelfImprovingResearchAgent...")  
        # Load .env from root if not already loaded  
        from dotenv import load_dotenv  
        import pathlib  
        root_env = pathlib.Path(__file__).parent.parent / ".env"  
        if root_env.exists():  
            load_dotenv(root_env)  

        self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")  
        self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")  

        logger.debug(f"API keys loaded: Anthropic={bool(self.api_key)}, OpenAI={bool(self.openai_api_key)}")  

        # Initialize skill library  
        self.skill_library = SkillLibrary()  
        logger.debug(f"Skill library initialized with {len(self.skill_library.skills)} skills")  

        # Initialize adaptive reward weights  
        reward_weights_path = os.path.join(os.getcwd(), "data", "reward_weights.json")  
        if AdaptiveRewardWeights:  
            self.adaptive_reward_weights = AdaptiveRewardWeights()  
            self.adaptive_reward_weights.load(reward_weights_path)  
            logger.info("Adaptive reward weights initialized")  
        else:  
            self.adaptive_reward_weights = None  
            logger.debug("Using fixed reward weights")  

        # Initialize trajectory buffer with persistent storage  
        trajectory_db_path = os.path.join(os.path.dirname(db_path), "trajectories.db")  
        self.trajectory_buffer = TrajectoryBuffer(  
            max_size=10000,  
            use_prioritized=True,  
            db_path=trajectory_db_path  
        )  
        logger.debug("Trajectory buffer initialized with prioritized replay")  

        # Load existing trajectories from database  
        try:  
            self.trajectory_buffer.load_from_db(limit=1000)  
        except Exception as e:  
            logger.warning(f"Could not load trajectories from DB: {e}")  

        # Initialize pending trajectories cache (session_id -> trajectory)  
        self.pending_trajectories: Dict[str, Dict[str, Any]] = {}  

        # Initialize Critic Agent  
        if self.api_key:  
            self.critic = CriticAgent(api_key=self.api_key)  
            logger.info("Critic Agent initialized")  
        else:  
            self.critic = None  
            logger.warning("Critic Agent NOT initialized (missing API key)")  

        # Training statistics  
        self.training_stats = {  
            'total_tasks': 0,  
            'successful_tasks': 0,  
            'average_reward': 0.0,  
            'improvement_rate': 0.0  
        }  

        # Setup database for memory and storage  
        self.db = SqliteDb(db_file=db_path)  

        # Setup knowledge base (optional - can be None for POC)  
        self.knowledge = None  
        self.enhanced_rag = None  # Enhanced RAG system  
        self.knowledge_urls = []  # Store URLs for knowledge base  
        self.lancedb_path = os.path.join(os.getcwd(), "data", "db", "lancedb")  

        # Load saved URLs from file  
        self._load_knowledge_urls()  

        if self.openai_api_key and Knowledge and LanceDb and OpenAIEmbedder and WebsiteReader:  
            try:  
                logger.info("Initializing knowledge base with LanceDB...")  
                if self.knowledge_urls:  
                    # Create embedder  
                    embedder = OpenAIEmbedder(  
                        id="text-embedding-3-small",  
                        api_key=self.openai_api_key  
                    )  

                    # Create vector database  
                    vector_db = LanceDb(  
                        uri=self.lancedb_path,  
                        table_name="research_docs",  
                        search_type=SearchType.hybrid,  
                        embedder=embedder  
                    )  

                    # Create website reader  
                    website_reader = WebsiteReader()  

                    # Create knowledge base with website reader  
                    self.knowledge = Knowledge(  
                        name="Research Knowledge Base",  
                        description="Knowledge base for research documents",  
                        vector_db=vector_db,  
                        readers={"url": website_reader, "website": website_reader}  
                    )  

                    # Load URLs into knowledge base  
                    logger.info(f"Loading {len(self.knowledge_urls)} URLs into knowledge base...")  
                    for url in self.knowledge_urls:  
                        try:  
                            self.knowledge.add_content(url=url)  
                            logger.debug(f"Loaded URL: {url}")  
                        except Exception as e:  
                            logger.warning(f"Failed to load URL {url}: {e}")  

                    logger.info(f"Knowledge base initialized with {len(self.knowledge_urls)} URLs")  

                    # Initialize Enhanced RAG if available  
                    if ENHANCED_RAG_AVAILABLE and EnhancedRAG:  
                        try:  
                            cohere_key = os.getenv("COHERE_API_KEY")  
                            self.enhanced_rag = EnhancedRAG(  
                                knowledge_base=self.knowledge,  
                                use_reranking=bool(cohere_key),  
                                use_multi_hop=True,  
                                use_query_expansion=False,  # Can be enabled if needed  
                                cohere_api_key=cohere_key  
                            )  
                            logger.info("Enhanced RAG system initialized")  
                        except Exception as e:  
                            logger.warning(f"Could not initialize Enhanced RAG: {e}")  
                else:  
                    # Initialize empty knowledge base  
                    embedder = OpenAIEmbedder(  
                        id="text-embedding-3-small",  
                        api_key=self.openai_api_key  
                    )  
                    vector_db = LanceDb(  
                        uri=self.lancedb_path,  
                        table_name="research_docs",  
                        search_type=SearchType.hybrid,  
                        embedder=embedder  
                    )  
                    website_reader = WebsiteReader()  
                    self.knowledge = Knowledge(  
                        name="Research Knowledge Base",  
                        description="Knowledge base for research documents",  
                        vector_db=vector_db,  
                        readers={"url": website_reader, "website": website_reader}  
                    )  
                    logger.info("Knowledge base initialized but no URLs configured")  

                    # Initialize Enhanced RAG even with empty knowledge base  
                    if ENHANCED_RAG_AVAILABLE and EnhancedRAG:  
                        try:  
                            cohere_key = os.getenv("COHERE_API_KEY")  
                            self.enhanced_rag = EnhancedRAG(  
                                knowledge_base=self.knowledge,  
                                use_reranking=bool(cohere_key),  
                                use_multi_hop=True,  
                                use_query_expansion=False,  
                                cohere_api_key=cohere_key  
                            )  
                            logger.info("Enhanced RAG system initialized (empty knowledge base)")  
                        except Exception as e:  
                            logger.warning(f"Could not initialize Enhanced RAG: {e}")  
            except Exception as e:  
                logger.warning(f"Knowledge base not initialized: {e}", exc_info=True)  
                self.knowledge = None  
        else:  
            missing = []  
            if not self.openai_api_key:  
                missing.append("OpenAI API key")  
            if not Knowledge:  
                missing.append("Knowledge class")  
            if not LanceDb:  
                missing.append("LanceDb")  
            if not OpenAIEmbedder:  
                missing.append("OpenAIEmbedder")  
            if not WebsiteReader:  
                missing.append("WebsiteReader")  
            logger.info(f"Knowledge base disabled (missing: {', '.join(missing)})")  

        # Create the main agent  
        self.agent = self._create_agent()

初始化做了几件重要的事情:加载 API 密钥,设置技能库,创建用于 RL 训练的轨迹缓冲区,初始化知识库,并创建主代理。

技能库是关键。它将已学习的技能存储在一个 JSON 文件中。每个技能都有名称、描述、上下文、成功率和使用次数。当代理需要回答一个问题时,它会检索相关的技能并使用它们来增强查询。

这里是技能库的工作方式:

class SkillLibrary:  
    """Manages learned skills for self-improvement"""  

    def __init__(self, storage_path: str = None):  
        if storage_path is None:  
            storage_path = os.path.join(os.getcwd(), "data", "skills", "skills.json")  
        self.storage_path = storage_path  
        self.skills: Dict[str, Skill] = {}  
        self.total_global_usage = 0  
        self.load_skills()  

    def add_skill(self, skill: Skill):  
        """Add or update a skill"""  
        self.skills[skill.name] = skill  
        self.save_skills()  

    def get_relevant_skills(self, query: str, top_k: int = 3) -> List[Skill]:  
        """Get most relevant skills for a query using UCB-based selection"""  
        scored_skills = []  
        query_words = set(query.lower().split())  

        # UCB Exploration Constant  
        c = 0.5  

        for skill in self.skills.values():  
            desc_words = set(skill.description.lower().split())  
            overlap = len(query_words.intersection(desc_words))  

            if overlap == 0:  
                continue  

            # UCB Formula: Score = Relevance * (AverageReward + c * sqrt(ln(TotalUsage) / SkillUsage))  
            if skill.usage_count > 0 and self.total_global_usage > 0:  
                exploration_term = c * np.sqrt(np.log(self.total_global_usage) / skill.usage_count)  
                exploration_term = min(exploration_term, 2.0)  # Cap exploration  
                reward_term = skill.average_reward  
            else:  
                # High score for never-used skills to encourage exploration  
                exploration_term = 1.0  
                reward_term = 0.5  

            score = overlap * (reward_term + exploration_term)  
            scored_skills.append((score, skill))  

        scored_skills.sort(reverse=True, key=lambda x: x[0])  
        return [skill for _, skill in scored_skills[:top_k]]  

    def update_skill_stats(self, name: str, reward: float, success: bool):  
        """Update skill statistics after use"""  
        self.total_global_usage += 1  
        if name in self.skills:  
            skill = self.skills[name]  
            skill.usage_count += 1  
            skill.average_reward = (  
                (skill.average_reward * (skill.usage_count - 1) + reward) /   
                skill.usage_count  
            )  
            if success:  
                skill.success_rate = (  
                    (skill.success_rate * (skill.usage_count - 1) + 1.0) /   
                    skill.usage_count  
                )  
            else:  
                skill.success_rate = (  
                    (skill.success_rate * (skill.usage_count - 1)) /   
                    skill.usage_count  
                )  
            self.save_skills()

技能库使用 UCB(Upper Confidence Bound)算法来平衡探索和利用。它根据相关性和成功率对技能进行排序,但也会给较少使用的技能一个奖励。这确保代理不会只使用最流行的技能。

轨迹缓冲区存储代理的动作和响应。这是 RL 系统用于训练的。它支持优先经验回放,这意味着高价值的轨迹会被采样更多次。

知识库使用 LanceDB 进行向量存储。它支持混合搜索,结合语义和关键词匹配。这比单独使用其中一种方法效果更好。

现在让我们看看代理本身是如何创建的:

    def _create_agent(self) -> Agent:  
        """Create the research agent with all capabilities"""  

        base_instructions = [  
            "You are a research assistant with self-improvement capabilities.",  
            "Use web search to find current information when needed.",  
            "Be thorough in your research and provide well-sourced answers.",  
            "Learn from feedback to improve your performance over time."  
        ]  

        # Add skill context  
        if self.skill_library.skills:  
            skill_context = "\n\nLearned Skills:\n"  
            for skill in list(self.skill_library.skills.values())[:5]:  
                skill_context += (  
                    f"- {skill.name}: {skill.description} "  
                    f"(success rate: {skill.success_rate:.2f})\n"  
                )  
            base_instructions.append(skill_context)  

        agent_config = {  
            "name": "Research Agent",  
            "model": Claude(id="claude-sonnet-4-20250514", api_key=self.api_key),  
            "instructions": base_instructions,  
            "tools": [DuckDuckGoTools()],  
            "db": self.db,  
            "add_history_to_context": True,  
            "markdown": True  
        }  

        # Add knowledge if available  
        if self.knowledge:  
            agent_config["knowledge"] = self.knowledge  
            agent_config["search_knowledge"] = True  

        logger.debug(f"Creating agent with config keys: {list(agent_config.keys())}")  
        return Agent(**agent_config)

这个方法创建了具有所有功能的 Agno 代理。它包括基本指令,将已学习的技能添加到上下文中,配置模型(Claude Sonnet 4),添加工具(网络搜索),并附加知识库。

这里的关键见解是,技能被注入到代理的指令中。这意味着代理可以看到过去哪些方法已经成功。这就像给代理一个成功的策略手册。

代理还使用一个 Critic Agent 来在用户反馈之前评估自己的响应:

class CriticAgent:  
    """Evaluates agent responses to provide dense reward signals"""  

    def __init__(self, api_key: str):  
        self.agent = Agent(  
            name="Critic",  
            model=Claude(id="claude-3-haiku-20240307", api_key=api_key),  
            instructions=[  
                "You are an expert AI critic. Your job is to evaluate the quality of research assistant responses.",  
                "You will be given a user query and the agent's response.",  
                "Evaluate based on: Accuracy, Completeness, Relevance, and Citation quality.",  
                "Output ONLY a single float number between 0.0 and 1.0 representing the score.",  
                "0.0 is terrible, 1.0 is perfect."  
            ],  
            markdown=False  
        )  

    def evaluate(self, query: str, response: str) -> float:  
        """Evaluate a response and return a score between 0.0 and 1.0"""  
        if not response or not response.strip():  
            return 0.0  

        try:  
            prompt = f"""  
            User Query: {query}  

            Agent Response:  
            {response[:4000]}   

            Rate this response from 0.0 to 1.0.  
            Return ONLY the number.  
            """  

            result = self.agent.run(prompt)  
            content = result.content.strip()  

            # Extract number from response  
            match = re.search(r"0\.\d+|1\.0|0|1", content)  
            if match:  
                score = float(match.group())  
                return max(0.0, min(1.0, score))  
            return 0.5  # Default if parsing fails  

        except Exception as e:  
            logger.error(f"Critic evaluation failed: {e}")  
            return 0.5  # Default on error

批评者在用户评分之前提供即时反馈。这创造了密集的奖励信号,帮助代理更快地学习。

当用户提问时,代理会检索相关技能,用技能上下文增强查询,并生成一个响应。该响应会实时返回给用户。用户对其进行评分。反馈会传递给 RL trainer。trainer 计算奖励。奖励更新技能统计数据。表现良好的技能会得到提升。代理会使用更新后的技能重新创建。

这是一个正反馈循环。好的技能会被更多使用。被使用的技能会变得更好。代理会随着时间推移而改进。

接下来,让我们看看强化学习系统是如何实现这种改进的。

4、强化学习的实际应用

强化学习使 REFLEX 自我改进。每一次用户互动都成为一个学习机会。以下是其工作原理。

当用户对一个响应进行评分时,该评分会成为奖励信号。奖励包含多个组成部分:任务成功(代理是否完成了任务?)、质量分数(响应有多好?)、效率分数(它响应得多快?)和用户反馈(整体满意度)。

这些组成部分被组合成一个单一的奖励:

@dataclass  
class RewardSignal:  
    """Reward signal for RL training"""  
    task_success: float  
    quality_score: float  
    efficiency_score: float  
    user_feedback: float  
    critic_score: float = 0.0  
    adaptive_weights: Optional[AdaptiveRewardWeights] = None  

    def compute_total_reward(self) -> float:  
        """Compute weighted total reward"""  
        if self.adaptive_weights:  
            total, _ = self.adaptive_weights.compute_reward(  
                self.task_success,  
                self.quality_score,  
                self.efficiency_score,  
                self.user_feedback,  
                self.critic_score  
            )  
            return total  
        else:  
            # Fallback to fixed weights  
            weights = {  
                'task_success': 0.35,  
                'quality_score': 0.25,  
                'efficiency_score': 0.1,  
                'user_feedback': 0.1,  
                'critic_score': 0.2  
            }  
            return (  
                weights['task_success'] * self.task_success +  
                weights['quality_score'] * self.quality_score +  
                weights['efficiency_score'] * self.efficiency_score +  
                weights['user_feedback'] * self.user_feedback +  
                weights['critic_score'] * self.critic_score  
            )

奖励是加权的。任务成功占35%,因为完成任务最重要。质量占25%,因为好的答案很重要。效率占10%,因为速度不错但不关键。用户反馈占10%,因为满意度重要。批评者的分数占20%,因为内部评估有助于学习。

这个奖励被存储在一个轨迹中。轨迹包含查询、响应、使用的工具、应用的技能和奖励。轨迹被存储在缓冲区中用于训练。

缓冲区使用优先经验回放。这意味着高奖励的轨迹被采样得更多。就像在学习一项运动时专注于你的最佳比赛一样。你从成功中学到的比失败更多。

这里是优先经验回放的工作方式:

class PrioritizedReplayBuffer:  
    """  
    Prioritized Experience Replay Buffer  
    Samples trajectories based on TD-error or advantage magnitude  
    """  

    def __init__(self, capacity: int = 10000, alpha: float = 0.6, beta: float = 0.4, beta_increment: float = 0.001):  
        """  
        Args:  
            capacity: Maximum number of trajectories  
            alpha: Priority exponent (0 = uniform, 1 = full prioritization)  
            beta: Importance sampling exponent (starts at beta, increases to 1.0)  
            beta_increment: How much to increment beta per sample  
        """  
        self.capacity = capacity  
        self.alpha = alpha  
        self.beta = beta  
        self.beta_increment = beta_increment  
        self.tree = SumTree(capacity)  
        self.max_priority = 1.0  
        self.position = 0  

    def add(self, trajectory: Dict[str, Any], priority: Optional[float] = None):  
        """Add trajectory with priority"""  
        if priority is None:  
            # Use advantage magnitude or TD-error as priority  
            priority = abs(trajectory.get('advantage', trajectory.get('reward', 1.0)))  
            priority = max(priority, 1e-6)  # Minimum priority  

        # Scale by alpha  
        priority = (priority + 1e-6) ** self.alpha  
        self.max_priority = max(self.max_priority, priority)  

        self.tree.add(priority, trajectory)  

    def sample(self, batch_size: int) -> Tuple[List[Dict[str, Any]], List[int], np.ndarray]:  
        """  
        Sample batch of trajectories  
        Returns: (batch, indices, importance_weights)  
        """  
        batch = []  
        indices = []  
        priorities = []  

        segment = self.tree.total / batch_size  

        for i in range(batch_size):  
            a = segment * i  
            b = segment * (i + 1)  
            s = np.random.uniform(a, b)  
            idx, data, priority = self.tree.get(s)  
            batch.append(data)  
            indices.append(idx)  
            priorities.append(priority)  

        # Compute importance sampling weights  
        probabilities = np.array(priorities) / self.tree.total  
        weights = (self.capacity * probabilities) ** (-self.beta)  
        weights = weights / weights.max()  # Normalize  

        self.beta = min(1.0, self.beta + self.beta_increment)  

        return batch, indices, weights

缓冲区使用 SumTree 数据结构来进行高效的采样。具有更高优势的轨迹被采样得更多,但重要性采样权重会纠正由此引入的偏差。

这里是轨迹缓冲区的工作方式:

class TrajectoryBuffer:  
    """Stores agent trajectories for RL training with prioritized replay support"""  

    def __init__(self, max_size: int = 1000, use_prioritized: bool = True, db_path: Optional[str] = None):  
        self.max_size = max_size  
        self.use_prioritized = use_prioritized and PrioritizedReplayBuffer is not None  

        if self.use_prioritized:  
            self.prioritized_buffer = PrioritizedReplayBuffer(capacity=max_size)  
            self.trajectories: List[Dict[str, Any]] = []  # Keep for compatibility  
        else:  
            self.trajectories: List[Dict[str, Any]] = []  
            self.prioritized_buffer = None  

        # Persistent storage  
        self.db = None  
        if db_path:  
            try:  
                self.db = TrajectoryDatabase(db_path)  
                logger.info("Trajectory persistent storage enabled")  
            except Exception as e:  
                logger.warning(f"Could not initialize trajectory database: {e}")  

        # PPO trainer for advanced advantage computation  
        self.ppo_trainer = PPOTrainer() if PPOTrainer else None  

    def add_trajectory(self, trajectory: Dict[str, Any]):  
        """Add a trajectory to the buffer"""  
        # Compute priority if using prioritized replay  
        priority = None  
        if self.use_prioritized and 'advantage' in trajectory:  
            priority = abs(trajectory['advantage'])  

        if self.use_prioritized and self.prioritized_buffer:  
            self.prioritized_buffer.add(trajectory, priority)  
        else:  
            self.trajectories.append(trajectory)  
            if len(self.trajectories) > self.max_size:  
                self.trajectories.pop(0)  

        # Save to persistent storage  
        if self.db:  
            try:  
                self.db.save_trajectory(trajectory)  
            except Exception as e:  
                logger.warning(f"Could not save trajectory to DB: {e}")  

    def get_batch(self, batch_size: int = 32) -> List[Dict[str, Any]]:  
        """Get a batch of trajectories (prioritized if enabled)"""  
        if self.use_prioritized and self.prioritized_buffer:  
            batch, indices, weights = self.prioritized_buffer.sample(batch_size)  
            # Store weights for importance sampling  
            for traj, weight in zip(batch, weights):  
                traj['importance_weight'] = float(weight)  
            return batch  
        else:  
            # Fallback to random sampling  
            if len(self.trajectories) < batch_size:  
                return self.trajectories.copy()  
            indices = np.random.choice(len(self.trajectories), batch_size, replace=False)  
            return [self.trajectories[i] for i in indices]  

    def compute_advantages(self, trajectories: List[Dict[str, Any]], use_ppo: bool = True) -> List[float]:  
        """Compute advantages using PPO-style GAE or group-relative approach"""  
        if use_ppo and self.ppo_trainer:  
            advantages = self.ppo_trainer.compute_advantages(trajectories, use_gae=True)  
            # Store advantages in trajectories  
            for traj, adv in zip(trajectories, advantages):  
                traj['advantage'] = adv  
            return advantages  
        else:  
            # Fallback to group-relative approach  
            rewards = [t.get('reward', 0.0) for t in trajectories]  
            mean_reward = np.mean(rewards)  
            std_reward = np.std(rewards) + 1e-8  
            advantages = [(r - mean_reward) / std_reward for r in rewards]  
            # Store advantages  
            for traj, adv in zip(trajectories, advantages):  
                traj['advantage'] = adv  
            return advantages

缓冲区可以存储最多 10,000 个轨迹。当训练发生时,它会采样一批(默认 32 个轨迹)。如果启用了优先经验回放,它会采样更多高奖励的轨迹。这使训练更高效。

优势是使用 PPO 风格的广义优势估计(GAE)计算的。这估计了一条轨迹相比平均值是更好还是更差。正的优势意味着轨迹很好。负的优势意味着轨迹不好。

在训练期间,具有正优势的技能会得到提升。它们的成功率增加。具有负优势的技能会稍微衰减。它们的成功率减少。这就是代理学习哪些方法有效的方式。

这里是训练迭代:

    def train_iteration(self, batch_size: int = 32, use_ppo: bool = True):  
        """Perform a training iteration with PPO-style updates"""  
        logger.info(f"Starting training iteration: batch_size={batch_size}, available_trajectories={len(self.trajectory_buffer)}")  

        if len(self.trajectory_buffer) < batch_size:  
            logger.warning(f"Not enough trajectories for training (have {len(self.trajectory_buffer)}, need {batch_size})")  
            return  

        # Get batch (prioritized if enabled)  
        batch = self.trajectory_buffer.get_batch(batch_size)  
        logger.debug(f"Selected batch of {len(batch)} trajectories")  

        # Compute advantages using PPO-style GAE  
        advantages = self.trajectory_buffer.compute_advantages(batch, use_ppo=use_ppo)  
        logger.debug(f"Computed advantages: mean={np.mean(advantages):.3f}, std={np.std(advantages):.3f}")  

        # Update priorities for prioritized replay  
        if self.trajectory_buffer.use_prioritized:  
            self.trajectory_buffer.update_priorities(batch, advantages)  

        # Update skill weights based on advantages  
        updated_skills = 0  
        for traj, advantage in zip(batch, advantages):  
            for skill_name in traj.get('relevant_skills', []):  
                if skill_name in self.skill_library.skills:  
                    skill = self.skill_library.skills[skill_name]  
                    if advantage > 0:  
                        # PPO-style soft update  
                        boost_factor = 1.0 + min(0.1, advantage * 0.05)  # Cap boost  
                        old_rate = skill.success_rate  
                        skill.success_rate = min(1.0, skill.success_rate * boost_factor)  
                        logger.debug(f"Boosted skill {skill_name}: {old_rate:.3f} -> {skill.success_rate:.3f} (advantage={advantage:.3f})")  
                        updated_skills += 1  
                    elif advantage < -0.5:  # Significant negative advantage  
                        # Slight decay for poor performance  
                        old_rate = skill.success_rate  
                        skill.success_rate = max(0.0, skill.success_rate * 0.98)  
                        logger.debug(f"Decayed skill {skill_name}: {old_rate:.3f} -> {skill.success_rate:.3f}")  

        self.skill_library.save_skills()  
        logger.info(f"Training iteration completed. Batch size: {batch_size}, skills updated: {updated_skills}")

训练过程很简单。获取一批轨迹。计算优势。基于优势更新技能成功率。保存更新后的技能。代理随后使用更新后的技能被重新创建。

这是一个 GRPO 风格的方法。GRPO(组相对策略优化)将类似的轨迹分组并计算相对优势。它比完整的 PPO 更简单,但对于基于技能的学习仍然有效。

关键的见解是,我们并没有训练神经网络。我们在更新技能统计。成功率高的技能会被更多使用。成功率低的技能会被更少使用。随着时间的推移,代理自然会趋向于更好的方法。

这很有趣,对吧?

5、增强的 RAG 系统

RAG(检索增强生成)给代理提供了知识库的访问权限。而不是仅仅依赖 LLM 的训练数据,代理可以搜索文档,找到相关信息,并利用这些信息生成更好的答案。

REFLEX 使用 LanceDB 进行向量存储。LanceDB 快速,支持混合搜索,并且易于集成。这里是知识库的设置:

        if self.openai_api_key and Knowledge and LanceDb and OpenAIEmbedder and WebsiteReader:  
            try:  
                logger.info("Initializing knowledge base with LanceDB...")  
                if self.knowledge_urls:  
                    # Create embedder  
                    embedder = OpenAIEmbedder(  
                        id="text-embedding-3-small",  
                        api_key=self.openai_api_key  
                    )  

                    # Create vector database  
                    vector_db = LanceDb(  
                        uri=self.lancedb_path,  
                        table_name="research_docs",  
                        search_type=SearchType.hybrid,  
                        embedder=embedder  
                    )  

                    # Create website reader  
                    website_reader = WebsiteReader()  

                    # Create knowledge base with website reader  
                    self.knowledge = Knowledge(  
                        name="Research Knowledge Base",  
                        description="Knowledge base for research documents",  
                        vector_db=vector_db,  
                        readers={"url": website_reader, "website": website_reader}  
                    )  

                    # Load URLs into knowledge base  
                    logger.info(f"Loading {len(self.knowledge_urls)} URLs into knowledge base...")  
                    for url in self.knowledge_urls:  
                        try:  
                            self.knowledge.add_content(url=url)  
                            logger.debug(f"Loaded URL: {url}")  
                        except Exception as e:  
                            logger.warning(f"Failed to load URL {url}: {e}")  

                    logger.info(f"Knowledge base initialized with {len(self.knowledge_urls)} URLs")

知识库使用 OpenAI 的 text-embedding-3-small 模型创建嵌入。这些嵌入被存储在 LanceDB 中。当用户提问时,查询会被嵌入,相似的文档会被检索。

这里的重点是混合搜索。LanceDB 支持语义搜索(查找概念上相似的内容)和关键词搜索(查找精确术语)。混合搜索结合了这两种方法以获得更好的结果。

语义搜索非常适合查找相关概念。如果你问“机器学习”,它会找到关于“神经网络”和“深度学习”的文档,即使这些确切的词不在你的查询中。

关键词搜索非常适合查找特定术语。如果你问“GRPO 算法”,它会找到提到“GRPO”的文档,即使上下文不同。

混合搜索让你同时获得两者的优点。它找到既概念上相似又包含相关关键词的文档。

REFLEX 还支持增强的 RAG 功能。这些包括多跳检索和重排序:

class EnhancedRAG:  
    """  
    Enhanced RAG system combining reranking, multi-hop retrieval, and query expansion  
    """  

    def __init__(  
        self,  
        knowledge_base,  
        use_reranking: bool = True,  
        use_multi_hop: bool = True,  
        use_query_expansion: bool = False,  
        cohere_api_key: Optional[str] = None  
    ):  
        """  
        Args:  
            knowledge_base: Agno Knowledge instance  
            use_reranking: Enable Cohere reranking  
            use_multi_hop: Enable multi-hop retrieval  
            use_query_expansion: Enable query expansion  
            cohere_api_key: Cohere API key for reranking  
        """  
        self.knowledge_base = knowledge_base  

        # Initialize components  
        self.reranker = None  
        if use_reranking and COHERE_AVAILABLE:  
            try:  
                self.reranker = Reranker(api_key=cohere_api_key)  
                logger.info("Reranking enabled")  
            except Exception as e:  
                logger.warning(f"Could not initialize reranker: {e}")  

        self.multi_hop = None  
        if use_multi_hop:  
            self.multi_hop = MultiHopRetriever(knowledge_base, max_hops=3)  
            logger.info("Multi-hop retrieval enabled")  

        self.query_expander = None  
        if use_query_expansion:  
            self.query_expander = QueryExpander()  
            logger.info("Query expansion enabled")  

        logger.info("Enhanced RAG system initialized")  

    def retrieve(  
        self,  
        query: str,  
        top_k: int = 5,  
        context: Optional[str] = None  
    ) -> Dict[str, Any]:  
        """  
        Enhanced retrieval with all enabled features  

        Args:  
            query: User query  
            top_k: Number of results to return  
            context: Additional context  

        Returns:  
            Dictionary with retrieved documents and metadata  
        """  
        # Step 1: Query expansion (if enabled)  
        queries = [query]  
        if self.query_expander:  
            queries = self.query_expander.expand(query, num_variations=2)  
            logger.debug(f"Expanded query into {len(queries)} variations")  

        # Step 2: Multi-hop retrieval (if enabled)  
        all_results = []  
        retrieval_metadata = {}  

        if self.multi_hop:  
            for q in queries[:1]:  # Use original query for multi-hop  
                hop_results = self.multi_hop.retrieve(q, context)  
                all_results.extend(hop_results['results'])  
                retrieval_metadata = hop_results  
        else:  
            # Simple single-hop retrieval  
            try:  
                if hasattr(self.knowledge_base, 'search'):  
                    for q in queries:  
                        results = self.knowledge_base.search(q, limit=top_k)  
                        all_results.extend(results)  
                else:  
                    logger.warning("Knowledge base does not support search")  
            except Exception as e:  
                logger.error(f"Error in retrieval: {e}")  

        # Step 3: Reranking (if enabled)  
        if self.reranker and all_results:  
            # Extract text from results  
            doc_texts = []  
            for result in all_results:  
                if hasattr(result, 'content'):  
                    doc_texts.append(result.content)  
                elif isinstance(result, str):  
                    doc_texts.append(result)  
                else:  
                    doc_texts.append(str(result))  

            if doc_texts:  
                reranked = self.reranker.rerank(query, doc_texts, top_n=top_k)  

                # Reorder results based on reranking  
                reranked_results = []  
                reranked_indices = {text: idx for idx, (text, score) in enumerate(reranked)}  

                for result in all_results:  
                    result_text = str(result.content if hasattr(result, 'content') else result)  
                    if result_text in reranked_indices:  
                        reranked_results.append((result, reranked[reranked_indices[result_text]][1]))  

                # Sort by rerank score  
                reranked_results.sort(key=lambda x: x[1], reverse=True)  
                all_results = [r[0] for r in reranked_results[:top_k]]  
                retrieval_metadata['reranked'] = True  
                retrieval_metadata['rerank_scores'] = [r[1] for r in reranked_results[:top_k]]  

        return {  
            'results': all_results[:top_k],  
            'metadata': retrieval_metadata,  
            'num_results': len(all_results[:top_k])  
        }

多跳检索非常强大。而不是做一次搜索,它做多次搜索。第一次搜索找到初始结果。这些结果被分析,并生成一个细化的查询。第二次搜索使用这个细化的查询。这个过程最多进行三次跳跃。

这里是多跳检索的实现:

class MultiHopRetriever:
    """
    Multi-hop retrieval system
    Performs iterative retrieval with query refinement
    """
    
    def __init__(self, knowledge_base, max_hops: int = 3):
        """
        Args:
            knowledge_base: Agno Knowledge instance
            max_hops: Maximum number of retrieval iterations
        """
        self.knowledge_base = knowledge_base
        self.max_hops = max_hops
        logger.info(f"Multi-hop retriever initialized with max_hops={max_hops}")
    
    def retrieve(
        self,
        initial_query: str,
        context: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Perform multi-hop retrieval
        
        Args:
            initial_query: Initial user query
            context: Additional context from conversation
            
        Returns:
            Dictionary with retrieved information and retrieval path
        """
        all_results = []
        current_query = initial_query
        retrieval_path = []
        
        for hop in range(self.max_hops):
            logger.debug(f"Retrieval hop {hop + 1}/{self.max_hops} with query: {current_query[:100]}...")
            
            # Perform retrieval
            try:
                # Use Agno's knowledge search
                if hasattr(self.knowledge_base, 'search'):
                    results = self.knowledge_base.search(current_query, limit=5)
                else:
                    # Fallback: direct vector search
                    results = []
                
                if results:
                    all_results.extend(results)
                    retrieval_path.append({
                        'hop': hop + 1,
                        'query': current_query,
                        'results_count': len(results)
                    })
                    
                    # Refine query for next hop based on results
                    if hop < self.max_hops - 1:
                        # Extract key information from results to refine query
                        result_texts = [str(r.content if hasattr(r, 'content') else r) for r in results[:3]]
                        current_query = self._refine_query(initial_query, result_texts)
                else:
                    break  # No more results, stop hopping
                    
            except Exception as e:
                logger.error(f"Error in hop {hop + 1}: {e}")
                break
        
        return {
            'results': all_results,
            'retrieval_path': retrieval_path,
            'num_hops': len(retrieval_path)
        }
    
    def _refine_query(self, original_query: str, context_docs: List[str]) -> str:
        """Refine query based on retrieved context"""
        # Simple refinement: add key terms from context
        # In production, you might use an LLM to generate refined queries
        context_text = " ".join(context_docs[:500])  # Limit context length
        return f"{original_query} {context_text[:200]}"

每次跳跃都会根据前一次跳跃的结果优化搜索查询,使系统能够深入复杂主题。

这对于复杂问题很有用。如果您问“强化学习在语言模型中的最新进展是什么?”,第一次跳跃可能会找到一般的 RL 论文。第二次跳跃可能会找到特定的语言模型论文。第三次跳跃可能会找到最新的进展。每次跳跃都会优化搜索。

重新排序提高了相关性。在检索文档后,它们会通过 Cohere 的重新排序 API 进行重新排序。这使用了专门的模型来判断相关性,比简单的相似性更有效。重新排序后的顶级结果更有可能真正有用。

以下是重新排序器的工作方式:

class Reranker:
    """
    Reranking component using Cohere API
    Improves retrieval relevance by reranking initial results
    """
    
    def __init__(self, api_key: Optional[str] = None, model: str = "rerank-english-v3.0"):
        """
        Args:
            api_key: Cohere API key (from env if None)
            model: Reranking model to use
        """
        if not COHERE_AVAILABLE:
            raise ImportError("Cohere package not installed. Install with: pip install cohere")
        
        self.api_key = api_key or os.getenv("COHERE_API_KEY")
        if not self.api_key:
            logger.warning("Cohere API key not found, reranking disabled")
            self.client = None
        else:
            self.client = cohere.Client(self.api_key)
            self.model = model
            logger.info(f"Reranker initialized with model: {model}")
    
    def rerank(
        self,
        query: str,
        documents: List[str],
        top_n: Optional[int] = None
    ) -> List[Tuple[str, float]]:
        """
        Rerank documents based on query relevance
        
        Args:
            query: Search query
            documents: List of document texts to rerank
            top_n: Number of top results to return (None = all)
            
        Returns:
            List of (document, relevance_score) tuples, sorted by relevance
        """
        if not self.client or not documents:
            return [(doc, 1.0) for doc in documents]
        
        try:
            # Cohere rerank API
            results = self.client.rerank(
                model=self.model,
                query=query,
                documents=documents,
                top_n=top_n or len(documents)
            )
            
            # Extract results
            reranked = [
                (documents[result.index], result.relevance_score)
                for result in results.results
            ]
            
            logger.debug(f"Reranked {len(documents)} documents, returning top {len(reranked)}")
            return reranked
            
        except Exception as e:
            logger.error(f"Error in reranking: {e}")
            # Fallback: return original order
            return [(doc, 1.0) for doc in documents]

重新排序器会根据查询的真实相关性对初始搜索结果进行重新排序,而不是仅仅基于语义相似性。这显著提高了检索信息的质量。

知识库通过 API 进行管理。用户可以添加 URL、删除 URL 并重新加载知识库。这使得保持知识库更新变得很容易。

当代理需要信息时,它首先搜索知识库。如果知识库没有足够的信息,它会回退到网络搜索。这提供了两者的最佳结合:快速访问已知信息和发现新信息的能力。

RAG 系统使代理变得知识渊博。RL 系统使其学习。两者结合,创造了一个既知道又学会更好地使用这些知识的代理。

现在,这里是反馈循环如何将所有内容连接在一起。

6、反馈循环

反馈循环是 REFLEX 自我改进的关键。没有它,代理就只是一个普通的聊天机器人。有了它,每一次互动都会让代理变得更好。

它是这样工作的。用户提出一个问题。代理生成一个回答。用户使用滑块对这个回答进行评分:任务成功、质量、效率和总体满意度。这些评分成为奖励信号。

奖励信号由 provide_feedback 方法处理:

def provide_feedback(  
    self,  
    trajectory: Dict[str, Any],  
    reward_signal: RewardSignal,  
    learned_skill: Optional[Skill] = None  
):  
    """处理反馈并更新代理"""  

    # 检查是否有带有更多详细信息的缓存轨迹(如评论者分数)  
    session_id = trajectory.get('session_id')  
    if session_id and session_id in self.pending_trajectories:  
        logger.info(f"使用会话 {session_id} 的缓存轨迹")  
        cached = self.pending_trajectories.pop(session_id)  
        # 将缓存数据合并到提供的轨迹中(优先使用缓存的 critic_score)  
        trajectory.update(cached)  
        # 确保 reward_signal 有来自轨迹的 critic_score  
        reward_signal.critic_score = cached.get('critic_score', 0.0)  

    logger.info(f"处理反馈: task_success={reward_signal.task_success}, critic_score={reward_signal.critic_score}")  

    total_reward = reward_signal.compute_total_reward()  
    trajectory['reward'] = total_reward  
    trajectory['trajectory_id'] = trajectory.get('session_id', f"traj_{datetime.now().timestamp()}")  
    trajectory['timestamp'] = datetime.now().isoformat()  
    logger.debug(f"计算的总奖励: {total_reward:.3f}")  

    # 计算优势用于优先重放  
    if self.trajectory_buffer.ppo_trainer:  
        # 使用 PPO 风格的 GAE 计算优势  
        single_traj_advantages = self.trajectory_buffer.compute_advantages([trajectory], use_ppo=True)  
        if single_traj_advantages:  
            trajectory['advantage'] = single_traj_advantages[0]  

    # 添加到轨迹缓冲区(如果启用,将保存到数据库)  
    self.trajectory_buffer.add_trajectory(trajectory)  

    # 更新技能库  
    if learned_skill:  
        self.skill_library.add_skill(learned_skill)  

    # 更新所用技能的统计数据  
    for skill_name in trajectory.get('relevant_skills', []):  
        success = reward_signal.task_success > 0.7  
        self.skill_library.update_skill_stats(skill_name, total_reward, success)  

    # 更新训练统计数据  
    self.training_stats['total_tasks'] += 1  
    if reward_signal.task_success > 0.7:  
        self.training_stats['successful_tasks'] += 1  

    old_avg = self.training_stats['average_reward']  
    n = self.training_stats['total_tasks']  
    self.training_stats['average_reward'] = (old_avg * (n - 1) + total_reward) / n  

    # 计算改进率  
    if n > 10:  
        # 从缓冲区获取最近的轨迹  
        recent_trajectories = self.trajectory_buffer.get_batch(min(10, len(self.trajectory_buffer)))  
        recent_rewards = [t.get('reward', 0.0) for t in recent_trajectories]  
        if recent_rewards:  
            self.training_stats['improvement_rate'] = np.mean(recent_rewards) - old_avg  

    # 根据性能更新自适应奖励权重  
    if self.adaptive_reward_weights:  
        component_performances = {  
            'task_success': reward_signal.task_success,  
            'quality_score': reward_signal.quality_score,  
            'efficiency_score': reward_signal.efficiency_score,  
            'user_feedback': max(0, reward_signal.user_feedback),  # 归一化到 0-1  
            'critic_score': reward_signal.critic_score  
        }  
        self.adaptive_reward_weights.update_weights(total_reward, component_performances)  
        # 定期保存权重  
        if n % 50 == 0:  
            reward_weights_path = os.path.join(os.getcwd(), "data", "reward_weights.json")  
            self.adaptive_reward_weights.save(reward_weights_path)  

    # 每 10 个任务重新创建代理  
    if self.training_stats['total_tasks'] % 10 == 0:  
        logger.info(f"在 {self.training_stats['total_tasks']} 个任务后重新创建代理")  
        self.agent = self._create_agent()

此方法执行几个重要操作。它从反馈组件中计算总奖励。它将轨迹添加到缓冲区。它更新技能统计数据。它更新训练统计数据。每 10 个任务,它会用更新的技能重新创建代理。

技能库跟踪每个轨迹中使用的哪些技能。当收到反馈时,这些技能会得到更新。如果反馈是积极的(任务成功 > 0.7),该技能的成功率会增加。如果反馈是消极的,成功率会下降。

这创造了一个自然选择过程。好的技能因为更高的成功率而被更多使用。坏的技能因为较低的成功率而被较少使用。随着时间推移,代理的技能集会朝着更好的方法进化。

用户也可以手动创建新技能。如果一个回答特别好,他们可以勾选一个框并描述是什么让它变得好。这将成为代理未来可以使用的技能。

API 中的反馈端点处理这个:

@app.post("/api/feedback")
async def submit_feedback(request: FeedbackRequest):
    """
    Submit feedback for a task to train the agent
    """
    logger.info(f"Received feedback: session_id={request.session_id}, task_success={request.task_success}, quality={request.quality_score}")
    
    if not agent_instance:
        logger.error("Agent not initialized when feedback received")
        raise HTTPException(status_code=503, detail="Agent not initialized")
    
    try:
        # Create reward signal with adaptive weights if available
        reward_signal = RewardSignal(
            task_success=request.task_success,
            quality_score=request.quality_score,
            efficiency_score=request.efficiency_score,
            user_feedback=request.user_feedback,
            adaptive_weights=getattr(agent_instance, 'adaptive_reward_weights', None)
        )
        
        # Create skill if provided
        learned_skill = None
        if request.learned_skill:
            learned_skill = Skill(
                name=request.learned_skill.name,
                description=request.learned_skill.description,
                context=request.learned_skill.context,
                success_rate=request.learned_skill.success_rate,
                usage_count=1,
                average_reward=reward_signal.compute_total_reward()
            )
        
        # Get trajectory (simplified - in real implementation, store trajectories)
        trajectory = {
            "query": "feedback_submission",
            "session_id": request.session_id,
            "relevant_skills": []
        }
        
        # Provide feedback to agent
        logger.debug(f"Processing feedback: reward={reward_signal.compute_total_reward()}, skill_added={learned_skill is not None}")
        agent_instance.provide_feedback(
            trajectory=trajectory,
            reward_signal=reward_signal,
            learned_skill=learned_skill
        )
        
        logger.info(f"Feedback processed successfully: total_reward={reward_signal.compute_total_reward()}")
        return {
            "status": "success",
            "message": "Feedback processed successfully",
            "total_reward": reward_signal.compute_total_reward(),
            "skill_added": learned_skill is not None
        }
        
    except Exception as e:
        logger.error(f"Error in feedback endpoint: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Feedback error: {str(e)}")

该端点接收反馈,创建奖励信号,可选地创建新技能,并调用 provide_feedback。代理然后更新其内部状态。

关键见解是反馈不会消失。它成为代理记忆的一部分。它影响哪些技能被使用。它塑造代理如何处理未来的问题。

这就是反馈循环。用户提问。代理回答。用户提供反馈。代理学习。代理变得更好。这个循环继续下去。

循环是 REFLEX 自我改进的关键。没有它,代理将是静态的。有了它,代理会进化。每一次对话都让它更聪明。

接下来,让我们看看流式传输如何让体验感觉即时。

7、实时流式传输

没有人喜欢等待。当你提问时,你希望答案立即开始出现。这就是 REFLEX 使用 Server-Sent Events (SSE) 进行流式传输的原因。

代理会逐个令牌流式传输,而不是等待整个响应生成。用户看到答案实时出现,就像看着有人打字一样。

以下是流式传输端点的工作方式:

@app.post("/api/chat/stream")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def chat_stream(request: Request, chat_request: ChatRequest):
    """
    Stream agent response using Server-Sent Events (SSE)
    Uses Agno's streaming capabilities for real-time response generation
    """
    # Validate input
    if not chat_request.message or not chat_request.message.strip():
        raise HTTPException(status_code=400, detail="Message cannot be empty")
    
    if len(chat_request.message) > settings.max_message_length:
        raise HTTPException(
            status_code=400,
            detail=f"Message exceeds maximum length of {settings.max_message_length} characters"
        )
    
    # Validate session and user IDs
    if chat_request.session_id and not SecurityValidator.validate_session_id(chat_request.session_id):
        raise HTTPException(status_code=400, detail="Invalid session ID format")
    
    if chat_request.user_id and not SecurityValidator.validate_user_id(chat_request.user_id):
        raise HTTPException(status_code=400, detail="Invalid user ID format")
    
    logger.info(f"Received stream request: session_id={chat_request.session_id}, message_length={len(chat_request.message)}")
    
    if not agent_instance:
        logger.error("Agent not initialized when stream request received")
        raise HTTPException(status_code=503, detail="Agent not initialized")
    
    session_id = chat_request.session_id or "default"
    user_id = chat_request.user_id or "default_user"
    
    async def event_generator():
        full_response_content = ""
        tools_used = []
        relevant_skills = []
        run_response = None
        
        try:
            # Store user message
            if conversation_db:
                conversation_db.save_message(
                    session_id=session_id,
                    user_id=user_id,
                    role="user",
                    content=chat_request.message
                )
            
            # Send initial status
            yield f"data: {json.dumps({'type': 'status', 'status': 'thinking', 'message': 'Analyzing your question...'})}\n\n"
            await asyncio.sleep(0.05)
            
            yield f"data: {json.dumps({'type': 'status', 'status': 'searching', 'message': 'Searching for information...'})}\n\n"
            await asyncio.sleep(0.05)
            
            logger.debug(f"Starting real stream for message: {chat_request.message[:100]}...")
            
            # Use the agent's streaming method
            async for chunk in agent_instance.run_task_stream(
                query=chat_request.message,
                session_id=session_id,
                user_id=user_id
            ):
                if chunk.get('type') == 'status':
                    # Stream status update (tool calls, searching, etc.)
                    status_data = {
                        'type': 'status',
                        'status': chunk.get('status', 'thinking'),
                        'message': chunk.get('message', 'Processing...')
                    }
                    yield f"data: {json.dumps(status_data)}\n\n"
                    
                elif chunk.get('type') == 'content':
                    # Stream content chunk
                    full_response_content += chunk.get('content', '')
                    chunk_data = {
                        'type': 'content',
                        'content': chunk.get('content', ''),
                        'done': False
                    }
                    yield f"data: {json.dumps(chunk_data)}\n\n"
                    
                elif chunk.get('type') == 'done':
                    # Final chunk with metadata
                    full_response_content = chunk.get('accumulated', full_response_content)
                    run_response = chunk.get('full_response')
                    tools_used = chunk.get('tools_used', [])
                    relevant_skills = chunk.get('relevant_skills', [])
                    sources = chunk.get('sources', [])
                    critic_score = chunk.get('critic_score', 0.0)
                    
                    # Send final content update
                    done_data = {
                        'type': 'done',
                        'content': '',
                        'done': True,
                        'tools_used': tools_used,
                        'relevant_skills': relevant_skills,
                        'sources': sources,
                        'critic_score': critic_score
                    }
                    yield f"data: {json.dumps(done_data)}\n\n"
                    
                elif chunk.get('type') == 'error':
                    # Error occurred
                    error_msg = chunk.get('error', 'Unknown error')
                    error_data = {
                        'type': 'error',
                        'error': error_msg,
                        'done': True
                    }
                    yield f"data: {json.dumps(error_data)}\n\n"
                    return
            
            # Store agent response after streaming completes
            if conversation_db and full_response_content:
                conversation_db.save_message(
                    session_id=session_id,
                    user_id=user_id,
                    role="agent",
                    content=full_response_content,
                    tools_used=tools_used,
                    skills_applied=relevant_skills,
                    trajectory_id=session_id
                )
                logger.info(f"Stored agent response: length={len(full_response_content)}")
            
        except Exception as e:
            logger.error(f"Error in stream generator: {str(e)}", exc_info=True)
            error_data = {
                "type": "error",
                "error": str(e),
                "done": True
            }
            yield f"data: {json.dumps(error_data)}\n\n"
    
    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",  # Disable nginx buffering
        }
    )

async def run_task_stream(  
    self,  
    query: str,  
    session_id: Optional[str] = None,  
    user_id: Optional[str] = None  
):  
    """  
    使用 Agno 原生异步流式传输执行研究任务  
    生成响应时产生响应块  

    使用 agent.arun(stream=True) 返回 AsyncIterator[RunOutputEvent]  
    """  
    logger.info(f"运行流式任务: session_id={session_id}, query_length={len(query)}")  

    # 获取相关技能  
    relevant_skills = self.skill_library.get_relevant_skills(query)  
    logger.debug(f"检索到 {len(relevant_skills)} 个相关技能")  

    # 如果相关,添加技能上下文  
    if relevant_skills:  
        skill_text = "\n\n相关学习方法:\n"  
        for skill in relevant_skills:  
            skill_text += f"- {skill.name}: {skill.context}\n"  
        enhanced_query = query + skill_text  
    else:  
        enhanced_query = query  

    # 添加当前日期时间上下文  
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")  
    enhanced_query = f"当前日期和时间: {current_time}\n\n{enhanced_query}"  

    try:  
        accumulated_content = ""  
        run_response = None  
        tools_used = []  

        logger.debug("开始使用 Agno 原生异步流式传输 with arun()...")  

        # 使用 Agno 的原生异步流式传输 with arun(stream=True)  
        # 这会返回 AsyncIterator[RunOutputEvent]  
        async for event in self.agent.arun(  
            enhanced_query,  
            session_id=session_id,  
            user_id=user_id,  
            stream=True  
        ):  
            # 处理 RunOutputEvent 对象  
            # 每个事件都有一个 'event' 属性和可能的 'content'  

            try:  
                event_type = getattr(event, 'event', None)  

                # 处理不同的事件类型  
                if event_type == 'RunContent':  
                    # 这是实际要流式的的内容  
                    content = getattr(event, 'content', None)  

                    # 跳过推理内容(内部思考)  
                    if hasattr(event, 'reasoning_content') and event.reasoning_content:  
                        continue  

                    if content and isinstance(content, str) and content.strip():  
                        # 过滤掉工具执行消息  
                        if 'ToolExecution' in content or 'tool_call_id' in content:  
                            continue  

                        accumulated_content += content  
                        yield {  
                            'type': 'content',  
                            'content': content,  
                            'done': False  
                        }  

                elif event_type == 'ToolCallStarted':  
                    # 工具正在被调用  
                    tool_name = getattr(event, 'tool_name', 'tool')  

                    # 生成适当的状态消息  
                    if 'duckduckgo' in tool_name.lower() or 'search' in tool_name.lower():  
                        status_msg = "正在网络上搜索信息..."  
                    elif 'knowledge' in tool_name.lower():  
                        status_msg = "正在搜索知识库..."  
                    else:  
                        status_msg = f"使用 {tool_name}..."  

                    yield {  
                        'type': 'status',  
                        'status': 'tool_use',  
                        'message': status_msg  
                    }  

                elif event_type == 'ToolCallCompleted':  
                    # 工具调用完成  
                    tool_name = getattr(event, 'tool_name', 'tool')  
                    if tool_name not in tools_used:  
                        tools_used.append(tool_name)  

                elif event_type == 'RunOutput':  
                    # 最终输出事件  
                    run_response = event  

                    # 获取最终内容如果可用  
                    final_content = getattr(event, 'content', None)  
                    if final_content and isinstance(final_content, str):  
                        # 仅在与累积内容明显不同时添加  
                        if len(final_content) > len(accumulated_content):  
                            diff = final_content[len(accumulated_content):]  
                            if diff.strip():  
                                accumulated_content = final_content  
                                yield {  
                                    'type': 'content',  
                                    'content': diff,  
                                    'done': False  
                                }  

                    break  # 最终事件,退出循环  

            except Exception as e:  
                logger.warning(f"处理事件时出错: {e}")  
                continue  

        # 如果没有获得 RunOutput 事件,获取完整响应  
        if not run_response:  
            logger.debug("未收到 RunOutput 事件,获取完整响应...")  
            try:  
                run_response = await self.agent.arun(  
                    enhanced_query,  
                    session_id=session_id,  
                    user_id=user_id  
                )  
            except Exception as e:  
                logger.warning(f"无法获取完整响应: {e}")  

        # 运行评论者评估(安全包装)  
        critic_score = 0.0  
        if self.critic:  
            try:  
                logger.debug("运行评论者评估...")  
                # 在执行器中运行以避免阻塞异步循环  
                loop = asyncio.get_running_loop()  
                critic_score = await loop.run_in_executor(  
                    None,   
                    lambda: self.critic.evaluate(query, accumulated_content)  
                )  
                logger.info(f"评论者评分: {critic_score}")  
            except Exception as e:  
                logger.error(f"运行评论者时出错: {e}")  

        # 从最终响应中提取来源和元数据  
        sources = []  
        if run_response:  
            sources = self._extract_sources(run_response)  

            # 从响应中提取使用的工具,如果没有被捕捉到  
            if hasattr(run_response, 'tool_calls') and run_response.tool_calls:  
                for tool_call in run_response.tool_calls:  
                    tool_name = getattr(tool_call, 'name', str(tool_call))  
                    if tool_name not in tools_used:  
                        tools_used.append(tool_name)  

        # 缓存轨迹  
        if session_id:  
            trajectory = {  
                'query': query,  
                'response': accumulated_content,  
                'tools_used': tools_used,  
                'relevant_skills': [s.name for s in relevant_skills],  
                'sources': sources,  
                'critic_score': critic_score,  
                'session_id': session_id,  
                'user_id': user_id  
            }  
            logger.debug(f"缓存会话 {session_id} 的轨迹")  
            self.pending_trajectories[session_id] = trajectory  

        # 发送最终完成事件和元数据  
        yield {  
            'type': 'done',  
            'content': '',  
            'done': True,  
            'accumulated': accumulated_content,  
            'tools_used': tools_used,  
            'relevant_skills': [s.name for s in relevant_skills],  
            'sources': sources,  
            'full_response': run_response,  
            'critic_score': critic_score  
        }  

    except Exception as e:  
        logger.error(f"流式任务错误: {str(e)}", exc_info=True)  
        yield {  
            'type': 'error',  
            'content': f'错误: {str(e)}',  
            'done': True,  
            'error': str(e)  
        }

该端点创建一个异步生成器,生成事件。每个事件是一个 JSON 对象,包含类型。状态事件显示代理正在做什么(“分析您的问题…”,“查找信息…”)。内容事件包含响应块。完成事件表示完成并包含元数据。

代理的 run_task_stream 方法处理实际的流式传输:

    async def run_task_stream(
        self,
        query: str,
        session_id: Optional[str] = None,
        user_id: Optional[str] = None
    ):
        """
        Execute a research task with streaming response using Agno's native async streaming
        Yields chunks of the response as they are generated
        
        Uses agent.arun(stream=True) which returns AsyncIterator[RunOutputEvent]
        """
        logger.info(f"Running streaming task: session_id={session_id}, query_length={len(query)}")
        
        # Get relevant skills
        relevant_skills = self.skill_library.get_relevant_skills(query)
        logger.debug(f"Retrieved {len(relevant_skills)} relevant skills")
        
        # Add skill context if relevant
        if relevant_skills:
            skill_text = "\n\nRelevant learned approaches:\n"
            for skill in relevant_skills:
                skill_text += f"- {skill.name}: {skill.context}\n"
            enhanced_query = query + skill_text
        else:
            enhanced_query = query
            
        # Add current datetime context
        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        enhanced_query = f"Current Date and Time: {current_time}\n\n{enhanced_query}"
        
        try:
            accumulated_content = ""
            run_response = None
            tools_used = []
            
            logger.debug("Starting Agno native async streaming with arun()...")
            
            # Use Agno's native async streaming with arun(stream=True)
            # This returns an AsyncIterator[RunOutputEvent]
            async for event in self.agent.arun(
                enhanced_query,
                session_id=session_id,
                user_id=user_id,
                stream=True
            ):
                # Process RunOutputEvent objects
                # Each event has an 'event' attribute and possibly 'content'
                
                try:
                    event_type = getattr(event, 'event', None)
                    
                    # Handle different event types
                    if event_type == 'RunContent':
                        # This is actual content to stream
                        content = getattr(event, 'content', None)
                        
                        # Skip reasoning content (internal thinking)
                        if hasattr(event, 'reasoning_content') and event.reasoning_content:
                            continue
                        
                        if content and isinstance(content, str) and content.strip():
                            # Filter out tool execution messages
                            if 'ToolExecution' in content or 'tool_call_id' in content:
                                continue
                            
                            accumulated_content += content
                            yield {
                                'type': 'content',
                                'content': content,
                                'done': False
                            }
                    
                    elif event_type == 'ToolCallStarted':
                        # Tool is being called
                        tool_name = getattr(event, 'tool_name', 'tool')
                        
                        # Generate appropriate status message
                        if 'duckduckgo' in tool_name.lower() or 'search' in tool_name.lower():
                            status_msg = "Searching the web for information..."
                        elif 'knowledge' in tool_name.lower():
                            status_msg = "Searching knowledge base..."
                        else:
                            status_msg = f"Using {tool_name}..."
                        
                        yield {
                            'type': 'status',
                            'status': 'tool_use',
                            'message': status_msg
                        }
                    
                    elif event_type == 'ToolCallCompleted':
                        # Tool call finished
                        tool_name = getattr(event, 'tool_name', 'tool')
                        if tool_name not in tools_used:
                            tools_used.append(tool_name)
                    
                    elif event_type == 'RunOutput':
                        # Final output event
                        run_response = event
                        
                        # Get final content if available
                        final_content = getattr(event, 'content', None)
                        if final_content and isinstance(final_content, str):
                            # Only add if significantly different from accumulated
                            if len(final_content) > len(accumulated_content):
                                diff = final_content[len(accumulated_content):]
                                if diff.strip():
                                    accumulated_content = final_content
                                    yield {
                                        'type': 'content',
                                        'content': diff,
                                        'done': False
                                    }
                        
                        break  # Final event, exit loop
                
                except Exception as e:
                    logger.warning(f"Error processing event: {e}")
                    continue
            
            # If we didn't get a RunOutput event, get the full response
            if not run_response:
                logger.debug("No RunOutput event received, fetching full response...")
                try:
                    run_response = await self.agent.arun(
                        enhanced_query,
                        session_id=session_id,
                        user_id=user_id
                    )
                except Exception as e:
                    logger.warning(f"Could not get full response: {e}")

            # Run Critic evaluation (async safe wrapper)
            critic_score = 0.0
            if self.critic:
                try:
                    logger.debug("Running critic evaluation...")
                    # Run in executor to avoid blocking async loop
                    loop = asyncio.get_running_loop()
                    critic_score = await loop.run_in_executor(
                        None, 
                        lambda: self.critic.evaluate(query, accumulated_content)
                    )
                    logger.info(f"Critic score: {critic_score}")
                except Exception as e:
                    logger.error(f"Error running critic: {e}")
            
            # Extract sources and metadata from final response
            sources = []
            if run_response:
                sources = self._extract_sources(run_response)
                
                # Extract tools used from response if not already captured
                if hasattr(run_response, 'tool_calls') and run_response.tool_calls:
                    for tool_call in run_response.tool_calls:
                        tool_name = getattr(tool_call, 'name', str(tool_call))
                        if tool_name not in tools_used:
                            tools_used.append(tool_name)
            
            # Cache trajectory
            if session_id:
                trajectory = {
                    'query': query,
                    'response': accumulated_content,
                    'tools_used': tools_used,
                    'relevant_skills': [s.name for s in relevant_skills],
                    'sources': sources,
                    'critic_score': critic_score,
                    'session_id': session_id,
                    'user_id': user_id
                }
                logger.debug(f"Caching trajectory for session {session_id}")
                self.pending_trajectories[session_id] = trajectory

            # Send final done event with metadata
            yield {
                'type': 'done',
                'content': '',
                'done': True,
                'accumulated': accumulated_content,
                'tools_used': tools_used,
                'relevant_skills': [s.name for s in relevant_skills],
                'sources': sources,
                'full_response': run_response,
                'critic_score': critic_score
            }
        
        except Exception as e:
            logger.error(f"Error in streaming task: {str(e)}", exc_info=True)
            yield {
                'type': 'error',
                'content': f'Error: {str(e)}',
                'done': True,
                'error': str(e)
            }

该方法使用 Agno 的原生异步流式传输。它遍历代理的事件,过滤掉内部推理,并生成内容块。它还会在使用工具时生成状态更新,让用户知道发生了什么。

前端处理这些事件并实时更新 UI。用户看到响应逐字出现,这感觉更快。

流式传输还提供状态更新。用户看到“分析您的问题…”然后是“查找信息…”,然后是实际的响应。这种透明度让等待感觉更短。

体验与非流式传输完全不同。不是盯着加载指示器等待十秒钟,用户看到进度立即出现。响应在一秒或两秒内开始出现。

这就是流式传输的重要性。它不仅仅是速度。它是感知速度。用户觉得系统是响应的,即使总时间相同。

8、前端架构

前端是纯 HTML、CSS 和 JavaScript。没有框架。没有构建步骤。只是原生 JS,可以在任何地方工作。

我选择这种方法是为了简单。前端不需要 React 或 Vue。它只需要显示消息,处理流式传输,并收集反馈。原生 JS 完美地做到了这一点。

以下是前端如何处理流式传输:

// Send message to agent with streaming
async function sendMessage() {
    const message = elements.messageInput.value.trim();
    if (!message) return;

    // Disable input
    elements.sendBtn.disabled = true;
    elements.messageInput.disabled = true;

    // Add user message
    addMessage('user', message);

    // Clear input
    elements.messageInput.value = '';
    elements.messageInput.style.height = 'auto';

    // Show loading with activity updates
    const loadingId = addLoadingMessage('Analyzing your question...');

    // Create agent message container for streaming
    let agentMessageId = null;
    let accumulatedContent = '';
    let toolsUsed = [];
    let relevantSkills = [];

    try {
        const response = await fetch(`${API_BASE}/chat/stream`, {
            method: 'POST',
            headers: { 'Content-Type': 'application/json' },
            body: JSON.stringify({
                message,
                session_id: state.sessionId,
                user_id: state.userId
            })
        });

        if (!response.ok) throw new Error('Failed to get response');

        // Handle streaming response
        const reader = response.body.getReader();
        const decoder = new TextDecoder();

        while (true) {
            const { done, value } = await reader.read();
            if (done) break;

            const chunk = decoder.decode(value, { stream: true });
            const lines = chunk.split('\n');

            for (const line of lines) {
                if (line.startsWith('data: ')) {
                    try {
                        const data = JSON.parse(line.slice(6));

                        if (data.type === 'status') {
                            // If agent message doesn't exist yet, create it
                            if (!agentMessageId) {
                                removeLoadingMessage(loadingId);
                                agentMessageId = addStreamingMessage('agent', '');
                            }
                            addThinkingStep(agentMessageId, data.message);
                        } else if (data.type === 'content') {
                            // Remove loading message on first content
                            if (!agentMessageId) {
                                removeLoadingMessage(loadingId);
                                agentMessageId = addStreamingMessage('agent', '');
                            }

                            // Accumulate content
                            accumulatedContent += data.content;

                            // Update streaming message
                            updateStreamingMessage(agentMessageId, accumulatedContent);
                        } else if (data.type === 'done') {
                            if (agentMessageId) {
                                finalizeThinkingProcess(agentMessageId);
                            }
                            // Finalize message
                            toolsUsed = data.tools_used || [];
                            relevantSkills = data.relevant_skills || [];
                            const sources = data.sources || [];

                            // Use accumulated content from done event if available, otherwise use our accumulated
                            const finalContent = data.accumulated || accumulatedContent;

                            // Replace streaming message with final message
                            if (agentMessageId) {
                                replaceStreamingMessage(agentMessageId, finalContent, {
                                    tools: toolsUsed,
                                    skills: relevantSkills,
                                    sources: sources,
                                    critic_score: data.critic_score
                                });
                            } else {
                                // Fallback: create message if streaming didn't work
                                removeLoadingMessage(loadingId);
                                addMessage('agent', finalContent, {
                                    tools: toolsUsed,
                                    skills: relevantSkills,
                                    sources: sources,
                                    critic_score: data.critic_score
                                });
                            }

                            // Update trajectory
                            state.lastTrajectory = {
                                message: accumulatedContent,
                                tools_used: toolsUsed,
                                relevant_skills: relevantSkills,
                                sources: sources,
                                critic_score: data.critic_score
                            };
                            updateTrajectoryInfo(state.lastTrajectory);

                            // Update message count
                            state.messageCount++;
                            elements.messageCountDisplay.textContent = state.messageCount;

                            // Update conversation title if this is the first message
                            if (state.messageCount === 1) {
                                updateConversationTitle(message);
                            }

                            // Update conversation in sidebar
                            updateConversationInSidebar();

                            // Update stats
                            updateStats();
                        } else if (data.type === 'error') {
                            // Handle error
                            removeLoadingMessage(loadingId);
                            if (agentMessageId) {
                                replaceStreamingMessage(agentMessageId, `Error: ${data.error}`, { error: true });
                            } else {
                                addMessage('agent', `Error: ${data.error}`, { error: true });
                            }
                        }
                    } catch (parseError) {
                        console.warn('Failed to parse SSE data:', parseError, line);
                    }
                }
            }
        }

    } catch (error) {
        console.error('Error:', error);
        removeLoadingMessage(loadingId);
        if (agentMessageId) {
            replaceStreamingMessage(agentMessageId, 'Sorry, I encountered an error. Please try again.', { error: true });
        } else {
            addMessage('agent', 'Sorry, I encountered an error. Please try again.', { error: true });
        }
    } finally {
        elements.sendBtn.disabled = false;
        elements.messageInput.disabled = false;
        elements.messageInput.focus();
    }
}

该函数逐块读取流,解析每个事件,并相应地更新 UI。状态事件显示思考步骤。内容事件更新消息文本。完成事件用元数据最终确定消息。

反馈收集很简单:

// Submit feedback
async function submitFeedback() {
    if (!state.lastTrajectory) {
        showFeedbackStatus('Please send a message first', 'error');
        return;
    }

    const feedbackData = {
        session_id: state.sessionId,
        task_success: parseFloat(elements.taskSuccess.value),
        quality_score: parseFloat(elements.quality.value),
        efficiency_score: parseFloat(elements.efficiency.value),
        user_feedback: parseFloat(elements.userFeedback.value)
    };

    // Add skill if checkbox is checked
    if (elements.createSkill.checked) {
        const skillName = document.getElementById('skillName').value.trim();
        const skillDesc = document.getElementById('skillDescription').value.trim();
        const skillContext = document.getElementById('skillContext').value.trim();

        if (!skillName || !skillDesc || !skillContext) {
            showFeedbackStatus('Please fill in all skill fields', 'error');
            return;
        }

        feedbackData.learned_skill = {
            name: skillName,
            description: skillDesc,
            context: skillContext,
            success_rate: feedbackData.task_success
        };
    }

    try {
        const response = await fetch(`${API_BASE}/feedback`, {
            method: 'POST',
            headers: { 'Content-Type': 'application/json' },
            body: JSON.stringify(feedbackData)
        });

        if (!response.ok) throw new Error('Failed to submit feedback');

        const result = await response.json();

        showFeedbackStatus(
            `Feedback submitted! Reward: ${result.total_reward.toFixed(2)}`,
            'success'
        );

        // Reset form
        elements.taskSuccess.value = 0.5;
        elements.quality.value = 0.5;
        elements.efficiency.value = 0.5;
        elements.userFeedback.value = 0;
        elements.createSkill.checked = false;
        elements.skillForm.style.display = 'none';

        // Update values
        elements.taskSuccessValue.textContent = '0.5';
        elements.qualityValue.textContent = '0.5';
        elements.efficiencyValue.textContent = '0.5';
        elements.userFeedbackValue.textContent = '0.0';

        // Update stats
        updateStats();

    } catch (error) {
        console.error('Error submitting feedback:', error);
        showFeedbackStatus('Failed to submit feedback', 'error');
    }
}

用户使用滑块对响应进行评分,可选地创建一个技能,然后提交。反馈被发送到后端,后端处理它并更新代理。

统计仪表板显示实时指标:

// Update stats
async function updateStats() {
    try {
        const response = await fetch(`${API_BASE}/stats`);
        if (!response.ok) throw new Error('Failed to fetch stats');

        const stats = await response.json();

        elements.totalTasks.textContent = stats.total_tasks;
        elements.skillCount.textContent = stats.skill_count;
        elements.avgReward.textContent = stats.average_reward.toFixed(2);

        const successRate = stats.total_tasks > 0
            ? ((stats.successful_tasks / stats.total_tasks) * 100).toFixed(1)
            : 0;
        elements.successRate.textContent = `${successRate}%`;

        // Update top skills
        if (stats.top_skills && stats.top_skills.length > 0) {
            elements.topSkillsList.innerHTML = stats.top_skills.map(skill => `
                <div class="skill-item">
                    <div class="skill-name">${skill.name}</div>
                    <div class="skill-meta">
                        <span>${(skill.success_rate * 100).toFixed(0)}% success</span>
                        <span>${skill.usage} uses</span>
                    </div>
                </div>
            `).join('');
        } else {
            elements.topSkillsList.innerHTML = '<p class="empty-state">No skills learned yet</p>';
        }

    } catch (error) {
        console.error('Error updating stats:', error);
    }
}

仪表板显示总任务数、成功率、平均奖励、技能数和顶级表现技能。这为用户提供了代理如何改进的可见性。

UI 简洁且功能齐全。它专注于对话,使反馈易于收集,并显示进度。没有不必要的复杂性。只需所需的东西。

前端是用户交互的地方。这就是魔法发生的地方。流式传输响应、收集反馈、显示进度。所有这些都在这里发生。

接下来,让我们谈谈如何使其适用于生产环境。

10、生产部署

构建一个能正常工作的系统是一回事。构建一个能在生产环境中正常工作的系统是另一回事。以下是我为使 REFLEX 适用于生产所做的工作。

安全性排在首位。用户输入经过验证和清理:

class SecurityValidator:
    """Security validation utilities"""
    
    @staticmethod
    def sanitize_input(text: str, max_length: int = 10000) -> str:
        """
        Sanitize user input to prevent injection attacks
        
        Args:
            text: Input text to sanitize
            max_length: Maximum allowed length
            
        Returns:
            Sanitized text
            
        Raises:
            ValueError: If input exceeds max length
        """
        # Remove null bytes
        text = text.replace('\x00', '')
        
        # Limit length
        if len(text) > max_length:
            raise ValueError(f"Input exceeds maximum length of {max_length}")
        
        # HTML escape for safety (backend sanitization)
        # Note: Frontend should also sanitize with DOMPurify
        text = html.escape(text)
        
        return text.strip()
    
    @staticmethod
    def validate_url(url: str) -> bool:
        """
        Validate URL format and safety
        
        Args:
            url: URL to validate
            
        Returns:
            True if URL is valid and safe, False otherwise
        """
        # Check basic format
        if not url.startswith(('http://', 'https://')):
            logger.warning(f"Invalid URL protocol: {url}")
            return False
        
        # Block local/private addresses to prevent SSRF
        blocked_patterns = [
            r'localhost',
            r'127\.0\.0\.',
            r'192\.168\.',
            r'10\.',
            r'172\.(1[6-9]|2[0-9]|3[01])\.',
            r'::1',
            r'0\.0\.0\.0'
        ]
        
        for pattern in blocked_patterns:
            if re.search(pattern, url, re.IGNORECASE):
                logger.warning(f"Blocked private/local URL: {url}")
                return False
        
        # Check URL length
        if len(url) > 2048:
            logger.warning(f"URL too long: {len(url)} chars")
            return False
        
        return True
    
    @staticmethod
    def validate_session_id(session_id: str) -> bool:
        """
        Validate session ID format
        
        Args:
            session_id: Session ID to validate
            
        Returns:
            True if valid, False otherwise
        """
        # Session IDs should be alphanumeric with underscores/hyphens
        if not re.match(r'^[a-zA-Z0-9_-]+$', session_id):
            logger.warning(f"Invalid session ID format: {session_id}")
            return False
        
        # Reasonable length limits
        if len(session_id) < 3 or len(session_id) > 128:
            logger.warning(f"Invalid session ID length: {len(session_id)}")
            return False
        
        return True
    
    @staticmethod
    def validate_user_id(user_id: str) -> bool:
        """
        Validate user ID format
        
        Args:
            user_id: User ID to validate
            
        Returns:
            True if valid, False otherwise
        """
        # Similar to session ID validation
        if not re.match(r'^[a-zA-Z0-9_-]+$', user_id):
            logger.warning(f"Invalid user ID format: {user_id}")
            return False
        
        if len(user_id) < 3 or len(user_id) > 128:
            logger.warning(f"Invalid user ID length: {len(user_id)}")
            return False
        
        return True

输入被消毒以防止注入攻击。URL 被验证以防止 SSRF(服务器端请求伪造)。会话和用户 ID 被验证以防止格式错误。

速率限制防止滥用:

# Initialize rate limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

限速器被应用于聊天端点,防止单个用户淹没系统。默认是每分钟 10 次请求,可配置。

CORS 被配置为只允许受信任的源:

# Add CORS middleware with security
logger.info(f"CORS origins: {settings.cors_origins_list}")
app.add_middleware(
    CORSMiddleware,
    allow_origins=settings.cors_origins_list,
    allow_credentials=True,
    allow_methods=["GET", "POST", "DELETE", "PUT"],
    allow_headers=["Content-Type", "Authorization"],
)

这防止了未经授权的域名访问 API。在生产中,您会将其设置为实际的前端域名。

性能通过多种方式进行优化。轨迹缓冲区有最大大小(10,000 条轨迹)以防止内存问题。轨迹存储在数据库中以实现持久化,但只保留最近的在内存中。

数据库查询通过索引加快速度:

        # 创建索引以加快查询  
        cursor.execute("""  
            CREATE INDEX IF NOT EXISTS idx_session_id   
            ON messages(session_id)  
        """)  

        cursor.execute("""  
            CREATE INDEX IF NOT EXISTS idx_timestamp   
            ON messages(timestamp)  
        """)

会话 ID 和时间戳索引使对话历史检索速度快,即使有数千条消息。

部署通过 Docker 进行:

services:
  backend:
    build:
      context: ./backend
      dockerfile: Dockerfile
    container_name: reflex-backend
    ports:
      - "8000:8000"
    environment:
      - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
      - OPENAI_API_KEY=${OPENAI_API_KEY}
      - COHERE_API_KEY=${COHERE_API_KEY:-}
    volumes:
      - ./data:/app/data
      - ./backend:/app
    env_file:
      - .env
    restart: unless-stopped
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/api/health"]
      interval: 30s
      timeout: 10s
      retries: 3
      start_period: 40s
    networks:
      - reflex-network

  frontend:
    build:
      context: ./frontend
      dockerfile: Dockerfile
    container_name: reflex-frontend
    ports:
      - "3000:80"
    depends_on:
      - backend
    environment:
      - API_BASE_URL=http://localhost:8000/api
    volumes:
      - ./frontend:/usr/share/nginx/html
    restart: unless-stopped
    networks:
      - reflex-network

networks:
  reflex-network:
    driver: bridge

volumes:
  data:
    driver: local

Docker Compose 设置了两个服务,处理网络和管理卷。健康检查确保在标记为健康之前后端确实正在运行。

环境变量从 .env 文件加载,将秘密保留在代码之外。前端使用 Nginx 高效地提供静态文件。

内置了监控功能。统计信息端点提供指标:

@app.get("/api/stats", response_model=AgentStats)
async def get_stats():
    """
    Get training statistics and agent metrics
    """
    if not agent_instance:
        raise HTTPException(status_code=503, detail="Agent not initialized")
    
    try:
        stats = agent_instance.get_stats()
        
        return AgentStats(
            total_tasks=stats['training_stats']['total_tasks'],
            successful_tasks=stats['training_stats']['successful_tasks'],
            average_reward=stats['training_stats']['average_reward'],
            improvement_rate=stats['training_stats']['improvement_rate'],
            skill_count=stats['skill_count'],
            trajectory_count=stats['trajectory_count'],
            top_skills=stats['top_skills']
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Stats error: {str(e)}")

这些指标可以导出到 Prometheus 或 Datadog 等监控系统。目前,它们在 UI 中显示。

错误处理涵盖了所有方面。Try-except 块捕获错误,记录它们,并返回适当的 HTTP 状态码。用户永远不会看到堆栈跟踪。他们看到的是友好的错误消息。

日志记录根据不同的环境进行配置:

# Configure logging with settings
logging.basicConfig(
    level=getattr(logging, settings.log_level),
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

在开发中,你会看到 DEBUG 日志。在生产中,你会看到 INFO 及以上级别的日志。这使日志易于管理,同时仍提供可见性。

这些生产考虑因素使 REFLEX 适合真实用户。安全性防止攻击。性能处理负载。部署是自动化的。监控提供可见性。错误处理防止崩溃。

它并不完美。总有更多工作要做。但它足以用于生产用途。

11、仓库

REFLEX 的所有代码都是开源的,可在 GitHub 上找到。该仓库包含您需要构建和运行自己的自我改进研究代理的所有内容。

包括什么:

该仓库包含了我们一直在讨论的完整实现。后端使用 Python 和 FastAPI,前端使用原生 JavaScript,所有内容都通过 Docker 容器化,以便于部署。

要探索的关键文件:

  • backend/agent_core.py - 具有 RL 功能的主要代理实现
  • backend/rl_trainer.py - PPO、优先级重放和高级 RL 算法
  • backend/rag_enhancer.py - 增强的 RAG,具有重新排序和多跳检索
  • backend/main.py - FastAPI 端点和流式传输实现
  • frontend/app.js - 具有实时流的前端 UI
  • docker-compose.yml - 生产部署配置

开始使用:

克隆仓库,设置您的 API 密钥(Anthropic 和 OpenAI),并运行设置脚本。README 中有安装、配置和部署的详细说明。

它有什么特别之处:

这不仅仅是一个演示。这是一个可以直接使用的实际实现。代码有良好的文档,架构是模块化的,您可以添加自己的功能。

想要贡献?发现了一个 bug?有改进的想法?该仓库欢迎 pull 请求、问题和讨论。这是一个学习项目,贡献使它对每个人都有所改善。


原文链接:Let’s Build a Self-Improving AI Agent That Learns From Your Feedback

汇智网翻译整理,转载请标明出处