反思#
反思是一种设计模式,其中 LLM 生成后跟着一个反思,这个反思本身是另一个基于第一个输出的 LLM 生成。例如,给定一个编写代码的任务,第一个 LLM 可以生成代码片段,第二个 LLM 可以生成对该代码片段的评价。
在 AutoGen 和代理的上下文中,反思可以实现为一对代理,其中第一个代理生成消息,第二个代理生成对该消息的响应。这两个代理继续交互,直到达到停止条件,比如最大迭代次数或来自第二个代理的批准。
让我们使用 AutoGen 代理实现一个简单的反思设计模式。将有两个代理:编码代理和审查代理,编码代理将生成代码片段,审查代理将生成对代码片段的评价。
消息协议#
在定义代理之前,我们需要首先定义代理的消息协议。
from dataclasses import dataclass
@dataclass
class CodeWritingTask:
task: str
@dataclass
class CodeWritingResult:
task: str
code: str
review: str
@dataclass
class CodeReviewTask:
session_id: str
code_writing_task: str
code_writing_scratchpad: str
code: str
@dataclass
class CodeReviewResult:
review: str
session_id: str
approved: bool
上述消息集定义了我们示例反思设计模式的协议:
应用程序向编码代理发送
CodeWritingTask
消息编码代理生成
CodeReviewTask
消息,发送给审查代理审查代理生成
CodeReviewResult
消息,发送回编码代理根据
CodeReview
消息,如果代码被批准,编码代理向应用程序发送CodeWritingResult
消息,否则,编码代理向审查代理发送另一个CodeWritingTask
消息,过程继续。
我们可以使用数据流图来可视化消息协议:
代理#
现在,让我们为反思设计模式定义代理。
import json
import re
import uuid
from typing import Dict, List, Union
from autogen_core.base import MessageContext, TopicId
from autogen_core.components import RoutedAgent, default_subscription, message_handler
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
我们使用 Broadcast API 来实现设计模式。代理实现了发布/订阅模型。编码代理订阅 CodeWritingTask
和 CodeReviewResult
消息,并发布 CodeReviewTask
和 CodeWritingResult
消息。
@default_subscription
class CoderAgent(RoutedAgent):
"""An agent that performs code writing tasks."""
def __init__(self, model_client: ChatCompletionClient) -> None:
super().__init__("A code writing agent.")
self._system_messages: List[LLMMessage] = [
SystemMessage(
content="""You are a proficient coder. You write code to solve problems.
Work with the reviewer to improve your code.
Always put all finished code in a single Markdown code block.
For example:
```python
def hello_world():
print("Hello, World!")
```
Respond using the following format:
Thoughts: <Your comments>
Code: <Your code>
""",
)
]
self._model_client = model_client
self._session_memory: Dict[str, List[CodeWritingTask | CodeReviewTask | CodeReviewResult]] = {}
@message_handler
async def handle_code_writing_task(self, message: CodeWritingTask, ctx: MessageContext) -> None:
# Store the messages in a temporary memory for this request only.
session_id = str(uuid.uuid4())
self._session_memory.setdefault(session_id, []).append(message)
# Generate a response using the chat completion API.
response = await self._model_client.create(
self._system_messages + [UserMessage(content=message.task, source=self.metadata["type"])],
cancellation_token=ctx.cancellation_token,
)
assert isinstance(response.content, str)
# Extract the code block from the response.
code_block = self._extract_code_block(response.content)
if code_block is None:
raise ValueError("Code block not found.")
# Create a code review task.
code_review_task = CodeReviewTask(
session_id=session_id,
code_writing_task=message.task,
code_writing_scratchpad=response.content,
code=code_block,
)
# Store the code review task in the session memory.
self._session_memory[session_id].append(code_review_task)
# Publish a code review task.
await self.publish_message(code_review_task, topic_id=TopicId("default", self.id.key))
@message_handler
async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:
# Store the review result in the session memory.
self._session_memory[message.session_id].append(message)
# Obtain the request from previous messages.
review_request = next(
m for m in reversed(self._session_memory[message.session_id]) if isinstance(m, CodeReviewTask)
)
assert review_request is not None
# Check if the code is approved.
if message.approved:
# Publish the code writing result.
await self.publish_message(
CodeWritingResult(
code=review_request.code,
task=review_request.code_writing_task,
review=message.review,
),
topic_id=TopicId("default", self.id.key),
)
print("Code Writing Result:")
print("-" * 80)
print(f"Task:\n{review_request.code_writing_task}")
print("-" * 80)
print(f"Code:\n{review_request.code}")
print("-" * 80)
print(f"Review:\n{message.review}")
print("-" * 80)
else:
# Create a list of LLM messages to send to the model.
messages: List[LLMMessage] = [*self._system_messages]
for m in self._session_memory[message.session_id]:
if isinstance(m, CodeReviewResult):
messages.append(UserMessage(content=m.review, source="Reviewer"))
elif isinstance(m, CodeReviewTask):
messages.append(AssistantMessage(content=m.code_writing_scratchpad, source="Coder"))
elif isinstance(m, CodeWritingTask):
messages.append(UserMessage(content=m.task, source="User"))
else:
raise ValueError(f"Unexpected message type: {m}")
# Generate a revision using the chat completion API.
response = await self._model_client.create(messages, cancellation_token=ctx.cancellation_token)
assert isinstance(response.content, str)
# Extract the code block from the response.
code_block = self._extract_code_block(response.content)
if code_block is None:
raise ValueError("Code block not found.")
# Create a new code review task.
code_review_task = CodeReviewTask(
session_id=message.session_id,
code_writing_task=review_request.code_writing_task,
code_writing_scratchpad=response.content,
code=code_block,
)
# Store the code review task in the session memory.
self._session_memory[message.session_id].append(code_review_task)
# Publish a new code review task.
await self.publish_message(code_review_task, topic_id=TopicId("default", self.id.key))
def _extract_code_block(self, markdown_text: str) -> Union[str, None]:
pattern = r"```(\w+)\n(.*?)\n```"
# Search for the pattern in the markdown text
match = re.search(pattern, markdown_text, re.DOTALL)
# Extract the language and code block if a match is found
if match:
return match.group(2)
return None
关于 CoderAgent
需要注意几点:
它在系统消息中使用思维链提示。
它在字典中存储不同
CodeWritingTask
的消息历史,因此每个任务都有自己的历史记录。当使用其模型客户端进行 LLM 推理请求时,它将消息历史转换为
autogen_core.components.models.LLMMessage
对象列表,以传递给模型客户端。
审查代理订阅 CodeReviewTask
消息并发布 CodeReviewResult
消息。
@default_subscription
class ReviewerAgent(RoutedAgent):
"""An agent that performs code review tasks."""
def __init__(self, model_client: ChatCompletionClient) -> None:
super().__init__("A code reviewer agent.")
self._system_messages: List[LLMMessage] = [
SystemMessage(
content="""You are a code reviewer. You focus on correctness, efficiency and safety of the code.
Respond using the following JSON format:
{
"correctness": "<Your comments>",
"efficiency": "<Your comments>",
"safety": "<Your comments>",
"approval": "<APPROVE or REVISE>",
"suggested_changes": "<Your comments>"
}
""",
)
]
self._session_memory: Dict[str, List[CodeReviewTask | CodeReviewResult]] = {}
self._model_client = model_client
@message_handler
async def handle_code_review_task(self, message: CodeReviewTask, ctx: MessageContext) -> None:
# Format the prompt for the code review.
# Gather the previous feedback if available.
previous_feedback = ""
if message.session_id in self._session_memory:
previous_review = next(
(m for m in reversed(self._session_memory[message.session_id]) if isinstance(m, CodeReviewResult)),
None,
)
if previous_review is not None:
previous_feedback = previous_review.review
# Store the messages in a temporary memory for this request only.
self._session_memory.setdefault(message.session_id, []).append(message)
prompt = f"""The problem statement is: {message.code_writing_task}
The code is:
```
{message.code}
```
Previous feedback:
{previous_feedback}
Please review the code. If previous feedback was provided, see if it was addressed.
"""
# Generate a response using the chat completion API.
response = await self._model_client.create(
self._system_messages + [UserMessage(content=prompt, source=self.metadata["type"])],
cancellation_token=ctx.cancellation_token,
json_output=True,
)
assert isinstance(response.content, str)
# TODO: use structured generation library e.g. guidance to ensure the response is in the expected format.
# Parse the response JSON.
review = json.loads(response.content)
# Construct the review text.
review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()])
approved = review["approval"].lower().strip() == "approve"
result = CodeReviewResult(
review=review_text,
session_id=message.session_id,
approved=approved,
)
# Store the review result in the session memory.
self._session_memory[message.session_id].append(result)
# Publish the review result.
await self.publish_message(result, topic_id=TopicId("default", self.id.key))
ReviewerAgent
在进行 LLM 推理请求时使用 JSON 模式,并且在其系统消息中也使用思维链提示。
日志记录#
打开日志记录以查看代理之间交换的消息。
import logging
logging.basicConfig(level=logging.WARNING)
logging.getLogger("autogen_core").setLevel(logging.DEBUG)
运行设计模式#
让我们用一个编码任务来测试设计模式。
由于所有代理都使用 default_subscription()
类装饰器进行装饰,创建代理时会自动订阅默认主题。
我们向默认主题发布 CodeWritingTask
消息来启动反思过程。
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.components import DefaultTopicId
from autogen_ext.models import OpenAIChatCompletionClient
runtime = SingleThreadedAgentRuntime()
await ReviewerAgent.register(
runtime, "ReviewerAgent", lambda: ReviewerAgent(model_client=OpenAIChatCompletionClient(model="gpt-4o-mini"))
)
await CoderAgent.register(
runtime, "CoderAgent", lambda: CoderAgent(model_client=OpenAIChatCompletionClient(model="gpt-4o-mini"))
)
runtime.start()
await runtime.publish_message(
message=CodeWritingTask(task="Write a function to find the sum of all even numbers in a list."),
topic_id=DefaultTopicId(),
)
# Keep processing messages until idle.
await runtime.stop_when_idle()
INFO:autogen_core:Publishing message of type CodeWritingTask to all subscribers: {'task': 'Write a function to find the sum of all even numbers in a list.'}
INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeWritingTask published by Unknown
INFO:autogen_core:Calling message handler for CoderAgent with message type CodeWritingTask published by Unknown
INFO:autogen_core:Unhandled message: CodeWritingTask(task='Write a function to find the sum of all even numbers in a list.')
INFO:autogen_core.events:{"prompt_tokens": 101, "completion_tokens": 88, "type": "LLMCall"}
INFO:autogen_core:Publishing message of type CodeReviewTask to all subscribers: {'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'code_writing_task': 'Write a function to find the sum of all even numbers in a list.', 'code_writing_scratchpad': 'Thoughts: To find the sum of all even numbers in a list, we can use a list comprehension to filter out the even numbers and then use the `sum()` function to calculate their total. The implementation should handle edge cases like an empty list or a list with no even numbers.\n\nCode:\n```python\ndef sum_of_even_numbers(numbers):\n return sum(num for num in numbers if num % 2 == 0)\n```', 'code': 'def sum_of_even_numbers(numbers):\n return sum(num for num in numbers if num % 2 == 0)'}
INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeReviewTask published by CoderAgent:default
INFO:autogen_core.events:{"prompt_tokens": 163, "completion_tokens": 235, "type": "LLMCall"}
INFO:autogen_core:Publishing message of type CodeReviewResult to all subscribers: {'review': "Code review:\ncorrectness: The function correctly identifies and sums all even numbers in the provided list. The use of a generator expression ensures that only even numbers are processed, which is correct.\nefficiency: The function is efficient as it utilizes a generator expression that avoids creating an intermediate list, therefore using less memory. The time complexity is O(n) where n is the number of elements in the input list, which is optimal for this task.\nsafety: The function does not include checks for input types. If a non-iterable or a list containing non-integer types is passed, it could lead to unexpected behavior or errors. It’s advisable to handle such cases.\napproval: REVISE\nsuggested_changes: Consider adding input validation to ensure that 'numbers' is a list and contains only integers. You could raise a ValueError if the input is invalid. Example: 'if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers): raise ValueError('Input must be a list of integers')'. This will make the function more robust.", 'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'approved': False}
INFO:autogen_core:Calling message handler for CoderAgent with message type CodeReviewResult published by ReviewerAgent:default
INFO:autogen_core.events:{"prompt_tokens": 421, "completion_tokens": 119, "type": "LLMCall"}
INFO:autogen_core:Publishing message of type CodeReviewTask to all subscribers: {'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'code_writing_task': 'Write a function to find the sum of all even numbers in a list.', 'code_writing_scratchpad': "Thoughts: I appreciate the reviewer's feedback on input validation. Adding type checks ensures that the function can handle unexpected inputs gracefully. I will implement the suggested changes and include checks for both the input type and the elements within the list to confirm that they are integers.\n\nCode:\n```python\ndef sum_of_even_numbers(numbers):\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n raise ValueError('Input must be a list of integers')\n \n return sum(num for num in numbers if num % 2 == 0)\n```", 'code': "def sum_of_even_numbers(numbers):\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n raise ValueError('Input must be a list of integers')\n \n return sum(num for num in numbers if num % 2 == 0)"}
INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeReviewTask published by CoderAgent:default
INFO:autogen_core.events:{"prompt_tokens": 420, "completion_tokens": 153, "type": "LLMCall"}
INFO:autogen_core:Publishing message of type CodeReviewResult to all subscribers: {'review': 'Code review:\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\napproval: APPROVE\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.', 'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'approved': True}
INFO:autogen_core:Calling message handler for CoderAgent with message type CodeReviewResult published by ReviewerAgent:default
INFO:autogen_core:Publishing message of type CodeWritingResult to all subscribers: {'task': 'Write a function to find the sum of all even numbers in a list.', 'code': "def sum_of_even_numbers(numbers):\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n raise ValueError('Input must be a list of integers')\n \n return sum(num for num in numbers if num % 2 == 0)", 'review': 'Code review:\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\napproval: APPROVE\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.'}
INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeWritingResult published by CoderAgent:default
INFO:autogen_core:Unhandled message: CodeWritingResult(task='Write a function to find the sum of all even numbers in a list.', code="def sum_of_even_numbers(numbers):\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n raise ValueError('Input must be a list of integers')\n \n return sum(num for num in numbers if num % 2 == 0)", review='Code review:\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\napproval: APPROVE\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.')
Code Writing Result:
--------------------------------------------------------------------------------
Task:
Write a function to find the sum of all even numbers in a list.
--------------------------------------------------------------------------------
Code:
def sum_of_even_numbers(numbers):
if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):
raise ValueError('Input must be a list of integers')
return sum(num for num in numbers if num % 2 == 0)
--------------------------------------------------------------------------------
Review:
Code review:
correctness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.
efficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.
safety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.
approval: APPROVE
suggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.
--------------------------------------------------------------------------------
日志消息显示了编码代理和审查代理之间的交互。 最终输出显示了编码代理生成的代码片段和审查代理生成的评价。