从零实现反思智能体

我们将从零开始构建一个多AI智能体系统,探索基本的智能体设计模式:反射、工具使用、规划和多智能体设置。

从零实现反思智能体

阿瑟·C·克拉克第三定律指出,“任何足够先进的技术都与魔法无异”。这正是当今许多 AI 框架给人的感觉。GitHub Copilot、Claude Desktop、OpenAI Operator 和 Perplexity Comet 等工具正在自动化处理五年前看似不可能实现的日常任务。更令人惊叹的是,只需几行代码,我们就能构建出复杂的AI工具:它们可以搜索文件、浏览网页、点击链接,甚至还能进行购物。这真的感觉像魔法一样。

虽然我真心相信数据魔法的存在,但我并不相信魔法。我发现了解事物的构建方式以及其底层运行机制令人兴奋(而且通常很有帮助)。因此,我决定分享一系列关于智能体AI设计概念的文章,帮助你理解这些神奇工具的工作原理。

为了深入理解,我们将从零开始构建一个多AI智能体系统。我们将避免使用CrewAI或smolagents之类的框架,而是直接使用基础模型API。在此过程中,我们将探索基本的智能体设计模式:反射、工具使用、规划和多智能体设置。然后,我们将结合所有这些知识来构建一个能够回答复杂数据相关问题的多AI智能体系统。

正如理查德·费曼所说:“我无法创造的东西,我就无法理解。” 所以,让我们开始构建吧!在本文中,我们将重点讨论反射设计模式。但首先,让我们弄清楚反射究竟是什么。

1、什么是反射

让我们反思一下我们(人类)通常是如何处理任务的。假设我需要向我的产品经理分享最近一次功能发布的结果。我可能会快速写一个草稿,然后从头到尾读一两遍,确保所有内容一致、信息充足且没有拼写错误。

或者我们再举一个例子:编写 SQL 查询。我可能会一步一步地写,并在过程中检查中间结果;或者(如果足够简单)我会一次性写完,执行查询,查看结果(检查错误或结果是否符合我的预期),然后根据反馈调整查询。我可能会重新运行查询,检查结果,并不断迭代,直到它正确为止。

所以我们很少会一口气写完长篇文本。我们通常会反复修改、回顾和调整。正是这些反馈循环帮助我们提升工作质量。

LLM采用不同的方法。如果你向 LLM 提出一个问题,默认情况下,它会逐个生成答案,而 LLM 无法回顾自己的结果并修复任何问题。但在智能体 AI 设置中,我们也可以为 LLM 创建反馈循环,例如让 LLM 回顾并改进自己的答案,或者与它分享外部反馈(例如 SQL 执行结果)。这正是反思的意义所在。听起来很简单,但它可以带来显著更好的结果。

大量研究表明了反思的益处:

  • “自我改进:基于自我反馈的迭代改进”,Madaan 等人。 (2023) 的研究表明,自我改进在从对话回复生成到数学推理等各种任务中,性能提升了约 20%。
图片来自 Madaan 等人的论文“Self-Refine: Iterative Refinement with Self-Feedback”。
  • 在 Shinn 等人 (2023) 的论文“Reflexion: Language Agents with Verbal Reinforcement Learning”中,作者在 HumanEval 编码基准测试中取得了 91% 的 pass@1 准确率,超越了之前最先进的 GPT-4(其准确率仅为 80%)。他们还发现,Reflexion 在 HotPotQA 基准测试(一个基于维基百科的问答数据集,要求智能体解析内容并对多个支持文档进行推理)中显著优于所有基线方法。
图片来自 Shinn 等人的论文“Reflexion: Language Agents with Verbal Reinforcement Learning”。
  • Gou 等人 (2024) 的论文“CRITIC: Large Language Models Can Self-Correct with Tool-Interactive Critiquing”(《评论:大型语言模型可通过工具交互式评论进行自我纠错》)着重探讨了外部反馈的影响,使大型语言模型能够利用外部工具来验证和纠正自身的输出。该方法在回答自由形式问题到解决数学问题等各种任务中,准确率均提高了 10% 至 30%。

反思在智能体系统中尤为重要,因为它可以在流程的许多步骤中进行纠正:

  • 当用户提出问题时问题在于,LLM 可以使用反射来评估请求是否可行。
  • 当 LLM 制定初始计划时,它可以使用反射来再次检查该计划是否合理,以及是否能够帮助实现目标。
  • 在每次执行步骤或工具调用之后,代理可以评估其是否按计划进行,以及是否值得调整计划。
  • 当计划完全执行完毕后,代理可以进行反射,以查看是否真正实现了目标并解决了任务。

显然,反射可以显著提高准确性。然而,其中也存在一些值得讨论的权衡。反射可能需要多次调用 LLM 以及其他系统,这会导致延迟和成本增加。因此,在实际应用中,值得考虑质量的提升是否足以抵消用户流程中的成本和延迟。

2、框架中的反射

由于反射无疑能为 AI 代理带来价值,因此它被广泛应用于流行的框架中。让我们来看一些示例。

反思的概念最早由姚等人(2022)在论文《ReAct:语言模型中推理与行动的协同作用》中提出。ReAct 框架将推理(通过显式思维轨迹进行反思)和行动(在环境中执行与任务相关的动作)两个阶段交错进行。在这个框架中,推理指导行动的选择,而行动则产生新的观察结果,这些结果又为进一步的推理提供信息。推理阶段本身就是反思和规划的结合。

该框架非常流行,因此现在有很多现成的实现,例如:

  • Databricks 的 DSPy 框架包含一个 ReAct 类
  • 在 LangGraph 中,可以使用 create_react_agent 函数
  • HuggingFace 的 smolagents 库中的代码代理也基于 ReAct 架构

3、从零开始实现反思

现在我们已经学习了理论并探索了现有的实现,是时候动手实践,自己构建一些东西了。在 ReAct 方法中,智能体在每个步骤都使用反思,将规划与反思相结合。然而,为了更清晰地理解反思的影响,我们将单独考察它。

例如,我们将使用文本到 SQL 的转换:我们会向 LLM 提出一个问题,并期望它返回一个有效的 SQL 查询。我们将使用航班延误数据集和 ClickHouse SQL 方言。

我们将首先使用不带任何反思的直接生成作为基准。然后,我们将尝试使用反思,例如要求模型对 SQL 进行评估和改进,或者向其提供额外的反馈。之后,我们将衡量答案的质量,以查看反思是否真的能带来更好的结果。

3.1 直接生成

我们将从最直接的方法——直接生成开始,即要求 LLM 生成能够回答用户查询的 SQL。

但是,在深入实现之前,让我们先进行一些设置。我们将使用基础模型 API。我更倾向于使用 Anthropic API,但您可以根据自己的喜好选择模型,因为 API 通常都很相似。首先,我们来安装 Python 包。

pip install anthropic

我们需要为 Anthropic API 指定 API 密钥。

import os
os.environ['ANTHROPIC_API_KEY'] = config['ANTHROPIC_API_KEY']

下一步是初始化客户端,一切就绪。

import anthropic
client = anthropic.Anthropic()

现在我们可以使用此客户端向 LLM 发送消息。接下来,我们编写一个函数,根据用户查询生成 SQL。我已经指定了包含基本说明和数据模式详细信息的系统提示。我还创建了一个函数,用于将系统提示和用户查询发送到 LLM。

base_sql_system_prompt = '''
You are a senior SQL developer and your task is to help generate a SQL query based on user requirements. 
You are working with ClickHouse database. Specify the format (Tab Separated With Names) in the SQL query output to ensure that column names are included in the output.
Do not use count(*) in your queries since it's a bad practice with columnar databases, prefer using count().
Ensure that the query is syntactically correct and optimized for performance, taking into account ClickHouse specific features (i.e. that ClickHouse is a columnar database and supports functions like ARRAY JOIN, SAMPLE, etc.).
Return only the SQL query without any additional explanations or comments.

You will be working with flight_data table which has the following schema:

Column Name | Data Type | Null % | Example Value | Description
--- | --- | --- | --- | ---
year | Int64 | 0.0 | 2024 | Year of flight
month | Int64 | 0.0 | 1 | Month of flight (1–12)
day_of_month | Int64 | 0.0 | 1 | Day of the month
day_of_week | Int64 | 0.0 | 1 | Day of week (1=Monday … 7=Sunday)
fl_date | datetime64[ns] | 0.0 | 2024-01-01 00:00:00 | Flight date (YYYY-MM-DD)
op_unique_carrier | object | 0.0 | 9E | Unique carrier code
op_carrier_fl_num | float64 | 0.0 | 4814.0 | Flight number for reporting airline
origin | object | 0.0 | JFK | Origin airport code
origin_city_name | object | 0.0 | "New York, NY" | Origin city name
origin_state_nm | object | 0.0 | New York | Origin state name
dest | object | 0.0 | DTW | Destination airport code
dest_city_name | object | 0.0 | "Detroit, MI" | Destination city name
dest_state_nm | object | 0.0 | Michigan | Destination state name
crs_dep_time | Int64 | 0.0 | 1252 | Scheduled departure time (local, hhmm)
dep_time | float64 | 1.31 | 1247.0 | Actual departure time (local, hhmm)
dep_delay | float64 | 1.31 | -5.0 | Departure delay in minutes (negative if early)
taxi_out | float64 | 1.35 | 31.0 | Taxi out time in minutes
wheels_off | float64 | 1.35 | 1318.0 | Wheels-off time (local, hhmm)
wheels_on | float64 | 1.38 | 1442.0 | Wheels-on time (local, hhmm)
taxi_in | float64 | 1.38 | 7.0 | Taxi in time in minutes
crs_arr_time | Int64 | 0.0 | 1508 | Scheduled arrival time (local, hhmm)
arr_time | float64 | 1.38 | 1449.0 | Actual arrival time (local, hhmm)
arr_delay | float64 | 1.61 | -19.0 | Arrival delay in minutes (negative if early)
cancelled | int64 | 0.0 | 0 | Cancelled flight indicator (0=No, 1=Yes)
cancellation_code | object | 98.64 | B | Reason for cancellation (if cancelled)
diverted | int64 | 0.0 | 0 | Diverted flight indicator (0=No, 1=Yes)
crs_elapsed_time | float64 | 0.0 | 136.0 | Scheduled elapsed time in minutes
actual_elapsed_time | float64 | 1.61 | 122.0 | Actual elapsed time in minutes
air_time | float64 | 1.61 | 84.0 | Flight time in minutes
distance | float64 | 0.0 | 509.0 | Distance between origin and destination (miles)
carrier_delay | int64 | 0.0 | 0 | Carrier-related delay in minutes
weather_delay | int64 | 0.0 | 0 | Weather-related delay in minutes
nas_delay | int64 | 0.0 | 0 | National Air System delay in minutes
security_delay | int64 | 0.0 | 0 | Security delay in minutes
late_aircraft_delay | int64 | 0.0 | 0 | Late aircraft delay in minutes
'''

def generate_direct_sql(rec):
  # making an LLM call
  message = client.messages.create(
    model = "claude-3-5-haiku-latest",
    # I chose smaller model so that it's easier for us to see the impact 
    max_tokens = 8192,
    system=base_sql_system_prompt,
    messages = [
        {'role': 'user', 'content': rec['question']}
    ]
  )

  sql  = message.content[0].text
  
  # cleaning the output
  if sql.endswith('```'):
    sql = sql[:-3]
  if sql.startswith('```sql'):
    sql = sql[6:]
  return sql

就这样。现在让我们测试一下文本转 SQL 的解决方案。我创建了一个包含 20 个问答对的小型评估集,我们可以用它来检查系统是否运行良好。以下是一个示例:

{
'question': 'What was the highest speed in mph?',
'answer': '''
    select max(distance / (air_time / 60)) as max_speed 
    from flight_data 
    where air_time > 0 
    format TabSeparatedWithNames'''
}

让我们使用文本转 SQL 函数为测试集中所有用户查询生成 SQL 语句。

# load evaluation set
with open('./data/flight_data_qa_pairs.json', 'r') as f:
    qa_pairs = json.load(f)
qa_pairs_df = pd.DataFrame(qa_pairs)


tmp = []
# executing LLM for each question in our eval set
for rec in tqdm.tqdm(qa_pairs_df.to_dict('records')):
    llm_sql = generate_direct_sql(rec)
    tmp.append(
        {
            'id': rec['id'],
            'llm_direct_sql': llm_sql
        }
    )

llm_direct_df = pd.DataFrame(tmp)
direct_result_df = qa_pairs_df.merge(llm_direct_df, on = 'id')

现在我们有了答案,下一步是评估质量。

3.2 评估质量

遗憾的是,这种情况没有唯一的正确答案,所以我们不能直接将 LLM 生成的 SQL 与参考答案进行比较。我们需要找到一种评估质量的方法。

有些质量方面我们可以使用客观标准进行检查,但要检查 LLM 是否返回了正确答案,我们需要使用 LLM。因此,我将结合以下几种方法:

  • 首先,我们将使用客观标准来检查 SQL 中是否指定了正确的格式(我们指示 LLM 使用 TabSeparatedWithNames)。
  • 其次,我们可以执行生成的查询,看看 ClickHouse 是否返回执行错误。
  • 最后,我们可以创建一个 LLM 评判器,将生成的查询的输出与我们的参考答案进行比较,并检查它们是否不同。

让我们从执行 SQL 开始。值得注意的是,我们的 get_clickhouse_data 函数不会抛出异常。相反,它会返回解释错误的文本,该文本稍后可由 LLM 处理。

CH_HOST = 'http://localhost:8123' # default address 
import requests
import pandas as pd
import tqdm

# function to execute SQL query
def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
  r = requests.post(host, params = {'query': query}, 
    timeout = connection_timeout)
  if r.status_code == 200:
      return r.text
  else: 
      return 'Database returned the following error:n' + r.text

# getting the results of SQL execution
direct_result_df['llm_direct_output'] = direct_result_df['llm_direct_sql'].apply(get_clickhouse_data)
direct_result_df['answer_output'] = direct_result_df['answer'].apply(get_clickhouse_data)

下一步是创建一个 LLM 评判器。为此,我采用了一种链式思维方法,提示 LLM 在给出最终答案之前提供其推理过程。这给了模型时间来思考问题,从而提高了响应质量。

llm_judge_system_prompt = '''
You are a senior analyst and your task is to compare two SQL query results and determine if they are equivalent. 
Focus only on the data returned by the queries, ignoring any formatting differences. 
Take into account the initial user query and information needed to answer it. For example, if user asked for the average distance, and both queries return the same average value but in one of them there's also a count of records, you should consider them equivalent, since both provide the same requested information.

Answer with a JSON of the following structure:
{
  'reasoning': '<your reasoning here, 1-3 sentences on why you think they are equivalent or not>', 
  'equivalence': <true|false>
}
Ensure that ONLY JSON is in the output. 

You will be working with flight_data table which has the following schema:
Column Name | Data Type | Null % | Example Value | Description
--- | --- | --- | --- | ---
year | Int64 | 0.0 | 2024 | Year of flight
month | Int64 | 0.0 | 1 | Month of flight (1–12)
day_of_month | Int64 | 0.0 | 1 | Day of the month
day_of_week | Int64 | 0.0 | 1 | Day of week (1=Monday … 7=Sunday)
fl_date | datetime64[ns] | 0.0 | 2024-01-01 00:00:00 | Flight date (YYYY-MM-DD)
op_unique_carrier | object | 0.0 | 9E | Unique carrier code
op_carrier_fl_num | float64 | 0.0 | 4814.0 | Flight number for reporting airline
origin | object | 0.0 | JFK | Origin airport code
origin_city_name | object | 0.0 | "New York, NY" | Origin city name
origin_state_nm | object | 0.0 | New York | Origin state name
dest | object | 0.0 | DTW | Destination airport code
dest_city_name | object | 0.0 | "Detroit, MI" | Destination city name
dest_state_nm | object | 0.0 | Michigan | Destination state name
crs_dep_time | Int64 | 0.0 | 1252 | Scheduled departure time (local, hhmm)
dep_time | float64 | 1.31 | 1247.0 | Actual departure time (local, hhmm)
dep_delay | float64 | 1.31 | -5.0 | Departure delay in minutes (negative if early)
taxi_out | float64 | 1.35 | 31.0 | Taxi out time in minutes
wheels_off | float64 | 1.35 | 1318.0 | Wheels-off time (local, hhmm)
wheels_on | float64 | 1.38 | 1442.0 | Wheels-on time (local, hhmm)
taxi_in | float64 | 1.38 | 7.0 | Taxi in time in minutes
crs_arr_time | Int64 | 0.0 | 1508 | Scheduled arrival time (local, hhmm)
arr_time | float64 | 1.38 | 1449.0 | Actual arrival time (local, hhmm)
arr_delay | float64 | 1.61 | -19.0 | Arrival delay in minutes (negative if early)
cancelled | int64 | 0.0 | 0 | Cancelled flight indicator (0=No, 1=Yes)
cancellation_code | object | 98.64 | B | Reason for cancellation (if cancelled)
diverted | int64 | 0.0 | 0 | Diverted flight indicator (0=No, 1=Yes)
crs_elapsed_time | float64 | 0.0 | 136.0 | Scheduled elapsed time in minutes
actual_elapsed_time | float64 | 1.61 | 122.0 | Actual elapsed time in minutes
air_time | float64 | 1.61 | 84.0 | Flight time in minutes
distance | float64 | 0.0 | 509.0 | Distance between origin and destination (miles)
carrier_delay | int64 | 0.0 | 0 | Carrier-related delay in minutes
weather_delay | int64 | 0.0 | 0 | Weather-related delay in minutes
nas_delay | int64 | 0.0 | 0 | National Air System delay in minutes
security_delay | int64 | 0.0 | 0 | Security delay in minutes
late_aircraft_delay | int64 | 0.0 | 0 | Late aircraft delay in minutes
'''

llm_judge_user_prompt_template = '''
Here is the initial user query:
{user_query}

Here is the SQL query generated by the first analyst: 
SQL: 
{sql1} 

Database output: 
{result1}

Here is the SQL query generated by the second analyst:
SQL:
{sql2}

Database output:
{result2}
'''

def llm_judge(rec, field_to_check):
  # construct the user prompt 
  user_prompt = llm_judge_user_prompt_template.format(
    user_query = rec['question'],
    sql1 = rec['answer'],
    result1 = rec['answer_output'],
    sql2 = rec[field_to_check + '_sql'],
    result2 = rec[field_to_check + '_output']
  )
  
  # make an LLM call
  message = client.messages.create(
      model = "claude-sonnet-4-5",
      max_tokens = 8192,
      temperature = 0.1,
      system = llm_judge_system_prompt,
      messages=[
          {'role': 'user', 'content': user_prompt}
      ]
  )
  data = message.content[0].text
  
  # Strip markdown code blocks
  data = data.strip()
  if data.startswith('```json'):
      data = data[7:]
  elif data.startswith('```'):
      data = data[3:]
  if data.endswith('```'):
      data = data[:-3]
  
  data = data.strip()
  return json.loads(data)

现在,让我们运行 LLM 评测器来获取结果。

tmp = []

for rec in tqdm.tqdm(direct_result_df.to_dict('records')):
  try:
    judgment = llm_judge(rec, 'llm_direct')
  except Exception as e:
    print(f"Error processing record {rec['id']}: {e}")
    continue
  tmp.append(
    {
      'id': rec['id'],
      'llm_judge_reasoning': judgment['reasoning'],
      'llm_judge_equivalence': judgment['equivalence']
    }
  )

judge_df = pd.DataFrame(tmp)
direct_result_df = direct_result_df.merge(judge_df, on = 'id')

让我们来看一个例子,了解 LLM 判断器是如何工作的。

# user query 
In 2024, what percentage of time all airplanes spent in the air?

# correct answer 
select (sum(air_time) / sum(actual_elapsed_time)) * 100 as percentage_in_air 
where year = 2024
from flight_data 
format TabSeparatedWithNames

percentage_in_air
81.43582596894757

# generated by LLM answer 
SELECT 
    round(sum(air_time) / (sum(air_time) + sum(taxi_out) + sum(taxi_in)) * 100, 2) as air_time_percentage
FROM flight_data
WHERE year = 2024
FORMAT TabSeparatedWithNames

air_time_percentage
81.39

# LLM judge response
{
 'reasoning': 'Both queries calculate the percentage of time airplanes 
    spent in the air, but use different denominators. The first query 
    uses actual_elapsed_time (which includes air_time + taxi_out + taxi_in 
    + any ground delays), while the second uses only (air_time + taxi_out 
    + taxi_in). The second query is approach is more accurate for answering 
    "time airplanes spent in the air" as it excludes ground delays. 
    However, the results are very close (81.44% vs 81.39%), suggesting minimal 
    impact. These are materially different approaches that happen to yield 
    similar results',
 'equivalence': FALSE
}

推理是合理的,因此我们可以相信我们的判断。现在,让我们检查所有 LLM 生成的查询。

def get_llm_accuracy(sql, output, equivalence): 
    problems = []
    if 'format tabseparatedwithnames' not in sql.lower():
        problems.append('No format specified in SQL')
    if 'Database returned the following error' in output:
        problems.append('SQL execution error')
    if not equivalence and ('SQL execution error' not in problems):
        problems.append('Wrong answer provided')
    if len(problems) == 0:
        return 'No problems detected'
    else:
        return ' + '.join(problems)

direct_result_df['llm_direct_sql_quality_heuristics'] = direct_result_df.apply(
    lambda row: get_llm_accuracy(row['llm_direct_sql'], row['llm_direct_output'], row['llm_judge_equivalence']), axis=1)

LLM 在 70% 的情况下返回了正确答案,这还不错。但它肯定还有改进的空间,因为它经常给出错误的答案,或者未能正确指定格式(有时会导致 SQL 执行错误)。

3.3 添加反射步骤

为了提高解决方案的质量,我们尝试添加一个反射步骤,让模型审查并改进其答案。

对于反射调用,我将保留相同的系统提示,因为它包含了有关 SQL 和数据模式的所有必要信息。但我会调整用户消息,以显示初始用户查询和生成的 SQL,并请求 LLM 对其进行评估和改进。

simple_reflection_user_prompt_template = '''
Your task is to assess the SQL query generated by another analyst and propose improvements if necessary.
Check whether the query is syntactically correct and optimized for performance. 
Pay attention to nuances in data (especially time stamps types, whether to use total elapsed time or time in the air, etc).
Ensure that the query answers the initial user question accurately. 
As the result return the following JSON: 
{{
  'reasoning': '<your reasoning here, 2-4 sentences on why you made changes or not>', 
  'refined_sql': '<the improved SQL query here>'
}}
Ensure that ONLY JSON is in the output and nothing else. Ensure that the output JSON is valid. 

Here is the initial user query:
{user_query}

Here is the SQL query generated by another analyst: 
{sql} 
'''

def simple_reflection(rec) -> str:
  # constructing a user prompt
  user_prompt = simple_reflection_user_prompt_template.format(
    user_query=rec['question'],
    sql=rec['llm_direct_sql']
  )
  
  # making an LLM call
  message = client.messages.create(
    model="claude-3-5-haiku-latest",
    max_tokens = 8192,
    system=base_sql_system_prompt,
    messages=[
        {'role': 'user', 'content': user_prompt}
    ]
  )

  data  = message.content[0].text

  # strip markdown code blocks
  data = data.strip()
  if data.startswith('```json'):
    data = data[7:]
  elif data.startswith('```'):
    data = data[3:]
  if data.endswith('```'):
    data = data[:-3]
  
  data = data.strip()
  return json.loads(data.replace('\n', ' '))

让我们使用反射来优化查询并测量准确率。最终结果的质量并没有显著提高,正确率仍然只有 70%。

让我们来看一些具体的例子来理解发生了什么。首先,LLM 成功修复了一些问题,要么是通过修正格式,要么是通过添加缺失的逻辑来处理零值。

然而,也有一些情况下,LLM 使答案过于复杂。最初的 SQL 语句是正确的(与黄金答案匹配),但 LLM 随后决定对其进行“改进”。其中一些改进是合理的(例如,考虑空值或排除已取消的航班)。然而,出于某种原因,它决定使用 ClickHouse 采样,即使我们的数据量不大,而且我们的表不支持采样。结果,改进后的查询返回了执行错误:数据库返回以下错误:代码:141。DB::Exception:存储 default.flight_data 不支持采样。(SAMPLING_NOT_SUPPORTED)。

3.4 使用外部反馈进行反思

反思并没有显著提高准确性。这可能是因为我们没有提供任何有助于模型生成更佳结果的额外信息。让我们尝试向模型提供外部反馈:

  • 格式是否正确指定的检查结果
  • 数据库输出(数据或错误信息)

让我们为此编写一个提示,并生成一个新的 SQL 版本。

feedback_reflection_user_prompt_template = '''
Your task is to assess the SQL query generated by another analyst and propose improvements if necessary.
Check whether the query is syntactically correct and optimized for performance. 
Pay attention to nuances in data (especially time stamps types, whether to use total elapsed time or time in the air, etc).
Ensure that the query answers the initial user question accurately. 

As the result return the following JSON: 
{{
  'reasoning': '<your reasoning here, 2-4 sentences on why you made changes or not>', 
  'refined_sql': '<the improved SQL query here>'
}}
Ensure that ONLY JSON is in the output and nothing else. Ensure that the output JSON is valid. 


Here is the initial user query:
{user_query}

Here is the SQL query generated by another analyst: 
{sql} 

Here is the database output of this query: 
{output}

We run an automatic check on the SQL query to check whether it has fomatting issues. Here's the output: 
{formatting}
'''

def feedback_reflection(rec) -> str:
  # define message for formatting 
  if 'No format specified in SQL' in rec['llm_direct_sql_quality_heuristics']:
    formatting = 'SQL missing formatting. Specify "format TabSeparatedWithNames" to ensure that column names are also returned'
  else: 
    formatting = 'Formatting is correct'

  # constructing a user prompt
  user_prompt = feedback_reflection_user_prompt_template.format(
    user_query = rec['question'],
    sql = rec['llm_direct_sql'],
    output = rec['llm_direct_output'],
    formatting = formatting
  )

  # making an LLM call 
  message = client.messages.create(
    model = "claude-3-5-haiku-latest",
    max_tokens = 8192,
    system = base_sql_system_prompt,
    messages = [
        {'role': 'user', 'content': user_prompt}
    ]
  )
  data  = message.content[0].text

  # strip markdown code blocks
  data = data.strip()
  if data.startswith('```json'):
    data = data[7:]
  elif data.startswith('```'):
    data = data[3:]
  if data.endswith('```'):
    data = data[:-3]
  
  data = data.strip()
  return json.loads(data.replace('\n', ' '))

运行准确率测试后,我们可以看到准确率显著提高:17 个正确答案(85% 准确率),而之前只有 14 个(70% 准确率)。

如果我们检查 LLM 修复问题的案例,可以看到它能够纠正格式、解决 SQL 执行错误,甚至修改业务逻辑(例如,使用通话时间计算速度)。

我们再进行一些错误分析,看看LLM在哪些情况下出错。在下表中,我们可以看到LLM在定义某些时间戳、错误计算总时间或使用总时间而非飞行时间进行速度计算方面存在问题。然而,有些差异比较棘手:

  • 在最后一个查询中,时间段没有明确定义,因此LLM使用2010-2023年是合理的。我不认为这是一个错误,而是应该调整评估结果。
  • 另一个例子是如何定义航空公司的速度:avg(距离/时间) 或 sum(距离)/sum(时间)。由于用户查询或系统提示中没有指定任何内容(假设我们没有预定义的计算方法),因此这两个选项都是有效的。

总的来说,我认为我们取得了相当不错的结果。我们最终达到的 85% 准确率意味着显著提升了 15 个百分点。您或许可以进行不止一次迭代,开展 2-3 轮反思,但值得评估的是,何时在您的特定情况下达到收益递减点。这种情况,因为每次迭代都会增加成本和延迟。

您可以在 GitHub 上找到完整代码。

4、结束语

是时候总结一下了。在本文中,我们开启了探索智能体 AI 系统工作原理的旅程。为了弄清这一点,我们将实现一个仅使用 API 调用基础模型的多智能体文本到数据工具。在此过程中,我们将逐步讲解关键的设计模式:从今天的反思开始,然后是工具使用、规划和多智能体协调。

在本文中,我们从最基本的模式——反思——开始。反思是任何智能体流程的核心,因为 LLM 需要反思其在实现最终目标方面的进展。

反思是一种相对简单的模式。我们只需让相同或不同的模型分析结果并尝试改进它。正如我们在实践中所了解到的,与模型共享外部反馈(例如静态检查的结果或数据库输出)可以显著提高准确性。多项研究以及我们自身使用文本转 SQL 代理的经验都证明了反思的益处。然而,准确率的提升是有代价的:由于多次 API 调用,需要消耗更多令牌并导致更高的延迟。


原文链接:Agentic AI from First Principles: Reflection

汇智网翻译整理,转载请标明出处