{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 反思\n", "\n", "反思是一种设计模式,其中 LLM 生成后跟着一个反思,这个反思本身是另一个基于第一个输出的 LLM 生成。例如,给定一个编写代码的任务,第一个 LLM 可以生成代码片段,第二个 LLM 可以生成对该代码片段的评价。\n", "\n", "在 AutoGen 和代理的上下文中,反思可以实现为一对代理,其中第一个代理生成消息,第二个代理生成对该消息的响应。这两个代理继续交互,直到达到停止条件,比如最大迭代次数或来自第二个代理的批准。\n", "\n", "让我们使用 AutoGen 代理实现一个简单的反思设计模式。将有两个代理:编码代理和审查代理,编码代理将生成代码片段,审查代理将生成对代码片段的评价。\n", "\n", "## 消息协议\n", "\n", "在定义代理之前,我们需要首先定义代理的消息协议。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "\n", "\n", "@dataclass\n", "class CodeWritingTask:\n", " task: str\n", "\n", "\n", "@dataclass\n", "class CodeWritingResult:\n", " task: str\n", " code: str\n", " review: str\n", "\n", "\n", "@dataclass\n", "class CodeReviewTask:\n", " session_id: str\n", " code_writing_task: str\n", " code_writing_scratchpad: str\n", " code: str\n", "\n", "\n", "@dataclass\n", "class CodeReviewResult:\n", " review: str\n", " session_id: str\n", " approved: bool" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上述消息集定义了我们示例反思设计模式的协议:\n", "- 应用程序向编码代理发送 `CodeWritingTask` 消息\n", "- 编码代理生成 `CodeReviewTask` 消息,发送给审查代理\n", "- 审查代理生成 `CodeReviewResult` 消息,发送回编码代理\n", "- 根据 `CodeReview` 消息,如果代码被批准,编码代理向应用程序发送 `CodeWritingResult` 消息,否则,编码代理向审查代理发送另一个 `CodeWritingTask` 消息,过程继续。\n", "\n", "我们可以使用数据流图来可视化消息协议:\n", "\n", "![coder-reviewer data flow](coder-reviewer-data-flow.svg)\n", "\n", "## 代理\n", "\n", "现在,让我们为反思设计模式定义代理。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "import re\n", "import uuid\n", "from typing import Dict, List, Union\n", "\n", "from autogen_core.base import MessageContext, TopicId\n", "from autogen_core.components import RoutedAgent, default_subscription, message_handler\n", "from autogen_core.components.models import (\n", " AssistantMessage,\n", " ChatCompletionClient,\n", " LLMMessage,\n", " SystemMessage,\n", " UserMessage,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们使用 [Broadcast](../framework/message-and-communication.ipynb#broadcast) API 来实现设计模式。代理实现了发布/订阅模型。编码代理订阅 `CodeWritingTask` 和 `CodeReviewResult` 消息,并发布 `CodeReviewTask` 和 `CodeWritingResult` 消息。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "@default_subscription\n", "class CoderAgent(RoutedAgent):\n", " \"\"\"An agent that performs code writing tasks.\"\"\"\n", "\n", " def __init__(self, model_client: ChatCompletionClient) -> None:\n", " super().__init__(\"A code writing agent.\")\n", " self._system_messages: List[LLMMessage] = [\n", " SystemMessage(\n", " content=\"\"\"You are a proficient coder. You write code to solve problems.\n", "Work with the reviewer to improve your code.\n", "Always put all finished code in a single Markdown code block.\n", "For example:\n", "```python\n", "def hello_world():\n", " print(\"Hello, World!\")\n", "```\n", "\n", "Respond using the following format:\n", "\n", "Thoughts: \n", "Code: \n", "\"\"\",\n", " )\n", " ]\n", " self._model_client = model_client\n", " self._session_memory: Dict[str, List[CodeWritingTask | CodeReviewTask | CodeReviewResult]] = {}\n", "\n", " @message_handler\n", " async def handle_code_writing_task(self, message: CodeWritingTask, ctx: MessageContext) -> None:\n", " # Store the messages in a temporary memory for this request only.\n", " session_id = str(uuid.uuid4())\n", " self._session_memory.setdefault(session_id, []).append(message)\n", " # Generate a response using the chat completion API.\n", " response = await self._model_client.create(\n", " self._system_messages + [UserMessage(content=message.task, source=self.metadata[\"type\"])],\n", " cancellation_token=ctx.cancellation_token,\n", " )\n", " assert isinstance(response.content, str)\n", " # Extract the code block from the response.\n", " code_block = self._extract_code_block(response.content)\n", " if code_block is None:\n", " raise ValueError(\"Code block not found.\")\n", " # Create a code review task.\n", " code_review_task = CodeReviewTask(\n", " session_id=session_id,\n", " code_writing_task=message.task,\n", " code_writing_scratchpad=response.content,\n", " code=code_block,\n", " )\n", " # Store the code review task in the session memory.\n", " self._session_memory[session_id].append(code_review_task)\n", " # Publish a code review task.\n", " await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n", "\n", " @message_handler\n", " async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:\n", " # Store the review result in the session memory.\n", " self._session_memory[message.session_id].append(message)\n", " # Obtain the request from previous messages.\n", " review_request = next(\n", " m for m in reversed(self._session_memory[message.session_id]) if isinstance(m, CodeReviewTask)\n", " )\n", " assert review_request is not None\n", " # Check if the code is approved.\n", " if message.approved:\n", " # Publish the code writing result.\n", " await self.publish_message(\n", " CodeWritingResult(\n", " code=review_request.code,\n", " task=review_request.code_writing_task,\n", " review=message.review,\n", " ),\n", " topic_id=TopicId(\"default\", self.id.key),\n", " )\n", " print(\"Code Writing Result:\")\n", " print(\"-\" * 80)\n", " print(f\"Task:\\n{review_request.code_writing_task}\")\n", " print(\"-\" * 80)\n", " print(f\"Code:\\n{review_request.code}\")\n", " print(\"-\" * 80)\n", " print(f\"Review:\\n{message.review}\")\n", " print(\"-\" * 80)\n", " else:\n", " # Create a list of LLM messages to send to the model.\n", " messages: List[LLMMessage] = [*self._system_messages]\n", " for m in self._session_memory[message.session_id]:\n", " if isinstance(m, CodeReviewResult):\n", " messages.append(UserMessage(content=m.review, source=\"Reviewer\"))\n", " elif isinstance(m, CodeReviewTask):\n", " messages.append(AssistantMessage(content=m.code_writing_scratchpad, source=\"Coder\"))\n", " elif isinstance(m, CodeWritingTask):\n", " messages.append(UserMessage(content=m.task, source=\"User\"))\n", " else:\n", " raise ValueError(f\"Unexpected message type: {m}\")\n", " # Generate a revision using the chat completion API.\n", " response = await self._model_client.create(messages, cancellation_token=ctx.cancellation_token)\n", " assert isinstance(response.content, str)\n", " # Extract the code block from the response.\n", " code_block = self._extract_code_block(response.content)\n", " if code_block is None:\n", " raise ValueError(\"Code block not found.\")\n", " # Create a new code review task.\n", " code_review_task = CodeReviewTask(\n", " session_id=message.session_id,\n", " code_writing_task=review_request.code_writing_task,\n", " code_writing_scratchpad=response.content,\n", " code=code_block,\n", " )\n", " # Store the code review task in the session memory.\n", " self._session_memory[message.session_id].append(code_review_task)\n", " # Publish a new code review task.\n", " await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n", "\n", " def _extract_code_block(self, markdown_text: str) -> Union[str, None]:\n", " pattern = r\"```(\\w+)\\n(.*?)\\n```\"\n", " # Search for the pattern in the markdown text\n", " match = re.search(pattern, markdown_text, re.DOTALL)\n", " # Extract the language and code block if a match is found\n", " if match:\n", " return match.group(2)\n", " return None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "关于 `CoderAgent` 需要注意几点:\n", "- 它在系统消息中使用思维链提示。\n", "- 它在字典中存储不同 `CodeWritingTask` 的消息历史,因此每个任务都有自己的历史记录。\n", "- 当使用其模型客户端进行 LLM 推理请求时,它将消息历史转换为 {py:class}`autogen_core.components.models.LLMMessage` 对象列表,以传递给模型客户端。\n", "\n", "审查代理订阅 `CodeReviewTask` 消息并发布 `CodeReviewResult` 消息。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "@default_subscription\n", "class ReviewerAgent(RoutedAgent):\n", " \"\"\"An agent that performs code review tasks.\"\"\"\n", "\n", " def __init__(self, model_client: ChatCompletionClient) -> None:\n", " super().__init__(\"A code reviewer agent.\")\n", " self._system_messages: List[LLMMessage] = [\n", " SystemMessage(\n", " content=\"\"\"You are a code reviewer. You focus on correctness, efficiency and safety of the code.\n", "Respond using the following JSON format:\n", "{\n", " \"correctness\": \"\",\n", " \"efficiency\": \"\",\n", " \"safety\": \"\",\n", " \"approval\": \"\",\n", " \"suggested_changes\": \"\"\n", "}\n", "\"\"\",\n", " )\n", " ]\n", " self._session_memory: Dict[str, List[CodeReviewTask | CodeReviewResult]] = {}\n", " self._model_client = model_client\n", "\n", " @message_handler\n", " async def handle_code_review_task(self, message: CodeReviewTask, ctx: MessageContext) -> None:\n", " # Format the prompt for the code review.\n", " # Gather the previous feedback if available.\n", " previous_feedback = \"\"\n", " if message.session_id in self._session_memory:\n", " previous_review = next(\n", " (m for m in reversed(self._session_memory[message.session_id]) if isinstance(m, CodeReviewResult)),\n", " None,\n", " )\n", " if previous_review is not None:\n", " previous_feedback = previous_review.review\n", " # Store the messages in a temporary memory for this request only.\n", " self._session_memory.setdefault(message.session_id, []).append(message)\n", " prompt = f\"\"\"The problem statement is: {message.code_writing_task}\n", "The code is:\n", "```\n", "{message.code}\n", "```\n", "\n", "Previous feedback:\n", "{previous_feedback}\n", "\n", "Please review the code. If previous feedback was provided, see if it was addressed.\n", "\"\"\"\n", " # Generate a response using the chat completion API.\n", " response = await self._model_client.create(\n", " self._system_messages + [UserMessage(content=prompt, source=self.metadata[\"type\"])],\n", " cancellation_token=ctx.cancellation_token,\n", " json_output=True,\n", " )\n", " assert isinstance(response.content, str)\n", " # TODO: use structured generation library e.g. guidance to ensure the response is in the expected format.\n", " # Parse the response JSON.\n", " review = json.loads(response.content)\n", " # Construct the review text.\n", " review_text = \"Code review:\\n\" + \"\\n\".join([f\"{k}: {v}\" for k, v in review.items()])\n", " approved = review[\"approval\"].lower().strip() == \"approve\"\n", " result = CodeReviewResult(\n", " review=review_text,\n", " session_id=message.session_id,\n", " approved=approved,\n", " )\n", " # Store the review result in the session memory.\n", " self._session_memory[message.session_id].append(result)\n", " # Publish the review result.\n", " await self.publish_message(result, topic_id=TopicId(\"default\", self.id.key))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`ReviewerAgent` 在进行 LLM 推理请求时使用 JSON 模式,并且在其系统消息中也使用思维链提示。\n", "\n", "## 日志记录\n", "\n", "打开日志记录以查看代理之间交换的消息。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import logging\n", "\n", "logging.basicConfig(level=logging.WARNING)\n", "logging.getLogger(\"autogen_core\").setLevel(logging.DEBUG)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 运行设计模式\n", "\n", "让我们用一个编码任务来测试设计模式。\n", "由于所有代理都使用 {py:meth}`~autogen_core.components.default_subscription` 类装饰器进行装饰,创建代理时会自动订阅默认主题。\n", "我们向默认主题发布 `CodeWritingTask` 消息来启动反思过程。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:autogen_core:Publishing message of type CodeWritingTask to all subscribers: {'task': 'Write a function to find the sum of all even numbers in a list.'}\n", "INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeWritingTask published by Unknown\n", "INFO:autogen_core:Calling message handler for CoderAgent with message type CodeWritingTask published by Unknown\n", "INFO:autogen_core:Unhandled message: CodeWritingTask(task='Write a function to find the sum of all even numbers in a list.')\n", "INFO:autogen_core.events:{\"prompt_tokens\": 101, \"completion_tokens\": 88, \"type\": \"LLMCall\"}\n", "INFO:autogen_core:Publishing message of type CodeReviewTask to all subscribers: {'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'code_writing_task': 'Write a function to find the sum of all even numbers in a list.', 'code_writing_scratchpad': 'Thoughts: To find the sum of all even numbers in a list, we can use a list comprehension to filter out the even numbers and then use the `sum()` function to calculate their total. The implementation should handle edge cases like an empty list or a list with no even numbers.\\n\\nCode:\\n```python\\ndef sum_of_even_numbers(numbers):\\n return sum(num for num in numbers if num % 2 == 0)\\n```', 'code': 'def sum_of_even_numbers(numbers):\\n return sum(num for num in numbers if num % 2 == 0)'}\n", "INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeReviewTask published by CoderAgent:default\n", "INFO:autogen_core.events:{\"prompt_tokens\": 163, \"completion_tokens\": 235, \"type\": \"LLMCall\"}\n", "INFO:autogen_core:Publishing message of type CodeReviewResult to all subscribers: {'review': \"Code review:\\ncorrectness: The function correctly identifies and sums all even numbers in the provided list. The use of a generator expression ensures that only even numbers are processed, which is correct.\\nefficiency: The function is efficient as it utilizes a generator expression that avoids creating an intermediate list, therefore using less memory. The time complexity is O(n) where n is the number of elements in the input list, which is optimal for this task.\\nsafety: The function does not include checks for input types. If a non-iterable or a list containing non-integer types is passed, it could lead to unexpected behavior or errors. It’s advisable to handle such cases.\\napproval: REVISE\\nsuggested_changes: Consider adding input validation to ensure that 'numbers' is a list and contains only integers. You could raise a ValueError if the input is invalid. Example: 'if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers): raise ValueError('Input must be a list of integers')'. This will make the function more robust.\", 'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'approved': False}\n", "INFO:autogen_core:Calling message handler for CoderAgent with message type CodeReviewResult published by ReviewerAgent:default\n", "INFO:autogen_core.events:{\"prompt_tokens\": 421, \"completion_tokens\": 119, \"type\": \"LLMCall\"}\n", "INFO:autogen_core:Publishing message of type CodeReviewTask to all subscribers: {'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'code_writing_task': 'Write a function to find the sum of all even numbers in a list.', 'code_writing_scratchpad': \"Thoughts: I appreciate the reviewer's feedback on input validation. Adding type checks ensures that the function can handle unexpected inputs gracefully. I will implement the suggested changes and include checks for both the input type and the elements within the list to confirm that they are integers.\\n\\nCode:\\n```python\\ndef sum_of_even_numbers(numbers):\\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\\n raise ValueError('Input must be a list of integers')\\n \\n return sum(num for num in numbers if num % 2 == 0)\\n```\", 'code': \"def sum_of_even_numbers(numbers):\\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\\n raise ValueError('Input must be a list of integers')\\n \\n return sum(num for num in numbers if num % 2 == 0)\"}\n", "INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeReviewTask published by CoderAgent:default\n", "INFO:autogen_core.events:{\"prompt_tokens\": 420, \"completion_tokens\": 153, \"type\": \"LLMCall\"}\n", "INFO:autogen_core:Publishing message of type CodeReviewResult to all subscribers: {'review': 'Code review:\\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\\napproval: APPROVE\\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.', 'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'approved': True}\n", "INFO:autogen_core:Calling message handler for CoderAgent with message type CodeReviewResult published by ReviewerAgent:default\n", "INFO:autogen_core:Publishing message of type CodeWritingResult to all subscribers: {'task': 'Write a function to find the sum of all even numbers in a list.', 'code': \"def sum_of_even_numbers(numbers):\\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\\n raise ValueError('Input must be a list of integers')\\n \\n return sum(num for num in numbers if num % 2 == 0)\", 'review': 'Code review:\\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\\napproval: APPROVE\\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.'}\n", "INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeWritingResult published by CoderAgent:default\n", "INFO:autogen_core:Unhandled message: CodeWritingResult(task='Write a function to find the sum of all even numbers in a list.', code=\"def sum_of_even_numbers(numbers):\\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\\n raise ValueError('Input must be a list of integers')\\n \\n return sum(num for num in numbers if num % 2 == 0)\", review='Code review:\\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\\napproval: APPROVE\\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Code Writing Result:\n", "--------------------------------------------------------------------------------\n", "Task:\n", "Write a function to find the sum of all even numbers in a list.\n", "--------------------------------------------------------------------------------\n", "Code:\n", "def sum_of_even_numbers(numbers):\n", " if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n", " raise ValueError('Input must be a list of integers')\n", " \n", " return sum(num for num in numbers if num % 2 == 0)\n", "--------------------------------------------------------------------------------\n", "Review:\n", "Code review:\n", "correctness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\n", "efficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\n", "safety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\n", "approval: APPROVE\n", "suggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.\n", "--------------------------------------------------------------------------------\n" ] } ], "source": [ "from autogen_core.application import SingleThreadedAgentRuntime\n", "from autogen_core.components import DefaultTopicId\n", "from autogen_ext.models import OpenAIChatCompletionClient\n", "\n", "runtime = SingleThreadedAgentRuntime()\n", "await ReviewerAgent.register(\n", " runtime, \"ReviewerAgent\", lambda: ReviewerAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"))\n", ")\n", "await CoderAgent.register(\n", " runtime, \"CoderAgent\", lambda: CoderAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"))\n", ")\n", "runtime.start()\n", "await runtime.publish_message(\n", " message=CodeWritingTask(task=\"Write a function to find the sum of all even numbers in a list.\"),\n", " topic_id=DefaultTopicId(),\n", ")\n", "\n", "# Keep processing messages until idle.\n", "await runtime.stop_when_idle()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "日志消息显示了编码代理和审查代理之间的交互。\n", "最终输出显示了编码代理生成的代码片段和审查代理生成的评价。" ] } ], "metadata": { "kernelspec": { "display_name": "agnext", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 }