mirror of
https://github.com/HKUDS/AutoAgent
synced 2026-04-21 15:47:56 +00:00
673 lines
30 KiB
Python
673 lines
30 KiB
Python
# Standard library imports
|
|
import copy
|
|
import json
|
|
from collections import defaultdict
|
|
from typing import List, Callable, Union
|
|
from datetime import datetime
|
|
# Local imports
|
|
import os
|
|
from .util import function_to_json, debug_print, merge_chunk, pretty_print_messages
|
|
from .types import (
|
|
Agent,
|
|
AgentFunction,
|
|
Message,
|
|
ChatCompletionMessageToolCall,
|
|
Function,
|
|
Response,
|
|
Result,
|
|
)
|
|
from litellm import completion, acompletion
|
|
from pathlib import Path
|
|
from .logger import MetaChainLogger, LoggerManager
|
|
from httpx import RemoteProtocolError, ConnectError
|
|
from litellm.exceptions import APIError
|
|
from tenacity import (
|
|
retry,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
retry_if_exception_type
|
|
)
|
|
from openai import AsyncOpenAI
|
|
import litellm
|
|
import inspect
|
|
from constant import MC_MODE, FN_CALL, API_BASE_URL, NOT_SUPPORT_SENDER, ADD_USER, NON_FN_CALL
|
|
from autoagent.fn_call_converter import convert_tools_to_description, convert_non_fncall_messages_to_fncall_messages, SYSTEM_PROMPT_SUFFIX_TEMPLATE, convert_fn_messages_to_non_fn_messages, interleave_user_into_messages
|
|
from litellm.types.utils import Message as litellmMessage
|
|
# litellm.set_verbose=True
|
|
# client = AsyncOpenAI()
|
|
def should_retry_error(exception):
|
|
if MC_MODE is False: print(f"Caught exception: {type(exception).__name__} - {str(exception)}")
|
|
|
|
# 匹配更多错误类型
|
|
if isinstance(exception, (APIError, RemoteProtocolError, ConnectError)):
|
|
return True
|
|
|
|
# 通过错误消息匹配
|
|
error_msg = str(exception).lower()
|
|
return any([
|
|
"connection error" in error_msg,
|
|
"server disconnected" in error_msg,
|
|
"eof occurred" in error_msg,
|
|
"timeout" in error_msg,
|
|
"event loop is closed" in error_msg, # 添加事件循环错误
|
|
"anthropicexception" in error_msg, # 添加 Anthropic 相关错误
|
|
])
|
|
__CTX_VARS_NAME__ = "context_variables"
|
|
logger = LoggerManager.get_logger()
|
|
|
|
def adapt_tools_for_gemini(tools):
|
|
"""为 Gemini 模型适配工具定义,确保所有 OBJECT 类型参数都有非空的 properties"""
|
|
if tools is None:
|
|
return None
|
|
|
|
adapted_tools = []
|
|
for tool in tools:
|
|
adapted_tool = copy.deepcopy(tool)
|
|
|
|
# 检查参数
|
|
if "parameters" in adapted_tool["function"]:
|
|
params = adapted_tool["function"]["parameters"]
|
|
|
|
# 处理顶层参数
|
|
if params.get("type") == "object":
|
|
if "properties" not in params or not params["properties"]:
|
|
params["properties"] = {
|
|
"dummy": {
|
|
"type": "string",
|
|
"description": "Dummy property for Gemini compatibility"
|
|
}
|
|
}
|
|
|
|
# 处理嵌套参数
|
|
if "properties" in params:
|
|
for prop_name, prop in params["properties"].items():
|
|
if isinstance(prop, dict) and prop.get("type") == "object":
|
|
if "properties" not in prop or not prop["properties"]:
|
|
prop["properties"] = {
|
|
"dummy": {
|
|
"type": "string",
|
|
"description": "Dummy property for Gemini compatibility"
|
|
}
|
|
}
|
|
|
|
adapted_tools.append(adapted_tool)
|
|
return adapted_tools
|
|
|
|
class MetaChain:
|
|
def __init__(self, log_path: Union[str, None, MetaChainLogger] = None):
|
|
"""
|
|
log_path: path of log file, None
|
|
"""
|
|
if logger:
|
|
self.logger = logger
|
|
elif isinstance(log_path, MetaChainLogger):
|
|
self.logger = log_path
|
|
else:
|
|
self.logger = MetaChainLogger(log_path=log_path)
|
|
# if self.logger.log_path is None: self.logger.info("[Warning] Not specific log path, so log will not be saved", "...", title="Log Path", color="light_cyan3")
|
|
# else: self.logger.info("Log file is saved to", self.logger.log_path, "...", title="Log Path", color="light_cyan3")
|
|
# @retry(
|
|
# stop=stop_after_attempt(4),
|
|
# wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
# retry=should_retry_error,
|
|
# before_sleep=lambda retry_state: print(f"Retrying... (attempt {retry_state.attempt_number})")
|
|
# )
|
|
def get_chat_completion(
|
|
self,
|
|
agent: Agent,
|
|
history: List,
|
|
context_variables: dict,
|
|
model_override: str,
|
|
stream: bool,
|
|
debug: bool,
|
|
) -> Message:
|
|
context_variables = defaultdict(str, context_variables)
|
|
instructions = (
|
|
agent.instructions(context_variables)
|
|
if callable(agent.instructions)
|
|
else agent.instructions
|
|
)
|
|
if agent.examples:
|
|
examples = agent.examples(context_variables) if callable(agent.examples) else agent.examples
|
|
history = examples + history
|
|
|
|
messages = [{"role": "system", "content": instructions}] + history
|
|
# debug_print(debug, "Getting chat completion for...:", messages)
|
|
|
|
tools = [function_to_json(f) for f in agent.functions]
|
|
# hide context_variables from model
|
|
for tool in tools:
|
|
params = tool["function"]["parameters"]
|
|
params["properties"].pop(__CTX_VARS_NAME__, None)
|
|
if __CTX_VARS_NAME__ in params["required"]:
|
|
params["required"].remove(__CTX_VARS_NAME__)
|
|
create_model = model_override or agent.model
|
|
|
|
if "gemini" in create_model.lower():
|
|
tools = adapt_tools_for_gemini(tools)
|
|
if FN_CALL:
|
|
# create_model = model_override or agent.model
|
|
assert litellm.supports_function_calling(model = create_model) == True, f"Model {create_model} does not support function calling, please set `FN_CALL=False` to use non-function calling mode"
|
|
create_params = {
|
|
"model": create_model,
|
|
"messages": messages,
|
|
"tools": tools or None,
|
|
"tool_choice": agent.tool_choice,
|
|
"stream": stream,
|
|
}
|
|
NO_SENDER_MODE = False
|
|
for not_sender_model in NOT_SUPPORT_SENDER:
|
|
if not_sender_model in create_model:
|
|
NO_SENDER_MODE = True
|
|
break
|
|
|
|
if NO_SENDER_MODE:
|
|
messages = create_params["messages"]
|
|
for message in messages:
|
|
if 'sender' in message:
|
|
del message['sender']
|
|
create_params["messages"] = messages
|
|
|
|
if tools and create_params['model'].startswith("gpt"):
|
|
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
|
|
completion_response = completion(**create_params)
|
|
else:
|
|
# create_model = model_override or agent.model
|
|
assert agent.tool_choice == "required", f"Non-function calling mode MUST use tool_choice = 'required' rather than {agent.tool_choice}"
|
|
last_content = messages[-1]["content"]
|
|
tools_description = convert_tools_to_description(tools)
|
|
messages[-1]["content"] = last_content + "\n[IMPORTANT] You MUST use the tools provided to complete the task.\n" + SYSTEM_PROMPT_SUFFIX_TEMPLATE.format(description=tools_description)
|
|
NO_SENDER_MODE = False
|
|
for not_sender_model in NOT_SUPPORT_SENDER:
|
|
if not_sender_model in create_model:
|
|
NO_SENDER_MODE = True
|
|
break
|
|
|
|
if NO_SENDER_MODE:
|
|
for message in messages:
|
|
if 'sender' in message:
|
|
del message['sender']
|
|
if NON_FN_CALL:
|
|
messages = convert_fn_messages_to_non_fn_messages(messages)
|
|
if ADD_USER and messages[-1]["role"] != "user":
|
|
# messages.append({"role": "user", "content": "Please think twice and take the next action according to your previous actions and observations."})
|
|
messages = interleave_user_into_messages(messages)
|
|
create_params = {
|
|
"model": create_model,
|
|
"messages": messages,
|
|
"stream": stream,
|
|
"base_url": API_BASE_URL,
|
|
}
|
|
completion_response = completion(**create_params)
|
|
last_message = [{"role": "assistant", "content": completion_response.choices[0].message.content}]
|
|
converted_message = convert_non_fncall_messages_to_fncall_messages(last_message, tools)
|
|
if "tool_calls" in converted_message[0]:
|
|
converted_tool_calls = [ChatCompletionMessageToolCall(**tool_call) for tool_call in converted_message[0]["tool_calls"]]
|
|
else:
|
|
converted_tool_calls = None
|
|
completion_response.choices[0].message = litellmMessage(content = converted_message[0]["content"], role = "assistant", tool_calls = converted_tool_calls)
|
|
|
|
return completion_response
|
|
|
|
def handle_function_result(self, result, debug) -> Result:
|
|
match result:
|
|
case Result() as result:
|
|
return result
|
|
|
|
case Agent() as agent:
|
|
return Result(
|
|
value=json.dumps({"assistant": agent.name}),
|
|
agent=agent,
|
|
)
|
|
case _:
|
|
try:
|
|
return Result(value=str(result))
|
|
except Exception as e:
|
|
error_message = f"Failed to cast response to string: {result}. Make sure agent functions return a string or Result object. Error: {str(e)}"
|
|
self.logger.info(error_message, title="Handle Function Result Error", color="red")
|
|
raise TypeError(error_message)
|
|
|
|
def handle_tool_calls(
|
|
self,
|
|
tool_calls: List[ChatCompletionMessageToolCall],
|
|
functions: List[AgentFunction],
|
|
context_variables: dict,
|
|
debug: bool,
|
|
handle_mm_func: Callable[[], str] = None,
|
|
) -> Response:
|
|
function_map = {f.__name__: f for f in functions}
|
|
partial_response = Response(
|
|
messages=[], agent=None, context_variables={})
|
|
|
|
for tool_call in tool_calls:
|
|
name = tool_call.function.name
|
|
# handle missing tool case, skip to next tool
|
|
if name not in function_map:
|
|
self.logger.info(f"Tool {name} not found in function map. You are recommended to use `run_tool` to run this tool.", title="Tool Call Error", color="red")
|
|
partial_response.messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call.id,
|
|
"name": name,
|
|
"content": f"Error: Tool {name} not found. You are recommended to use `run_tool` to run this tool.",
|
|
}
|
|
)
|
|
continue
|
|
args = json.loads(tool_call.function.arguments)
|
|
|
|
# debug_print(
|
|
# debug, f"Processing tool call: {name} with arguments {args}")
|
|
func = function_map[name]
|
|
# pass context_variables to agent functions
|
|
# if __CTX_VARS_NAME__ in func.__code__.co_varnames:
|
|
# args[__CTX_VARS_NAME__] = context_variables
|
|
if __CTX_VARS_NAME__ in inspect.signature(func).parameters.keys():
|
|
args[__CTX_VARS_NAME__] = context_variables
|
|
raw_result = function_map[name](**args)
|
|
|
|
result: Result = self.handle_function_result(raw_result, debug)
|
|
|
|
partial_response.messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call.id,
|
|
"name": name,
|
|
"content": result.value,
|
|
}
|
|
)
|
|
self.logger.pretty_print_messages(partial_response.messages[-1])
|
|
if result.image:
|
|
assert handle_mm_func, f"handle_mm_func is not provided, but an image is returned by tool call {name}({tool_call.function.arguments})"
|
|
partial_response.messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
# {"type":"text", "text":f"After take last action `{name}({tool_call.function.arguments})`, the image of current page is shown below. Please take next action based on the image, the current state of the page as well as previous actions and observations."},
|
|
{"type":"text", "text":handle_mm_func(name, tool_call.function.arguments)},
|
|
{
|
|
"type":"image_url",
|
|
"image_url":{
|
|
"url":f"data:image/png;base64,{result.image}"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
)
|
|
# debug_print(debug, "Tool calling: ", json.dumps(partial_response.messages[-1], indent=4), log_path=log_path, title="Tool Calling", color="green")
|
|
|
|
partial_response.context_variables.update(result.context_variables)
|
|
if result.agent:
|
|
partial_response.agent = result.agent
|
|
|
|
return partial_response
|
|
|
|
def run_and_stream(
|
|
self,
|
|
agent: Agent,
|
|
messages: List,
|
|
context_variables: dict = {},
|
|
model_override: str = None,
|
|
debug: bool = False,
|
|
max_turns: int = float("inf"),
|
|
execute_tools: bool = True,
|
|
):
|
|
active_agent = agent
|
|
context_variables = copy.deepcopy(context_variables)
|
|
history = copy.deepcopy(messages)
|
|
init_len = len(messages)
|
|
|
|
while len(history) - init_len < max_turns:
|
|
|
|
message = {
|
|
"content": "",
|
|
"sender": agent.name,
|
|
"role": "assistant",
|
|
"function_call": None,
|
|
"tool_calls": defaultdict(
|
|
lambda: {
|
|
"function": {"arguments": "", "name": ""},
|
|
"id": "",
|
|
"type": "",
|
|
}
|
|
),
|
|
}
|
|
|
|
# get completion with current history, agent
|
|
completion = self.get_chat_completion(
|
|
agent=active_agent,
|
|
history=history,
|
|
context_variables=context_variables,
|
|
model_override=model_override,
|
|
stream=True,
|
|
debug=debug,
|
|
)
|
|
|
|
yield {"delim": "start"}
|
|
for chunk in completion:
|
|
delta = json.loads(chunk.choices[0].delta.json())
|
|
if delta["role"] == "assistant":
|
|
delta["sender"] = active_agent.name
|
|
yield delta
|
|
delta.pop("role", None)
|
|
delta.pop("sender", None)
|
|
merge_chunk(message, delta)
|
|
yield {"delim": "end"}
|
|
|
|
message["tool_calls"] = list(
|
|
message.get("tool_calls", {}).values())
|
|
if not message["tool_calls"]:
|
|
message["tool_calls"] = None
|
|
debug_print(debug, "Received completion:", message)
|
|
history.append(message)
|
|
|
|
if not message["tool_calls"] or not execute_tools:
|
|
debug_print(debug, "Ending turn.")
|
|
break
|
|
|
|
# convert tool_calls to objects
|
|
tool_calls = []
|
|
for tool_call in message["tool_calls"]:
|
|
function = Function(
|
|
arguments=tool_call["function"]["arguments"],
|
|
name=tool_call["function"]["name"],
|
|
)
|
|
tool_call_object = ChatCompletionMessageToolCall(
|
|
id=tool_call["id"], function=function, type=tool_call["type"]
|
|
)
|
|
tool_calls.append(tool_call_object)
|
|
|
|
# handle function calls, updating context_variables, and switching agents
|
|
partial_response = self.handle_tool_calls(
|
|
tool_calls, active_agent.functions, context_variables, debug
|
|
)
|
|
history.extend(partial_response.messages)
|
|
context_variables.update(partial_response.context_variables)
|
|
if partial_response.agent:
|
|
active_agent = partial_response.agent
|
|
|
|
yield {
|
|
"response": Response(
|
|
messages=history[init_len:],
|
|
agent=active_agent,
|
|
context_variables=context_variables,
|
|
)
|
|
}
|
|
|
|
def run(
|
|
self,
|
|
agent: Agent,
|
|
messages: List,
|
|
context_variables: dict = {},
|
|
model_override: str = None,
|
|
stream: bool = False,
|
|
debug: bool = True,
|
|
max_turns: int = float("inf"),
|
|
execute_tools: bool = True,
|
|
) -> Response:
|
|
if stream:
|
|
return self.run_and_stream(
|
|
agent=agent,
|
|
messages=messages,
|
|
context_variables=context_variables,
|
|
model_override=model_override,
|
|
debug=debug,
|
|
max_turns=max_turns,
|
|
execute_tools=execute_tools,
|
|
)
|
|
active_agent = agent
|
|
enter_agent = agent
|
|
context_variables = copy.copy(context_variables)
|
|
history = copy.deepcopy(messages)
|
|
init_len = len(messages)
|
|
|
|
self.logger.info("Receiveing the task:", history[-1]['content'], title="Receive Task", color="green")
|
|
|
|
while len(history) - init_len < max_turns and active_agent:
|
|
|
|
# get completion with current history, agent
|
|
completion = self.get_chat_completion(
|
|
agent=active_agent,
|
|
history=history,
|
|
context_variables=context_variables,
|
|
model_override=model_override,
|
|
stream=stream,
|
|
debug=debug,
|
|
)
|
|
message: Message = completion.choices[0].message
|
|
message.sender = active_agent.name
|
|
# debug_print(debug, "Received completion:", message.model_dump_json(indent=4), log_path=log_path, title="Received Completion", color="blue")
|
|
self.logger.pretty_print_messages(message)
|
|
history.append(
|
|
json.loads(message.model_dump_json())
|
|
) # to avoid OpenAI types (?)
|
|
|
|
# if not message.tool_calls or not execute_tools:
|
|
# self.logger.info("Ending turn.", title="End Turn", color="red")
|
|
# break
|
|
|
|
if enter_agent.tool_choice != "required":
|
|
if (not message.tool_calls and active_agent.name == enter_agent.name) or not execute_tools:
|
|
self.logger.info("Ending turn.", title="End Turn", color="red")
|
|
break
|
|
else:
|
|
if (message.tool_calls and message.tool_calls[0].function.name == "case_resolved") or not execute_tools:
|
|
self.logger.info("Ending turn with case resolved.", title="End Turn", color="red")
|
|
partial_response = self.handle_tool_calls(
|
|
message.tool_calls, active_agent.functions, context_variables, debug, handle_mm_func=active_agent.handle_mm_func
|
|
)
|
|
history.extend(partial_response.messages)
|
|
context_variables.update(partial_response.context_variables)
|
|
break
|
|
elif (message.tool_calls and message.tool_calls[0].function.name == "case_not_resolved") or not execute_tools:
|
|
self.logger.info("Ending turn with case not resolved.", title="End Turn", color="red")
|
|
partial_response = self.handle_tool_calls(
|
|
message.tool_calls, active_agent.functions, context_variables, debug, handle_mm_func=active_agent.handle_mm_func
|
|
)
|
|
history.extend(partial_response.messages)
|
|
context_variables.update(partial_response.context_variables)
|
|
break
|
|
elif (not message.tool_calls) or not execute_tools:
|
|
self.logger.info("Ending turn with no tool calls.", title="End Turn", color="red")
|
|
break
|
|
|
|
# if (message.tool_calls and message.tool_calls[0].function.name == "case_resolved") or not execute_tools:
|
|
# debug_print(debug, "Ending turn.", log_path=log_path, title="End Turn", color="red")
|
|
# break
|
|
|
|
# handle function calls, updating context_variables, and switching agents
|
|
if message.tool_calls:
|
|
partial_response = self.handle_tool_calls(
|
|
message.tool_calls, active_agent.functions, context_variables, debug, handle_mm_func=active_agent.handle_mm_func
|
|
)
|
|
else:
|
|
partial_response = Response(messages=[message])
|
|
history.extend(partial_response.messages)
|
|
context_variables.update(partial_response.context_variables)
|
|
if partial_response.agent:
|
|
active_agent = partial_response.agent
|
|
|
|
return Response(
|
|
messages=history[init_len:],
|
|
agent=active_agent,
|
|
context_variables=context_variables,
|
|
)
|
|
@retry(
|
|
stop=stop_after_attempt(4),
|
|
wait=wait_exponential(multiplier=1, min=10, max=180),
|
|
retry=should_retry_error,
|
|
before_sleep=lambda retry_state: print(f"Retrying... (attempt {retry_state.attempt_number})")
|
|
)
|
|
async def get_chat_completion_async(
|
|
self,
|
|
agent: Agent,
|
|
history: List,
|
|
context_variables: dict,
|
|
model_override: str,
|
|
stream: bool,
|
|
debug: bool,
|
|
) -> Message:
|
|
context_variables = defaultdict(str, context_variables)
|
|
instructions = (
|
|
agent.instructions(context_variables)
|
|
if callable(agent.instructions)
|
|
else agent.instructions
|
|
)
|
|
if agent.examples:
|
|
examples = agent.examples(context_variables) if callable(agent.examples) else agent.examples
|
|
history = examples + history
|
|
|
|
messages = [{"role": "system", "content": instructions}] + history
|
|
# debug_print(debug, "Getting chat completion for...:", messages)
|
|
|
|
tools = [function_to_json(f) for f in agent.functions]
|
|
# hide context_variables from model
|
|
for tool in tools:
|
|
params = tool["function"]["parameters"]
|
|
params["properties"].pop(__CTX_VARS_NAME__, None)
|
|
if __CTX_VARS_NAME__ in params["required"]:
|
|
params["required"].remove(__CTX_VARS_NAME__)
|
|
|
|
if FN_CALL:
|
|
create_model = model_override or agent.model
|
|
assert litellm.supports_function_calling(model = create_model) == True, f"Model {create_model} does not support function calling, please set `FN_CALL=False` to use non-function calling mode"
|
|
|
|
create_params = {
|
|
"model": create_model,
|
|
"messages": messages,
|
|
"tools": tools or None,
|
|
"tool_choice": agent.tool_choice,
|
|
"stream": stream,
|
|
}
|
|
NO_SENDER_MODE = False
|
|
for not_sender_model in NOT_SUPPORT_SENDER:
|
|
if not_sender_model in create_model:
|
|
NO_SENDER_MODE = True
|
|
break
|
|
|
|
if NO_SENDER_MODE:
|
|
messages = create_params["messages"]
|
|
for message in messages:
|
|
if 'sender' in message:
|
|
del message['sender']
|
|
create_params["messages"] = messages
|
|
|
|
if tools and create_params['model'].startswith("gpt"):
|
|
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
|
|
completion_response = await acompletion(**create_params)
|
|
else:
|
|
create_model = model_override or agent.model
|
|
assert agent.tool_choice == "required", f"Non-function calling mode MUST use tool_choice = 'required' rather than {agent.tool_choice}"
|
|
last_content = messages[-1]["content"]
|
|
tools_description = convert_tools_to_description(tools)
|
|
messages[-1]["content"] = last_content + "\n[IMPORTANT] You MUST use the tools provided to complete the task.\n" + SYSTEM_PROMPT_SUFFIX_TEMPLATE.format(description=tools_description)
|
|
NO_SENDER_MODE = False
|
|
for not_sender_model in NOT_SUPPORT_SENDER:
|
|
if not_sender_model in create_model:
|
|
NO_SENDER_MODE = True
|
|
break
|
|
|
|
if NO_SENDER_MODE:
|
|
for message in messages:
|
|
if 'sender' in message:
|
|
del message['sender']
|
|
create_params = {
|
|
"model": create_model,
|
|
"messages": messages,
|
|
"stream": stream,
|
|
"base_url": API_BASE_URL,
|
|
}
|
|
completion_response = await acompletion(**create_params)
|
|
last_message = [{"role": "assistant", "content": completion_response.choices[0].message.content}]
|
|
converted_message = convert_non_fncall_messages_to_fncall_messages(last_message, tools)
|
|
converted_tool_calls = [ChatCompletionMessageToolCall(**tool_call) for tool_call in converted_message[0]["tool_calls"]]
|
|
completion_response.choices[0].message = litellmMessage(content = converted_message[0]["content"], role = "assistant", tool_calls = converted_tool_calls)
|
|
|
|
# response = await acompletion(**create_params)
|
|
# response = await client.chat.completions.create(**create_params)
|
|
return completion_response
|
|
|
|
async def run_async(
|
|
self,
|
|
agent: Agent,
|
|
messages: List,
|
|
context_variables: dict = {},
|
|
model_override: str = None,
|
|
stream: bool = False,
|
|
debug: bool = True,
|
|
max_turns: int = float("inf"),
|
|
execute_tools: bool = True,
|
|
) -> Response:
|
|
assert stream == False, "Async run does not support stream"
|
|
active_agent = agent
|
|
enter_agent = agent
|
|
context_variables = copy.copy(context_variables)
|
|
history = copy.deepcopy(messages)
|
|
init_len = len(messages)
|
|
|
|
self.logger.info("Receiveing the task:", history[-1]['content'], title="Receive Task", color="green")
|
|
|
|
while len(history) - init_len < max_turns and active_agent:
|
|
|
|
# get completion with current history, agent
|
|
completion = await self.get_chat_completion_async(
|
|
agent=active_agent,
|
|
history=history,
|
|
context_variables=context_variables,
|
|
model_override=model_override,
|
|
stream=stream,
|
|
debug=debug,
|
|
)
|
|
message: Message = completion.choices[0].message
|
|
message.sender = active_agent.name
|
|
# debug_print(debug, "Received completion:", message.model_dump_json(indent=4), log_path=log_path, title="Received Completion", color="blue")
|
|
self.logger.pretty_print_messages(message)
|
|
history.append(
|
|
json.loads(message.model_dump_json())
|
|
) # to avoid OpenAI types (?)
|
|
|
|
if enter_agent.tool_choice != "required":
|
|
if (not message.tool_calls and active_agent.name == enter_agent.name) or not execute_tools:
|
|
self.logger.info("Ending turn.", title="End Turn", color="red")
|
|
break
|
|
else:
|
|
if (message.tool_calls and message.tool_calls[0].function.name == "case_resolved") or not execute_tools:
|
|
self.logger.info("Ending turn with case resolved.", title="End Turn", color="red")
|
|
partial_response = self.handle_tool_calls(
|
|
message.tool_calls, active_agent.functions, context_variables, debug, handle_mm_func=active_agent.handle_mm_func
|
|
)
|
|
history.extend(partial_response.messages)
|
|
context_variables.update(partial_response.context_variables)
|
|
break
|
|
elif (message.tool_calls and message.tool_calls[0].function.name == "case_not_resolved") or not execute_tools:
|
|
self.logger.info("Ending turn with case not resolved.", title="End Turn", color="red")
|
|
partial_response = self.handle_tool_calls(
|
|
message.tool_calls, active_agent.functions, context_variables, debug, handle_mm_func=active_agent.handle_mm_func
|
|
)
|
|
history.extend(partial_response.messages)
|
|
context_variables.update(partial_response.context_variables)
|
|
break
|
|
elif (not message.tool_calls) or not execute_tools:
|
|
self.logger.info("Ending turn with no tool calls.", title="End Turn", color="red")
|
|
break
|
|
|
|
# if (message.tool_calls and message.tool_calls[0].function.name == "case_resolved") or not execute_tools:
|
|
# debug_print(debug, "Ending turn.", log_path=log_path, title="End Turn", color="red")
|
|
# break
|
|
|
|
# handle function calls, updating context_variables, and switching agents
|
|
if message.tool_calls:
|
|
partial_response = self.handle_tool_calls(
|
|
message.tool_calls, active_agent.functions, context_variables, debug, handle_mm_func=active_agent.handle_mm_func
|
|
)
|
|
else:
|
|
partial_response = Response(messages=[message])
|
|
history.extend(partial_response.messages)
|
|
context_variables.update(partial_response.context_variables)
|
|
if partial_response.agent:
|
|
active_agent = partial_response.agent
|
|
|
|
return Response(
|
|
messages=history[init_len:],
|
|
agent=active_agent,
|
|
context_variables=context_variables,
|
|
)
|