diff --git a/.github/workflows/docker-image-pull.yml b/.github/workflows/docker-image-pull.yml
index ca64365..2c3d34a 100644
--- a/.github/workflows/docker-image-pull.yml
+++ b/.github/workflows/docker-image-pull.yml
@@ -11,8 +11,8 @@ jobs:
architecture: [amd64, arm64]
os: [linux]
service:
- - name: runtime:0.1.0
- - name: muagent:0.1.0
+ - name: runtime:0.1.1
+ - name: muagent:0.1.1
- name: ekgfrontend:0.1.0
steps:
diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml
index 04e0c50..3dcd8ba 100644
--- a/.github/workflows/docker-image.yml
+++ b/.github/workflows/docker-image.yml
@@ -12,7 +12,7 @@ jobs:
- name: runtime
context: ./runtime
dockerfile: ./runtime/Dockerfile.no-package
- tag: ghcr.io/codefuse-ai/runtime:0.1.0
+ tag: ghcr.io/codefuse-ai/runtime:0.1.1
tag_latest: ghcr.io/codefuse-ai/runtime:latest
- name: ekgfrontend
context: .
@@ -22,7 +22,7 @@ jobs:
- name: ekgservice
context: .
dockerfile: ./Dockerfile_gh
- tag: ghcr.io/codefuse-ai/muagent:0.1.0
+ tag: ghcr.io/codefuse-ai/muagent:0.1.1
tag_latest: ghcr.io/codefuse-ai/muagent:latest
steps:
diff --git a/docker-compose.yaml b/docker-compose.yaml
index 84cbafd..ed26951 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -190,7 +190,7 @@ services:
context: .
dockerfile: Dockerfile
container_name: ekgservice
- image: muagent:0.1.0
+ image: muagent:0.1.1
environment:
USER: root
TZ: "${TZ}"
diff --git a/docker_pull_images.sh b/docker_pull_images.sh
index 82998c1..85bd44a 100644
--- a/docker_pull_images.sh
+++ b/docker_pull_images.sh
@@ -17,11 +17,11 @@ docker pull redis/redis-stack:7.4.0-v0
docker pull ollama/ollama:0.3.6
# pull images from github ghcr.io by nju
-docker pull ghcr.nju.edu.cn/runtime:0.1.0
-docker pull ghcr.nju.edu.cn/muagent:0.1.0
+docker pull ghcr.nju.edu.cn/runtime:0.1.1
+docker pull ghcr.nju.edu.cn/muagent:0.1.1
docker pull ghcr.nju.edu.cn/ekgfrontend:0.1.0
# # pull images from github ghcr.io
-# docker pull ghcr.io/runtime:0.1.0
-# docker pull ghcr.io/muagent:0.1.0
+# docker pull ghcr.io/runtime:0.1.1
+# docker pull ghcr.io/muagent:0.1.1
# docker pull ghcr.io/ekgfrontend:0.1.0
diff --git a/examples/ekg_examples/start.py b/examples/ekg_examples/start.py
index 838ca87..0e1ed81 100644
--- a/examples/ekg_examples/start.py
+++ b/examples/ekg_examples/start.py
@@ -37,6 +37,7 @@
import test_config
from muagent.schemas.db import *
+from muagent.schemas.apis.ekg_api_schema import LLMFCRequest
from muagent.db_handler import *
from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
from muagent.service.ekg_construct.ekg_construct_base import EKGConstructService
@@ -46,7 +47,8 @@
from pydantic import BaseModel
-
+from muagent.schemas.models import ModelConfig
+from muagent.models import get_model
cur_dir = os.path.dirname(__file__)
@@ -92,56 +94,75 @@ def update_params(self, **kwargs):
def _llm_type(self, *args):
return ""
-
- def predict(self, prompt: str, stop = None) -> str:
- return self._call(prompt, stop)
-
- def _call(self, prompt: str,
- stop = None) -> str:
+
+ def _get_model(self):
"""_call
"""
- return_str = ""
- stop = stop or self.stop
-
- if self.model_type == "ollama":
- stream = ollama.chat(
- model=self.model_name,
- messages=[{'role': 'user', 'content': prompt}],
- stream=True,
- )
- answer = ""
- for chunk in stream:
- answer += chunk['message']['content']
-
- return answer
- elif self.model_type == "openai":
+ if self.model_type in [
+ "ollama", "qwen", "openai", "lingyiwanwu",
+ "kimi", "moonshot",
+ ]:
from muagent.llm_models.openai_model import getChatModelFromConfig
llm_config = LLMConfig(
model_name=self.model_name,
- model_engine="openai",
+ model_engine=self.model_type,
api_key=self.api_key,
api_base_url=self.url,
temperature=self.temperature,
stop=self.stop
)
model = getChatModelFromConfig(llm_config)
- return model.predict(prompt, stop=self.stop)
- elif self.model_type in ["lingyiwanwu", "kimi", "moonshot", "qwen"]:
- from muagent.llm_models.openai_model import getChatModelFromConfig
- llm_config = LLMConfig(
+ else:
+ model_config = ModelConfig(
+ model_type=self.model_type,
model_name=self.model_name,
- model_engine=self.model_type,
api_key=self.api_key,
- api_base_url=self.url,
+ api_url=self.url,
temperature=self.temperature,
- stop=self.stop
)
- model = getChatModelFromConfig(llm_config)
- return model.predict(prompt, stop=self.stop)
- else:
- pass
+ model = get_model(model_config)
+ return model
+
+ def predict(self, prompt: str, stop = None) -> str:
+ return self._call(prompt, stop)
- return return_str
+ def fc(self, request: LLMFCRequest) -> str:
+ """_function_call
+ """
+ if self.model_type not in [
+ "openai", "ollama", "lingyiwanwu", "kimi", "moonshot", "qwen"
+ ]:
+ return f"{self.model_type} not in valid model range"
+
+ model = self._get_model()
+ return model.fc(
+ messages=request.messages,
+ tools=request.tools,
+ tool_choice=request.tool_choice,
+ parallel_tool_calls=request.parallel_tool_calls,
+ )
+
+ def _call(self, prompt: str,
+ stop = None) -> str:
+ """_call
+ """
+ return_str = ""
+ stop = stop or self.stop
+ if self.model_type not in [
+ "openai", "ollama", "lingyiwanwu", "kimi", "moonshot", "qwen"
+ ]:
+ pass
+ elif self.model_type not in [
+ "dashscope_chat", "moonshot_chat", "ollama_chat",
+ "openai_chat", "qwen_chat", "yi_chat",
+ "dashscope_text_embedding", "ollama_embedding", "openai_embedding", "qwen_text_embedding"
+ ]:
+ pass
+ else:
+ return f"{self.model_type} not in valid model range"
+
+ model = self._get_model()
+ return model.predict(prompt, stop=self.stop)
class CustomEmbeddings(Embeddings):
@@ -185,6 +206,17 @@ def _get_sentence_emb(self, sentence: str) -> dict:
)
text2vector_dict = get_embedding("openai", [sentence], embed_config=embed_config)
return text2vector_dict[sentence]
+ elif self.embedding_type in [
+ "dashscope_text_embedding", "ollama_embedding", "openai_embedding", "qwen_text_embedding"
+ ]:
+ model_config = ModelConfig(
+ model_type=self.embedding_type,
+ model_name=self.model_name,
+ api_key=self.api_key,
+ api_url=self.url,
+ )
+ model = get_model(model_config)
+ return model.embed_query(sentence)
else:
pass
@@ -280,6 +312,7 @@ def embed_query(self, text: str) -> List[float]:
llm_config=llm_config,
tb_config=tb_config,
gb_config=gb_config,
+ initialize_space=True,
clear_history_data=clear_history_data
)
diff --git a/examples/muagent_examples/docchat_example.py b/examples/muagent_examples/docchat_example.py
index 1ba7ed4..89174cd 100644
--- a/examples/muagent_examples/docchat_example.py
+++ b/examples/muagent_examples/docchat_example.py
@@ -60,7 +60,8 @@
# create your knowledge base
from muagent.service.kb_api import create_kb, upload_files2kb
from muagent.utils.server_utils import run_async
-from muagent.orm import create_tables
+# from muagent.orm import create_tables
+from muagent.db_handler import create_tables
# use to test, don't create some directory
diff --git a/examples/test_config.py.example b/examples/test_config.py.example
index efc3603..ac03016 100644
--- a/examples/test_config.py.example
+++ b/examples/test_config.py.example
@@ -1,6 +1,8 @@
import os, openai, base64
from loguru import logger
+os.environ["DM_llm_name"] = 'Qwen2_72B_Instruct_OpsGPT' #or gpt_4
+
# 兜底大模型配置
OPENAI_API_BASE = "https://api.openai.com/v1"
os.environ["API_BASE_URL"] = OPENAI_API_BASE
@@ -19,6 +21,78 @@ os.environ["gpt4-llm_temperature"] = "0.0"
+MODEL_CONFIGS = {
+ # old llm config
+ "default": {
+ "model_name": "gpt-3.5-turbo",
+ "model_engine": "qwen",
+ "temperature": "0",
+ "api_key": "",
+ "api_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
+ },
+ "codefuser":{
+ "model_name": "gpt-4",
+ "model_engine": "openai",
+ "temperature": "0",
+ "api_key": "",
+ "api_base_url": OPENAI_API_BASE,
+ },
+ # new llm config
+ "dashscope_chat": {
+ "model_type": "dashscope_chat",
+ "model_name": "qwen2.5-72b-instruct" ,
+ "api_key": "",
+ },
+ "moonshot_chat": {
+ "model_type": "moonshot_chat",
+ "model_name": "moonshot-v1-8k" ,
+ "api_key": "",
+ },
+ "ollama_chat": {
+ "model_type": "ollama_chat",
+ "model_name": "qwen2.5-0.5b",
+ "api_key": "",
+ },
+ "openai_chat": {
+ "model_type": "openai_chat",
+ "model_name": "gpt-4",
+ "api_key": "",
+ },
+ "qwen_chat": {
+ "model_type": "qwen_chat",
+ "model_name": "qwen2.5-72b-instruct",
+ "api_key": "",
+ },
+ "yi_chat": {
+ "model_type": "yi_chat",
+ "model_name": "yi-lightning" ,
+ "api_key": "",
+ },
+ # embedding configs
+ "dashscope_text_embedding": {
+ "model_type": "dashscope_text_embedding",
+ "model_name": "text-embedding-v3",
+ "api_key": "",
+ },
+ "ollama_embedding": {
+ "model_type": "ollama_embedding",
+ "model_name": "qwen2.5-0.5b",
+ "api_key": "",
+ },
+ "openai_embedding": {
+ "model_type": "openai_embedding",
+ "model_name": "text-embedding-ada-002",
+ "api_key": "",
+ },
+ "qwen_text_embedding": {
+ "model_type": "dashscope_text_embedding",
+ "model_name": "text-embedding-v3",
+ "api_key": "",
+ },
+}
+
+os.environ["MODEL_CONFIGS"] = json.dumps(MODEL_CONFIGS)
+
#### NebulaHandler ####
os.environ['nb_host'] = 'graphd'
os.environ['nb_port'] = '9669'
@@ -41,8 +115,36 @@ os.environ["tb_index_name"] = "ekg_migration_new"
os.environ['tb_definition_value'] = 'message_test_new'
os.environ['tb_expire_time'] = '604800' #86400*7
-# clear history data in tb and gb
-os.environ['clear_history_data'] = 'True'
+
+#################
+## DB_CONFIGS ##
+#################
+DB_CONFIGS = {
+ "gb_config": {
+ "gb_type": "NebulaHandler",
+ "extra_kwargs": {
+ 'host':'graphd',
+ 'port': '9669',
+ 'username': os.environ['nb_username'],
+ 'password': os.environ['nb_password'],
+ 'space': "client"
+ }
+ },
+ "tb_config": {
+ "tb_type": 'TBaseHandler',
+ "index_name": "opsgptkg",
+ "host": 'redis-stack',
+ "port": '6379',
+ "username": os.environ['tb_username'],
+ "password": os.environ['tb_password'],
+ "extra_kwargs": {
+ "definition_value": "opsgptkg",
+ "memory_definition_value": "opsgptkg_message"
+ }
+ }
+}
+os.environ["DB_CONFIGS"] = json.dumps(DB_CONFIGS)
+
########################################
diff --git a/muagent/__init__.py b/muagent/__init__.py
index 67c87c2..a079b65 100644
--- a/muagent/__init__.py
+++ b/muagent/__init__.py
@@ -1,7 +1,11 @@
-# encoding: utf-8
-'''
-@author: 温进
-@file: __init__.py.py
-@time: 2023/11/9 下午4:01
-@desc:
-'''
\ No newline at end of file
+from .ekg_project import EKG, get_ekg_project_config_from_env
+from .project_manager import get_project_config_from_env
+from .models import get_model
+from .agents import get_agent
+from .tools import get_tool
+
+__all__ = [
+ "EKG", "get_model", "get_agent", "get_tool",
+ "get_ekg_project_config_from_env",
+ "get_project_config_from_env"
+]
\ No newline at end of file
diff --git a/muagent/agents/__init__.py b/muagent/agents/__init__.py
new file mode 100644
index 0000000..271b61f
--- /dev/null
+++ b/muagent/agents/__init__.py
@@ -0,0 +1,30 @@
+from .base_agent import BaseAgent
+from .single_agent import SingleAgent
+from .react_agent import ReactAgent
+from .task_agent import TaskAgent
+from .group_agent import GroupAgent
+from .user_agent import UserAgent
+from .functioncall_agent import FunctioncallAgent
+from ..schemas import AgentConfig
+
+__all__ = [
+ "BaseAgent",
+ "SingleAgent",
+ "ReactAgent",
+ "TaskAgent",
+ "GroupAgent",
+ "UserAgent",
+ "FunctioncallAgent"
+]
+
+
+def get_agent(agent_config: AgentConfig) -> BaseAgent:
+ """Get the agent by agent config
+
+ Args:
+ agent_config (`AgentConfig`): The agent config
+
+ Returns:
+ `BaseAgent`: The specific agent
+ """
+ return BaseAgent.init_from_project_config(agent_config)
\ No newline at end of file
diff --git a/muagent/agents/agent_util.py b/muagent/agents/agent_util.py
new file mode 100644
index 0000000..a5561ae
--- /dev/null
+++ b/muagent/agents/agent_util.py
@@ -0,0 +1,202 @@
+import re, uuid, os
+from typing import (
+ Union,
+ Tuple,
+ List
+)
+from loguru import logger
+
+from ..schemas import Memory, Message
+from ..schemas.common import ActionStatus, LogVerboseEnum
+from ..tools import get_tool
+from ..sandbox import NBClientBox
+
+from muagent.base_configs.env_config import KB_ROOT_PATH
+
+class MessageUtil:
+ """Utility class for processing messages and executing code or tools based on message content."""
+
+ def __init__(
+ self,
+ workdir_path: str = KB_ROOT_PATH,
+ log_verbose: str = "0",
+ **kwargs
+ ) -> None:
+ """Initialize the MessageUtil with the specified working directory and log verbosity.
+
+ Args:
+ workdir_path (str): Path to the working directory where files may be saved.
+ log_verbose (str): Verbosity level for logging.
+ **kwargs: Additional keyword arguments for future extensions.
+ """
+ self.codebox = NBClientBox(do_code_exe=True) # Initialize code execution box
+
+ self.workdir_path = workdir_path # Set the working directory path
+ self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose # Configure logging verbosity
+
+ def step_router(
+ self,
+ msg: Message,
+ session_index: str = "",
+ **kwargs
+ ) -> Tuple[Message, ...]:
+ """Route a message to the appropriate step for processing based on its action status.
+
+ Args:
+ msg (Message): The input message that needs processing.
+ session_index (str): The session identifier for managing the conversation.
+ **kwargs: Additional parameters for processing.
+
+ Returns:
+ Tuple[Message, ...]: The processed message and any observation message.
+ """
+ session_index = msg.session_index or session_index or str(uuid.uuid4())
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"message.action_status: {msg.action_status}")
+
+ observation_msg = None
+
+ # Determine the action to take based on the message's action status
+ if msg.action_status == ActionStatus.CODE_EXECUTING:
+ msg, observation_msg = self.code_step(msg, session_index)
+ elif msg.action_status == ActionStatus.TOOL_USING:
+ msg, observation_msg = self.tool_step(msg, session_index, **kwargs)
+ elif msg.action_status == ActionStatus.CODING2FILE:
+ self.save_code2file(msg, self.workdir_path)
+ # Handle other action statuses as needed (currently no operations for these)
+ elif msg.action_status == ActionStatus.CODE_RETRIEVAL:
+ pass
+ elif msg.action_status == ActionStatus.CODING:
+ pass
+
+ return msg, observation_msg
+
+ def code_step(self, msg: Message, session_index: str) -> Message:
+ """Execute code contained in the message.
+
+ Args:
+ msg (Message): The message containing code to be executed.
+ session_index (str): The session identifier for managing the conversation.
+
+ Returns:
+ Tuple[Message, Message]: The processed message and an observation message regarding code execution.
+ """
+ # Execute the code using the codebox and capture the result
+ code_answer = self.codebox.chat(
+ '```python\n{}```'.format(msg.spec_parsed_content.get("code_content", ""))
+ )
+
+ # Prepare a response message based on code execution result
+ code_prompt = (
+ f"The return error after executing the above code is {code_answer.code_exe_response},need to recover.\n"
+ if code_answer.code_exe_type == "error" else
+ f"The return information after executing the above code is {code_answer.code_exe_response}.\n"
+ )
+
+ # Create an observation message for logging code execution outcome
+ observation_msg = Message(
+ session_index=session_index,
+ role_name="function",
+ role_type="observation",
+ content="",
+ step_content="",
+ input_text=msg.spec_parsed_content.get("code_content", ""),
+ )
+
+ uid = str(uuid.uuid1()) # Generate a unique identifier for related content
+ if code_answer.code_exe_type == "image/png":
+ # If the code execution produces an image, log the result and update the message
+ msg.global_kwargs[uid] = code_answer.code_exe_response
+ msg.step_content += f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
+ msg.parsed_contents.append({"Observation": f"The return figure name is {uid} after executing the above code.\n"})
+ observation_msg.update_content(f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n")
+ observation_msg.update_parsed_content({"Observation": f"The return figure name is {uid} after executing the above code.\n"})
+ else:
+ # Log the standard execution result
+ msg.step_content += f"\n**Observation:**: {code_prompt}\n"
+ observation_msg.update_content(code_prompt)
+ observation_msg.update_parsed_content({"Observation": f"{code_prompt}\n"})
+
+ # Log the observations at the defined verbosity level
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"**Observation:** {msg.action_status}, {observation_msg.content}")
+
+ return msg, observation_msg
+
+ def tool_step(
+ self,
+ msg: Message,
+ session_index: str,
+ **kwargs
+ ) -> Message:
+ """Execute a tool based on parameters in the message.
+
+ Args:
+ msg (Message): The message that specifies the tool to be executed.
+ session_index (str): The session identifier for managing the conversation.
+ **kwargs: Additional parameters for processing, including available tools.
+
+ Returns:
+ Tuple[Message, ...]:
+ The processed message and an observation message regarding the tool execution.
+ """
+ no_tool_msg = "\n**Observation:** there is no tool can execute.\n" # Message for missing tool
+ tool_names = kwargs.get("tools") # Retrieve available tool names
+ extra_params = kwargs.get("extra_params", {})
+ tool_param = msg.spec_parsed_content.get("tool_param", {}) # Parameters for the tool execution
+ tool_param.update(extra_params)
+ tool_name = msg.spec_parsed_content.get("tool_name", "") # Name of the tool to execute
+
+ # Create a message to log the tool execution result
+ observation_msg = Message(
+ session_index=session_index,
+ role_name="function",
+ role_type="observation",
+ input_text=str(tool_param),
+ )
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"message: {msg.action_status}, {tool_param}")
+
+ if tool_name not in tool_names:
+ msg.step_content += f"\n{no_tool_msg}"
+ observation_msg.update_content(no_tool_msg)
+ observation_msg.update_parsed_content({"Observation": no_tool_msg})
+ else:
+ # Execute the specified tool and capture the result
+ tool = get_tool(tool_name)
+ tool_res = tool.run(**tool_param)
+ msg.step_content += f"\n**Observation:** {tool_res}.\n"
+ msg.parsed_contents.append({"Observation": f"{tool_res}.\n"})
+ observation_msg.update_content(f"**Observation:** {tool_res}.\n")
+ observation_msg.update_parsed_content({"Observation": f"{tool_res}.\n"})
+
+ # Log the observations at the defined verbosity level
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"**Observation:** {msg.action_status}, {observation_msg.content}")
+
+ return msg, observation_msg
+
+ def save_code2file(self, msg: Message, project_dir="./"):
+ """Save the code from the message to a specified file.
+
+ Args:
+ msg (Message): The message containing the code to be saved.
+ project_dir (str): Directory path where the code file will be saved.
+ """
+ filename = msg.parsed_content.get("SaveFileName") # Retrieve filename from message content
+ code = msg.spec_parsed_content.get("code") # Extract code content from the message
+
+ # Replace HTML entities in the code
+ for k, v in {">": ">", "≥": ">=", "<": "<", "≤": "<="}.items():
+ code = code.replace(k, v)
+
+ project_dir_path = os.path.join(self.workdir_path, project_dir) # Construct project directory path
+ file_path = os.path.join(project_dir_path, filename) # Full path for the output code file
+
+ # Create directories if they don't exist
+ if not os.path.exists(file_path):
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+
+ # Write the code to the file
+ with open(file_path, "w") as f:
+ f.write(code)
diff --git a/muagent/agents/base_agent.py b/muagent/agents/base_agent.py
new file mode 100644
index 0000000..47de07c
--- /dev/null
+++ b/muagent/agents/base_agent.py
@@ -0,0 +1,504 @@
+from abc import ABCMeta
+from pydantic import BaseModel
+import os
+from typing import (
+ List,
+ Union,
+ Generator,
+ Any,
+ Type,
+ Optional,
+ Literal
+)
+import copy
+from loguru import logger
+
+from ..schemas import (
+ Message,
+ Memory,
+ PromptConfig,
+ AgentConfig,
+ ProjectConfig
+)
+from ..schemas.models import ModelConfig
+from ..schemas.models import LLMConfig as TempLLMConfig
+from ..memory_manager import BaseMemoryManager
+from ..prompt_manager import BasePromptManager
+from ..models import ModelWrapperBase, get_model
+
+from .agent_util import MessageUtil
+from muagent.connector.schema import LogVerboseEnum
+from muagent.llm_models import getChatModelFromConfig
+
+
+class _AgentWapperBase(ABCMeta):
+ """A meta class to replace the tool wrapper's run function with
+ a wrapper that handles errors gracefully.
+ """
+
+ def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any:
+ if "__call__" in attrs:
+ attrs["__call__"] = attrs["__call__"]
+ return super().__new__(mcs, name, bases, attrs)
+
+ def __init__(cls, name: Any, bases: Any, attrs: Any) -> None:
+ # Initialize class-level registries for storing agent classes
+ if not hasattr(cls, "_registry"):
+ cls._registry = {} # Registry of agent class names
+ cls._type_registry = {} # Registry of agent class type names
+ else:
+ # Register the current class in the registry
+ cls._registry[name] = cls
+ cls._type_registry[cls.agent_type] = cls
+ super().__init__(name, bases, attrs)
+
+
+class BaseAgent(metaclass=_AgentWapperBase):
+ """Base class for agents, providing initialization and interaction methods.
+
+ You can define your custom agent for your agent work, such as
+ .. code-block:: python
+
+ from muagent.schemas.message import BaseAgent
+
+ class SingleAgent(BaseAgent):
+ """"""
+ agent_type: str = "SingleAgent"
+ """"""
+ agent_id: str
+ """"""
+ def __init__(
+ self,
+ agent_name: str = "codefuse_simpler",
+ system_prompt: str = "",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = "",
+ prompt: Optional[str] = None,
+ agents: List[str] = [],
+ tools: List[str] = [],
+ agent_desc: str = "",
+ *,
+ agent_config: Optional[AgentConfig] = None,
+ model_config: Optional[ModelConfig] = None,
+ prompt_config: Optional[PromptConfig] = PromptConfig(),
+ project_config: Optional[ProjectConfig] = None,
+ #
+ log_verbose: str = "0",
+ ):
+
+ super().__init__(
+ agent_name=agent_name,
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=output_template,
+ prompt=prompt,
+ agents=agents,
+ tools=tools,
+ agent_desc=agent_desc,
+ agent_config=agent_config,
+ model_config=model_config,
+ prompt_config=prompt_config,
+ project_config=project_config,
+ log_verbose=log_verbose
+ )
+
+ def step_stream(
+ self,
+ query: Message,
+ memory_manager: Optional[BaseMemoryManager]=None,
+ session_index: str = "default"
+ ) -> Generator[Message, None, None]:
+ '''agent response from multi-message'''
+ session_index = query.session_index or session_index
+ # insert query into memory
+ ...
+ # transform query into output_message.input_text
+ ...
+ # get memory from self or memory_manager
+ ...
+ # generate prompt by prompt manager
+ ...
+ # predict
+ ...
+ # update infomation
+ ...
+ # common parse llm' content to message
+ ...
+ # todo: action step
+ ...
+ # end
+ ...
+ # update self_memory and memory pool
+ ...
+ def pre_print(
+ self,
+ query: Message,
+ memory_manager: BaseMemoryManager=None,
+ tools: List[str] = [],
+ session_index: str = "default"
+
+ ) -> None:
+ pass
+ """
+
+ agent_type: str = "BaseAgent"
+ """Defines the type of the agent (default is BaseAgent)."""
+
+ agent_id: str
+ """Unique identifier for the agent."""
+
+ def __init__(
+ self,
+ agent_name: str = "codefuse_baser",
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = "",
+ prompt: Optional[str] = None,
+ agents: List[str] = [],
+ tools: List[str] = [],
+ agent_desc: str = "",
+ *,
+ agent_config: Optional[AgentConfig] = None,
+ model_config: Optional[ModelConfig] = None,
+ prompt_config: Optional[PromptConfig] = PromptConfig(),
+ project_config: Optional[ProjectConfig] = None,
+ #
+ log_verbose: str = "0",
+ ):
+ # Configure logging verbosity
+ self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
+
+ # Initialize agent properties
+ self.agent_name = agent_name
+ self.system_prompt = system_prompt
+ self.input_template = input_template
+ self.output_template = output_template
+ self.prompt = prompt
+ self.agent_desc = agent_desc
+ self.agents = agents
+ self.tools = tools
+ self.agent_config = agent_config
+ self.prompt_config = prompt_config
+ self.model_config = model_config
+ self.project_config = project_config
+ #
+ self.memory: Memory = Memory()
+ self.message_util = MessageUtil()
+
+ # Initialize agent from configuration data
+ self._init_from_configs()
+
+ def _init_from_configs(self):
+ '''Initialize agent's configuration from provided parameters.'''
+ if not self.agent_name:
+ raise ValueError(
+ f"Init a agent must have a agent name."
+ )
+ # Load configurations
+ self._init_agent_config()
+ self._init_model_config()
+ self._init_prompt_config()
+
+ def _init_agent_config(self):
+ '''Initialize agent configuration (AgentConfig).'''
+ # Load agent configuration based on the agent name and project config
+ if self.agent_name and self.project_config and self.project_config.agent_configs:
+ tmp_agent_config = self.project_config.agent_configs.get(self.agent_name)
+ self.agent_config = self.agent_config or tmp_agent_config
+
+ if self.agent_config and isinstance(self.agent_config, AgentConfig):
+ # Set agent properties from the configuration
+ self.agent_name = self.agent_config.agent_name
+ self.system_prompt = self.system_prompt or self.agent_config.system_prompt
+ self.input_template = self.input_template or self.agent_config.input_template
+ self.output_template = self.output_template or self.agent_config.output_template
+ self.prompt = self.prompt or self.agent_config.prompt
+ self.agent_desc = self.agent_desc or self.agent_config.agent_desc
+ self.tools = self.tools or self.agent_config.tools or self.tools
+ self.agents = self.agents or self.agent_config.agents
+ self._llm_config_name = self.agent_config.llm_config_name
+ self._em_config_name = self.agent_config.em_config_name
+ self._prompt_config_name = self.agent_config.prompt_config_name
+
+ def _init_model_config(self):
+ '''Initialize model configuration (ModelConfig).'''
+ # Check if model_config was provided
+ if self.model_config:
+ pass
+ # Load model configuration from project config if not provided
+ elif self.agent_name and self.project_config and self.project_config.model_configs:
+ if self._llm_config_name in self.project_config.model_configs:
+ self.model_config = self.project_config.model_configs[self._llm_config_name]
+ elif "default_chat" in self.project_config.model_configs:
+ self.model_config = self.project_config.model_configs["default_chat"]
+ else:
+ raise ValueError(
+ f"While init a model, project_config must have model configs. "
+ f"However, there is something wrong in agent_name: {self.agent_name} "
+ f"agent_config: {self.project_config.model_configs} "
+ )
+ else:
+ raise ValueError(
+ f"While init a model, it must have model config. "
+ f"However, there is something wrong in agent_name: {self.agent_name} "
+ f"agent_config: {self.project_config} "
+ )
+
+ def _init_prompt_config(self):
+ '''Initialize prompt configuration (PromptConfig).'''
+ # Load prompt configuration based on the agent's name and project config
+ if self.agent_name and self.project_config and self.project_config.prompt_configs:
+ self.prompt_config = self.project_config.prompt_configs.get(
+ self.agent_name, PromptConfig()
+ )
+ self._init_prompt_manager()
+ else:
+ self.prompt_config = PromptConfig() # Fallback to default prompt config
+ self._init_prompt_manager()
+
+ def _init_prompt_manager(self):
+ '''Initialize prompt manager from prompt configurations.'''
+ self.prompt_manager = BasePromptManager.from_config(
+ system_prompt=self.system_prompt,
+ input_template=self.input_template,
+ output_template=self.output_template,
+ prompt=self.prompt,
+ prompt_config=self.prompt_config,
+ )
+
+ def copy_config(self) -> ProjectConfig:
+ '''Create a copy of the current agent's configuration for use in a project.'''
+ return ProjectConfig(
+ agent_configs={self.agent_config.config_name: self.agent_config} if self.agent_config else {},
+ prompt_configs={self.prompt_config.config_name: self.prompt_config} if self.prompt_config else {},
+ model_configs={self.model_config.config_name: self.model_config} if self.model_config else {},
+ )
+
+ @classmethod
+ def init_from_project_config(cls, agent_name: str, project_config: ProjectConfig) -> 'BaseAgent':
+ '''Create a new instance of the agent from project configuration.'''
+ agent_config = project_config.agent_configs[agent_name]
+ agent_type = agent_config.agent_type
+ model_config = (
+ project_config.model_configs[agent_config.llm_config_name]
+ if agent_config.llm_config_name
+ else project_config.model_configs["default_chat"]
+ )
+ prompt_config = (
+ project_config.prompt_configs[agent_config.prompt_config_name]
+ if agent_config.prompt_config_name
+ else PromptConfig()
+ )
+ return cls.get_wrapper(agent_type)(
+ agent_config=agent_config,
+ model_config=model_config,
+ prompt_config=prompt_config,
+ project_config=project_config
+ )
+
+ @classmethod
+ def get_wrapper(cls, agent_type: str) -> Type['BaseAgent']:
+ '''Retrieve the appropriate agent class wrapper based on the agent type.
+
+ Args:
+ agent_type (str):
+ A string that specifies the type of agent for which a wrapper
+ class is requested. This string is used to look up the
+ appropriate agent class from the registered agent type registry.
+
+ Returns:
+ Type['BaseAgent']:
+ The method returns the appropriate subclass of BaseAgent based on
+ the provided agent_type. If the agent_type is found in the
+ class's _type_registry or _registry, it returns the corresponding
+ class. If not found, it raises a KeyError.
+ '''
+ if agent_type in cls._type_registry:
+ return cls._type_registry[agent_type]
+ elif agent_type in cls._registry:
+ return cls._registry[agent_type]
+ else:
+ raise KeyError(
+ f"Agent Library is missing "
+ f"{agent_type}, please check your agent type"
+ )
+
+ def step(
+ self,
+ query: Message,
+ memory_manager: Optional[BaseMemoryManager]=None,
+ session_index: str = "default",
+ **kwargs
+ ) -> Optional[Message]:
+ '''Process a query and return the agent's response.
+
+ Args:
+ query (Message):
+ An instance of the Message class containing the
+ input query for the agent.
+ memory_manager (Optional[BaseMemoryManager]):
+ An optional memory manager instance for managing message history.
+ session_index (str, default="default"):
+ A string representing the session index for message tracking and management.
+ kwargs: Additional keyword arguments for extended functionality.
+
+ Returns:
+ Optional[Message]:
+ The final response from the agent as an instance of the Message class,
+ or None if no response is available.
+ '''
+ session_index = query.session_index or session_index
+ message = None
+ # Retrieve the final message from the step_stream generator
+ for message in self.step_stream(
+ query, memory_manager, session_index, **kwargs
+ ):
+ pass
+ return message
+
+ def step_stream(
+ self,
+ query: Message,
+ memory_manager: Optional[BaseMemoryManager]=None,
+ session_index: str = "default"
+ ) -> Generator[Message, None, None]:
+ '''Stream the agent's responses over multiple messages.
+
+ Args:
+ query (Message):
+ An instance of the Message class containing the
+ input query for the agent.
+ memory_manager (Optional[BaseMemoryManager]):
+ An optional memory manager instance for managing message history.
+ session_index (str, default="default"):
+ A string representing the session index for message tracking and management.
+
+ Returns:
+ Generator[Message, None, None]:
+ A generator that yields multiple Message instances as responses to the input query.
+ '''
+ raise NotImplementedError(
+ f"Agent Wrapper [{type(self).__name__}]"
+ f" is missing the required `step_stream`"
+ f" method.",
+ )
+
+ def pre_print(
+ self,
+ query: Message,
+ memory_manager: BaseMemoryManager=None,
+ session_index: str = "default",
+ **kwargs
+ ) -> None:
+ """Pre-print this agent's prompt format.
+
+ Args:
+ query (Message):
+ An instance of the Message class containing the
+ input query for the agent.
+ memory_manager (Optional[BaseMemoryManager]):
+ An optional memory manager instance for managing message history.
+ session_index (str, default="default"):
+ A string representing the session index for message tracking and management.
+ """
+ session_index = query.session_index or session_index
+ # Generate the output message before proceeding with the agent action
+ output_message = self.inherit_extrainfo(query)
+ output_message = self.start_action_step(output_message)
+
+ # Insert query into history memory
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # Retrieve memory for the current session
+ memory = self.get_memory(session_index)
+ prompt = self.prompt_manager.pre_print(query=query, memory=memory, **kwargs)
+
+ # Displaying the formatted prompt for the agent
+ title = f"<<<<{self.agent_name}'s prompt>>>>"
+ print("#"*len(title) + f"\n{title}\n" + "#"*len(title) + f"\n\n{prompt}\n\n")
+
+ def inherit_extrainfo(self, input: Message):
+ """Incorporate additional information from the last message into the new message."""
+ output_message = Message(
+ role_name=self.agent_name,
+ role_type="assistant",
+ session_index=input.session_index,
+ )
+ output_message.update_input(input)
+ output_message.global_kwargs = copy.deepcopy(input.global_kwargs) # Preserve global args
+ return output_message
+
+ def registry_actions(self, actions):
+ '''Register actions related to the LLM model.'''
+ self.action_list = actions
+
+ def start_action_step(self, message: Message) -> Message:
+ '''Perform actions before predicting the response from the agent.'''
+ # (To be implemented) Additional actions can be done here
+ return message
+
+ def end_action_step(self, message: Message) -> Message:
+ '''Perform actions after the agent has predicted a response.'''
+ # (To be implemented) Additional actions can be done here
+ return message
+
+ def update_memory_manager(
+ self,
+ message: Message,
+ memory_manager: Optional[BaseMemoryManager] = None,
+ ):
+ """Update the memory manager with the latest message."""
+ if memory_manager:
+ memory_manager.append(message, self.agent_name)
+
+ def init_history(self, memory: Memory = None) -> Memory:
+ """Initialize message history."""
+ return Memory(messages=[])
+
+ def update_history(self, message: Message):
+ """Update the agent's internal history with a new message."""
+ self.memory.append(message)
+
+ def append_history(self, message: Message):
+ """Append a new message to the agent's history."""
+ self.memory.append(message)
+
+ def clear_history(self):
+ """Clear the agent's memory history."""
+ self.memory.clear()
+ self.memory = self.init_history()
+
+ def get_memory(
+ self,
+ session_index: str,
+ memory_manager: Optional[BaseMemoryManager] = None,
+ ) -> Memory:
+ """Retrieve the agent's memory for a given session index."""
+ if memory_manager:
+ return memory_manager.get_memory_pool(session_index=session_index)
+ return self.memory
+
+ def memory_to_format_messages(
+ self,
+ attributes: dict[str, Union[any, List[any]]] = {},
+ filter_type: Optional[Literal['select', 'filter']] = None,
+ *,
+ return_all: bool = True,
+ content_key: str = "response",
+ with_tag: bool = False,
+ format_type: Literal['raw', 'tuple', 'dict', 'str']='raw',
+ logic: Literal['or', 'and'] = 'and'
+ ) -> List:
+ """Format the stored memory into specific message formats based on parameters."""
+ kwargs = locals()
+ kwargs.pop("self")
+ kwargs.pop("class")
+ return self.memory.to_format_messages(**kwargs)
+
+ def _get_model(self) -> ModelWrapperBase:
+ """Retrieve the model wrapper based on the model configuration."""
+ if isinstance(self.model_config, ModelConfig):
+ return get_model(self.model_config)
+ elif isinstance(self.model_config, TempLLMConfig):
+ return getChatModelFromConfig(self.model_config)
\ No newline at end of file
diff --git a/muagent/agents/functioncall_agent.py b/muagent/agents/functioncall_agent.py
new file mode 100644
index 0000000..b1c614c
--- /dev/null
+++ b/muagent/agents/functioncall_agent.py
@@ -0,0 +1,237 @@
+from abc import ABCMeta
+from pydantic import BaseModel
+import os
+from typing import (
+ List,
+ Union,
+ Generator,
+ Optional,
+)
+
+from loguru import logger
+
+from ..schemas import (
+ Message,
+ Memory,
+ PromptConfig,
+ AgentConfig,
+ ProjectConfig
+)
+from .base_agent import BaseAgent
+from ..schemas.models import ModelConfig
+from ..memory_manager import BaseMemoryManager
+
+from muagent.connector.schema import LogVerboseEnum
+
+
+
+funtioncall_output_template = '''#### RESPONSE OUTPUT FORMAT
+**Thoughts:** According the previous context, plan the approach for using the tool effectively.
+
+**Action Status:** stoped, tool_using or code_executing
+Use 'stopped' when the task has been completed, and no further use of tools or execution of code is necessary.
+Use 'tool_using' when the current step in the process involves utilizing a tool to proceed.
+
+**Action:**
+
+If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this:
+```json
+{
+ "tool_name": "$TOOL_NAME",
+ "tool_params": $args
+}
+```
+
+If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this:
+```text
+The final response or instructions to the user question.
+```
+'''
+
+
+class FunctioncallAgent(BaseAgent):
+ """FunctioncallAgent class that extends the BaseAgent class for
+ function calling.
+
+ FunctioncallAgent Examples:
+ .. code-block:: python
+ from muagent.schemas import Message, Memory
+ from muagent.agents import FunctioncallAgent
+ from muagent import get_project_config_from_env
+
+
+ # log-level,print prompt和llm predict
+ os.environ["log_verbose"] = "0"
+
+ AGENT_CONFIGS = {
+ "codefuse_function_caller": {
+ "config_name": "codefuse_function_caller",
+ "agent_type": "FunctioncallAgent",
+ "agent_name": "codefuse_function_caller",
+ "llm_config_name": "qwener"
+ }
+ }
+ os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+ project_config = get_project_config_from_env()
+ tools = ["KSigmaDetector", "MetricsQuery"]
+ agent = FunctioncallAgent(
+ agent_name="codefuse_function_caller",
+ project_config=project_config,
+ tools=tools
+ )
+
+ query_content = "帮我查询下127.0.0.1这个服务器的在10点的数据"
+ query = Message(
+ role_name="human",
+ role_type="user",
+ content=query_content,
+ )
+ # agent.pre_print(query)
+ output_message = agent.step(query)
+ print("### intput ###\n", output_message.input_text)
+ print("### content ###\n", output_message.content)
+ print("### step content ###\n", output_message.step_content)
+ """
+
+ agent_type: str = "FunctioncallAgent"
+ """The type of the agent, which is defined as 'FunctioncallAgent'."""
+
+ agent_id: str
+ """Unique identifier for the agent."""
+
+ def __init__(
+ self,
+ agent_name: str = "codefuse_function_caller",
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = funtioncall_output_template,
+ prompt: Optional[str] = None,
+ agents: List[str] = [],
+ tools: List[str] = [],
+ agent_desc: str = "",
+ *,
+ agent_config: Optional[AgentConfig] = None,
+ model_config: Optional[ModelConfig] = None,
+ prompt_config: Optional[PromptConfig] = PromptConfig(),
+ project_config: Optional[ProjectConfig] = None,
+ #
+ log_verbose: str = "0",
+ ):
+
+ super().__init__(
+ agent_name=agent_name,
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=output_template or funtioncall_output_template,
+ prompt=prompt,
+ agents=agents,
+ tools=tools,
+ agent_desc=agent_desc,
+ agent_config=agent_config,
+ model_config=model_config,
+ prompt_config=prompt_config,
+ project_config=project_config,
+ log_verbose=log_verbose
+ )
+
+ def step_stream(
+ self,
+ query: Message,
+ memory_manager: Optional[BaseMemoryManager]=None,
+ session_index: str = "default",
+ memory: Optional[Memory] = None,
+ **kwargs
+ ) -> Generator[Message, None, None]:
+ '''agent response from multi-message'''
+
+ session_index = query.session_index or session_index
+
+ # insert query into memory
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # transform query into output_message.input_text
+ output_message = self.inherit_extrainfo(query)
+ output_message = self.start_action_step(output_message)
+
+ # get memory from self or memory_manager
+ memory = memory or self.get_memory(session_index)
+
+ # generate prompt by prompt manager
+ prompt = self.prompt_manager.generate_prompt(
+ query=output_message, memory=memory, tools=self.tools
+ )
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"{self.agent_name} prompt: {prompt}")
+
+ # predict
+ model = self._get_model()
+ content = model.predict(prompt)
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"{self.agent_name} content: {content}")
+
+ # update infomation
+ output_message.update_content(content)
+
+ # common parse llm' content to message
+ output_message = self.prompt_manager.parser(output_message)
+
+ # todo: action step
+ output_message, observation_message = self.message_util.step_router(
+ output_message,
+ session_index=session_index,
+ tools=self.tools,
+ **kwargs
+ )
+ # end
+ output_message = self.end_action_step(output_message)
+
+ # update self_memory and memory pool
+ self.append_history(output_message)
+ self.update_memory_manager(output_message, memory_manager)
+ if observation_message:
+ self.append_history(observation_message)
+ self.update_memory_manager(observation_message, memory_manager)
+
+ yield output_message
+
+ def pre_print(
+ self,
+ query: Message,
+ memory_manager: BaseMemoryManager=None,
+ tools: List[str] = [],
+ session_index: str = "default"
+
+ ) -> None:
+ """pre print this agent prompt format"""
+ session_index = query.session_index or session_index
+ #
+ output_message = self.inherit_extrainfo(query)
+ output_message = self.start_action_step(output_message)
+
+ # insert query into memory
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # get memory from self or memory_manager
+ memory = self.get_memory(session_index)
+
+ prompt = self.prompt_manager.pre_print(query=query, memory=memory, tools=tools or self.tools)
+
+ title = f"<<<<{self.agent_name}'s prompt>>>>"
+ print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
+
+ def start_action_step(self, message: Message) -> Message:
+ '''do action before agent predict '''
+ # action_json = self.start_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
+
+ def end_action_step(self, message: Message) -> Message:
+ '''do action after agent predict '''
+ # action_json = self.end_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
\ No newline at end of file
diff --git a/muagent/agents/group_agent.py b/muagent/agents/group_agent.py
new file mode 100644
index 0000000..0493b1b
--- /dev/null
+++ b/muagent/agents/group_agent.py
@@ -0,0 +1,227 @@
+from pydantic import BaseModel
+from typing import (
+ List,
+ Union,
+ Generator,
+ Optional,
+)
+
+from loguru import logger
+
+from ..schemas import (
+ Message,
+ Memory,
+ PromptConfig,
+ AgentConfig,
+ ProjectConfig
+)
+from .base_agent import BaseAgent
+from ..schemas.models import ModelConfig
+from ..memory_manager import BaseMemoryManager
+
+from muagent.connector.schema import LogVerboseEnum
+
+
+
+group_output_template = """#### RESPONSE OUTPUT FORMAT
+**Thoughts:** think the reason step by step about why you selecte one role
+
+**Role:** Select one role from agent names. No other information.
+"""
+
+group_output_template_zh = """#### 响应输出格式
+**思考:** 一步一步思考你选择一个角色的原因
+
+**角色:** 从代理名称中选择一个角色。不要包含其他信息。
+"""
+
+class GroupAgent(BaseAgent):
+ """GroupAgent class that extends the BaseAgent class for
+ managing the agent team to complete task.
+
+ GroupAgent Examples:
+ .. code-block:: python
+ from muagent.tools import TOOL_SETS
+ from muagent.schemas import Message
+ from muagent.agents import BaseAgent
+ from muagent.project_manager import get_project_config_from_env
+
+
+ tools = list(TOOL_SETS)
+ tools = ["KSigmaDetector", "MetricsQuery"]
+ role_prompt = "you are a helpful assistant!"
+
+ AGENT_CONFIGS = {
+ "grouper": {
+ "agent_type": "GroupAgent",
+ "agent_name": "grouper",
+ "agents": ["codefuse_reacter_1", "codefuse_reacter_2"]
+ },
+ "codefuse_reacter_1": {
+ "agent_type": "ReactAgent",
+ "agent_name": "codefuse_reacter_1",
+ "tools": tools,
+ },
+ "codefuse_reacter_2": {
+ "agent_type": "ReactAgent",
+ "agent_name": "codefuse_reacter_2",
+ "tools": tools,
+ }
+ }
+ os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+ # log-level,print prompt和llm predict
+ os.environ["log_verbose"] = "0"
+
+ #
+ project_config = get_project_config_from_env()
+ agent = BaseAgent.init_from_project_config(
+ "grouper", project_config
+ )
+
+ query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
+ query = Message(
+ role_name="human",
+ role_type="user",
+ content=query_content,
+ )
+ # agent.pre_print(query)
+ output_message = agent.step(query)
+ print("input:", output_message.input_text)
+ print("content:", output_message.content)
+ print("step_content:", output_message.step_content)
+ """
+
+ agent_type: str = "GroupAgent"
+ """The type of the agent, which is defined as 'GroupAgent'."""
+
+ agent_id: str
+ """Unique identifier for the agent."""
+
+ def __init__(
+ self,
+ agent_name: str = "codefuse_grouper",
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ agents: List[str] = [],
+ tools: List[str] = [],
+ agent_desc: str = "",
+ *,
+ agent_config: Optional[AgentConfig] = None,
+ model_config: Optional[ModelConfig] = None,
+ prompt_config: Optional[PromptConfig] = PromptConfig(),
+ project_config: Optional[ProjectConfig] = None,
+ #
+ log_verbose: str = "0",
+ **kwargs,
+ ):
+
+ super().__init__(
+ agent_name=agent_name,
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=group_output_template,
+ prompt="",
+ agents=agents,
+ tools=tools,
+ agent_desc=agent_desc,
+ agent_config=agent_config,
+ model_config=model_config,
+ prompt_config=prompt_config,
+ project_config = project_config,
+ log_verbose=log_verbose,
+ **kwargs,
+ )
+
+ def step_stream(
+ self,
+ query: Message,
+ memory_manager: Optional[BaseMemoryManager]=None,
+ session_index: str = "default"
+ ) -> Generator[Message, None, None]:
+ '''Stream the agent's responses based on an input multi-message query.'''
+
+ session_index = query.session_index or session_index
+
+ # insert query into memory
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # transform query into output_message.input_text
+ select_message = self.inherit_extrainfo(query)
+ select_message = self.start_action_step(select_message)
+
+ # get memory from self or memory_manager
+ memory = self.get_memory(session_index)
+
+ # generate prompt by prompt manager
+ agents = [self.get_agent_by_name(agent_name) for agent_name in self.agents]
+ agent_descs = [agent.agent_desc or agent.system_prompt for agent in agents]
+ prompt = self.prompt_manager.generate_prompt(
+ query=select_message, memory=memory,
+ tools=self.tools, agent_names=self.agents, agent_descs=agent_descs,
+ )
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"{self.agent_name} prompt: {prompt}")
+
+ # predict
+ model = self._get_model()
+ content = model.predict(prompt)
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"{self.agent_name} content: {content}")
+
+ # update infomation
+ select_message.update_content(content)
+ # common parse llm' content to message
+ select_message = self.prompt_manager.parser(select_message)
+
+ output_message = None
+ if select_message.parsed_content.get("Role", "") in self.agents:
+ agent_name = select_message.parsed_content.get("Role", "")
+ agent = self.get_agent_by_name(agent_name)
+
+ # update self_memory
+ self.append_history(select_message)
+ self.update_memory_manager(select_message, memory_manager)
+
+ # 把除了role以外的信息传给下一个agent
+ logger.debug(f"{select_message.parsed_content}")
+ select_message.parsed_content.update(
+ {k:v for k,v in select_message.parsed_content.items() if k!="Role"}
+ )
+ logger.debug(f"{select_message.parsed_content}")
+
+ # only query to next agent
+ query_bak = self.inherit_extrainfo(query)
+ for output_message in agent.step_stream(query_bak, memory_manager, session_index):
+ yield output_message or select_message
+
+ #
+ output_message = self.end_action_step(output_message)
+
+ select_message.update_content(output_message.step_content)
+ select_message.update_parsed_content(output_message.parsed_content)
+ select_message.update_spec_parsed_content(output_message.spec_parsed_content)
+
+ # update memory pool
+ self.append_history(output_message)
+ self.update_memory_manager(select_message, memory_manager)
+ yield select_message
+
+ def get_agent_by_name(self, agent_name: str) -> BaseAgent:
+ """new a agent by agent name and project config"""
+ return self.init_from_project_config(agent_name, self.project_config)
+
+ def start_action_step(self, message: Message) -> Message:
+ '''Perform any required actions before predicting the response of the agent.'''
+ # action_json = self.start_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
+
+ def end_action_step(self, message: Message) -> Message:
+ '''Perform any required actions after the agent has predicted the response.'''
+ # action_json = self.end_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
\ No newline at end of file
diff --git a/muagent/agents/react_agent.py b/muagent/agents/react_agent.py
new file mode 100644
index 0000000..5d5db14
--- /dev/null
+++ b/muagent/agents/react_agent.py
@@ -0,0 +1,284 @@
+from abc import ABCMeta
+from pydantic import BaseModel
+from typing import (
+ List,
+ Union,
+ Generator,
+ Optional,
+)
+import copy
+from loguru import logger
+
+from ..schemas import (
+ Message,
+ Memory,
+ PromptConfig,
+ AgentConfig,
+ ProjectConfig
+)
+from .base_agent import BaseAgent
+from ..schemas.models import ModelConfig
+from ..schemas.common import ActionStatus
+from ..memory_manager import BaseMemoryManager
+
+from muagent.connector.schema import LogVerboseEnum
+
+
+
+react_output_template = '''#### RESPONSE OUTPUT FORMAT
+**Thoughts:** According the previous observations, plan the approach for using the tool effectively.
+
+**Action Status:** stoped, tool_using or code_executing
+Use 'stopped' when the task has been completed, and no further use of tools or execution of code is necessary.
+Use 'tool_using' when the current step in the process involves utilizing a tool to proceed.
+Use 'code_executing' when the current step requires writing and executing code.
+
+**Action:**
+
+If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this:
+```json
+{
+ "tool_name": "$TOOL_NAME",
+ "tool_params": "$INPUT"
+}
+```
+
+If Action Status is 'code_executing', write the necessary code to solve the issue, enclosed in a code block, like this:
+```python
+Write your running code here
+```
+
+If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this:
+```text
+The final response or instructions to the user question.
+```
+
+**Observation:** Check the results and effects of the executed action.
+
+... (Repeat this Thoughts/Action Status/Action/Observation cycle as needed)
+
+**Thoughts:** Conclude the final response to the user question.
+
+**Action Status:** stoped
+
+**Action:** The final answer or guidance to the user question.
+'''
+
+
+class ReactAgent(BaseAgent):
+ """ReactAgent class that extends the BaseAgent class for completing task by reacting.
+
+ ReactAgent Examples:
+ .. code-block:: python
+ from muagent.tools import TOOL_SETS
+ from muagent.schemas import Message
+ from muagent.agents import BaseAgent
+ from muagent import get_project_config_from_env
+
+ # log-level,print prompt和llm predict
+ os.environ["log_verbose"] = "0"
+
+ tools = list(TOOL_SETS)
+ tools = ["KSigmaDetector", "MetricsQuery"]
+ role_prompt = "you are a helpful assistant!"
+
+ AGENT_CONFIGS = {
+ "reacter": {
+ "system_prompt": role_prompt,
+ "agent_type": "ReactAgent",
+ "agent_name": "reacter",
+ "tools": tools,
+ "llm_config_name": "qwen_chat"
+ }
+ }
+ os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+ #
+ project_config = get_project_config_from_env()
+ agent = BaseAgent.init_from_project_config(
+ "reacter", project_config
+ )
+
+ query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
+ query = Message(
+ role_name="human",
+ role_type="user",
+ content=query_content,
+ )
+ # agent.pre_print(query)
+ output_message = agent.step(query)
+ print("### intput ### ", output_message.input_text)
+ print("### content ### ", output_message.content)
+ print("### step content ### ", output_message.step_content)
+ """
+
+ agent_type: str = "ReactAgent"
+ """The type of the agent, which is defined as 'ReactAgent'."""
+
+ agent_id: str
+ """Unique identifier for the agent."""
+
+ def __init__(
+ self,
+ agent_name: str = "codefuse_reacter",
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = react_output_template,
+ prompt: Optional[str] = None,
+ stop: str = '**Observation:**',
+ agents: List[str] = [],
+ tools: List[str] = [],
+ agent_desc: str = "",
+ *,
+ agent_config: Optional[AgentConfig] = None,
+ model_config: Optional[ModelConfig] = None,
+ prompt_config: Optional[PromptConfig] = PromptConfig(),
+ project_config: Optional[ProjectConfig] = None,
+ #
+ chat_turn: int = 3,
+ log_verbose: str = "0",
+ ):
+ super().__init__(
+ agent_name=agent_name,
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=output_template or react_output_template,
+ prompt=prompt,
+ agents=agents,
+ tools=tools,
+ agent_desc=agent_desc,
+ agent_config=agent_config,
+ model_config=model_config,
+ prompt_config=prompt_config,
+ project_config=project_config,
+ log_verbose=log_verbose
+ )
+ #
+ self.stop = stop
+ self.chat_turn = chat_turn
+
+ def step_stream(
+ self,
+ query: Message,
+ memory_manager: Optional[BaseMemoryManager]=None,
+ session_index: str = "default"
+ ) -> Generator[Message, None, None]:
+ '''Stream the agent's responses based on an input multi-message query.'''
+
+ session_index = query.session_index or session_index
+ step_nums = copy.deepcopy(self.chat_turn)
+ react_memory = Memory(messages=[])
+
+ # insert query into memory
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # transform query into output_message.input_text
+ output_message = self.inherit_extrainfo(query)
+ output_message = self.start_action_step(output_message)
+
+ # get memory from self or memory_manager
+ memory = self.get_memory(session_index)
+
+ idx = 0
+ while step_nums > 0:
+ output_message.content = output_message.step_content
+ prompt = self.prompt_manager.generate_prompt(
+ query=output_message,
+ memory=memory,
+ react_memory=react_memory,
+ tools=self.tools
+ )
+
+ try:
+ model = self._get_model()
+ content = model.predict(prompt, self.stop)
+ except Exception as e:
+ logger.error(f"error : {e}, prompt: {prompt}")
+ raise RuntimeError(f"error : {e}, prompt: {prompt}")
+
+ output_message.content = content
+ output_message.step_content += f"\n{content}"
+ yield output_message
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"{self.agent_name}, {idx} iteration prompt: {prompt}")
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"{self.agent_name}, {idx} iteration step_run: {content}")
+
+ output_message = self.prompt_manager.parser(output_message)
+ # when get finished signal can stop early
+ if (output_message.action_status == ActionStatus.FINISHED or
+ output_message.action_status == ActionStatus.STOPPED):
+ output_message.spec_parsed_contents.append(output_message.spec_parsed_content)
+ break
+ # according the output to choose one action for code_content or tool_content
+ output_message, observation_message = self.message_util.step_router(
+ output_message,
+ session_index=session_index,
+ tools=self.tools,
+ )
+
+ # only record content
+ react_message = copy.deepcopy(output_message)
+ react_memory.append(react_message)
+ if observation_message:
+ react_memory.append(observation_message)
+ output_message.update_parsed_content(observation_message.parsed_content)
+ output_message.update_spec_parsed_content(observation_message.parsed_content)
+ idx += 1
+ step_nums -= 1
+ yield output_message
+
+ # end
+ output_message = self.end_action_step(output_message)
+
+ # update self_memory and memory pool
+ self.append_history(output_message)
+ self.update_memory_manager(output_message, memory_manager)
+ yield output_message
+
+ def pre_print(
+ self,
+ query: Message,
+ memory_manager: BaseMemoryManager=None,
+ tools: List[str] = [],
+ session_index: str = "default"
+
+ ) -> None:
+ """pre print this agent prompt format"""
+ session_index = query.session_index or session_index
+ react_memory = Memory(messages=[])
+ #
+ output_message = self.inherit_extrainfo(query)
+ output_message = self.start_action_step(output_message)
+
+ # insert query into memory
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # get memory from self or memory_manager
+ memory = self.get_memory(session_index)
+
+ prompt = self.prompt_manager.pre_print(
+ query=query,
+ memory=memory,
+ tools=tools or self.tools,
+ react_memory=react_memory
+ )
+
+ title = f"<<<<{self.agent_name}'s prompt>>>>"
+ print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
+
+ def start_action_step(self, message: Message) -> Message:
+ '''Perform any required actions before predicting the response of the agent.'''
+ # action_json = self.start_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
+
+ def end_action_step(self, message: Message) -> Message:
+ '''Perform any required actions after the agent has predicted the response.'''
+ # action_json = self.end_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
\ No newline at end of file
diff --git a/muagent/agents/single_agent.py b/muagent/agents/single_agent.py
new file mode 100644
index 0000000..741253a
--- /dev/null
+++ b/muagent/agents/single_agent.py
@@ -0,0 +1,205 @@
+from abc import ABCMeta
+from pydantic import BaseModel
+import os
+from typing import (
+ List,
+ Union,
+ Generator,
+ Optional,
+)
+
+from loguru import logger
+
+from ..schemas import (
+ Message,
+ Memory,
+ PromptConfig,
+ AgentConfig,
+ ProjectConfig
+)
+from .base_agent import BaseAgent
+from ..schemas.models import ModelConfig
+from ..memory_manager import BaseMemoryManager
+
+from muagent.connector.schema import LogVerboseEnum
+
+
+class SingleAgent(BaseAgent):
+ """SingleAgent class that extends the BaseAgent class for simple single-agent tasks.
+
+ FunctioncallAgent Examples:
+ .. code-block:: python
+ from muagent.schemas import Message, Memory
+ from muagent.agents import BaseAgent
+ from muagent import get_project_config_from_env
+
+ tools = list(TOOL_SETS)
+ tools = ["KSigmaDetector", "MetricsQuery"]
+ AGENT_CONFIGS = {
+ "codefuse_simpler": {
+ "agent_type": "SingleAgent",
+ "agent_name": "codefuse_simpler",
+ "tools": tools,
+ "llm_config_name": "qwen_chat"
+ }
+ }
+ os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+ project_config = get_project_config_from_env()
+ agent = BaseAgent.init_from_project_config(
+ "codefuse_simpler", project_config
+ )
+
+ query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
+ query = Message(
+ role_name="human",
+ role_type="user",
+ input_text=query_content,
+ )
+ # base_agent.pre_print(query)
+ output_message = agent.step(query)
+ print("### intput ###", output_message.input_text)
+ print("### content ###", output_message.content)
+ print("### step content ###", output_message.step_content)
+ """
+
+ agent_type: str = "SingleAgent"
+ """The type of the agent, which is defined as 'SingleAgent'."""
+
+ agent_id: str
+ """Unique identifier for the agent."""
+
+ def __init__(
+ self,
+ agent_name: str = "codefuse_simpler",
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = "",
+ prompt: Optional[str] = None,
+ agents: List[str] = [],
+ tools: List[str] = [],
+ agent_desc: str = "",
+ *,
+ agent_config: Optional[AgentConfig] = None,
+ model_config: Optional[ModelConfig] = None,
+ prompt_config: Optional[PromptConfig] = PromptConfig(),
+ project_config: Optional[ProjectConfig] = None,
+ #
+ log_verbose: str = "0",
+ ):
+
+ super().__init__(
+ agent_name=agent_name,
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=output_template,
+ prompt=prompt,
+ agents=agents,
+ tools=tools,
+ agent_desc=agent_desc,
+ agent_config=agent_config,
+ model_config=model_config,
+ prompt_config=prompt_config,
+ project_config=project_config,
+ log_verbose=log_verbose
+ )
+
+ def step_stream(
+ self,
+ query: Message,
+ memory_manager: Optional[BaseMemoryManager]=None,
+ session_index: str = "default"
+ ) -> Generator[Message, None, None]:
+ '''Stream the agent's responses based on an input multi-message query.'''
+
+ session_index = query.session_index or session_index
+
+ # Insert the received query into memory
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # Create an output message containing inherited information from the input query
+ output_message = self.inherit_extrainfo(query)
+ output_message = self.start_action_step(output_message)
+
+ # Retrieve memory for the current session, either from self or the memory manager
+ memory = self.get_memory(session_index)
+
+ # Generate a prompt using the prompt manager
+ prompt = self.prompt_manager.generate_prompt(
+ query=output_message, memory=memory, tools=self.tools
+ )
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"{self.agent_name} prompt: {prompt}")
+
+ # Predict the content using the agent's model
+ model = self._get_model()
+ content = model.predict(prompt)
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"{self.agent_name} content: {content}")
+
+ # Update the output message with the predicted content
+ output_message.update_content(content)
+
+ # Parse the output content into a structured message format
+ output_message = self.prompt_manager.parser(output_message)
+
+ # Process any actions or observations required based on the output message
+ output_message, observation_message = self.message_util.step_router(
+ output_message,
+ session_index=session_index,
+ tools=self.tools,
+ )
+
+ # Wrap up any action steps
+ output_message = self.end_action_step(output_message)
+
+ # Update memory with the output message and any observations
+ self.append_history(output_message)
+ self.update_memory_manager(output_message, memory_manager)
+ if observation_message:
+ self.append_history(observation_message)
+ self.update_memory_manager(observation_message, memory_manager)
+
+ yield output_message # Yield the constructed output message
+
+ def pre_print(
+ self,
+ query: Message,
+ memory_manager: BaseMemoryManager=None,
+ tools: List[str] = [],
+ session_index: str = "default"
+ ) -> None:
+ """Prepare and print the prompt format for this agent based on the input query."""
+ session_index = query.session_index or session_index
+ # Prepare an output message with inherited information
+ output_message = self.inherit_extrainfo(query)
+ output_message = self.start_action_step(output_message)
+
+ # Insert query into memory for later reference
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # Get the current memory for the session
+ memory = self.get_memory(session_index)
+
+ # Generate and format the prompt
+ prompt = self.prompt_manager.pre_print(query=query, memory=memory, tools=tools or self.tools)
+
+ # Display the prompt for this agent
+ title = f"<<<<{self.agent_name}'s prompt>>>>"
+ print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
+
+ def start_action_step(self, message: Message) -> Message:
+ '''Perform any required actions before predicting the response of the agent.'''
+ # action_json = self.start_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
+
+ def end_action_step(self, message: Message) -> Message:
+ '''Perform any required actions after the agent has predicted the response.'''
+ # action_json = self.end_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
\ No newline at end of file
diff --git a/muagent/agents/task_agent.py b/muagent/agents/task_agent.py
new file mode 100644
index 0000000..2aa24da
--- /dev/null
+++ b/muagent/agents/task_agent.py
@@ -0,0 +1,291 @@
+from abc import ABCMeta
+from pydantic import BaseModel
+import os
+from typing import (
+ List,
+ Union,
+ Generator,
+ Optional,
+ Tuple,
+)
+import copy
+from loguru import logger
+
+from ..schemas import (
+ Message,
+ Memory,
+ PromptConfig,
+ AgentConfig,
+ ProjectConfig
+)
+from .base_agent import BaseAgent
+from ..schemas.models import ModelConfig
+from ..memory_manager import BaseMemoryManager
+from ..base_configs.prompts import PLAN_EXECUTOR_PROMPT
+
+from muagent.connector.schema import LogVerboseEnum
+
+
+
+
+executor_output_template = '''#### RESPONSE OUTPUT FORMAT
+**Thoughts:** Considering the session records and task records, decide whether the current step requires the use of a tool or code_executing.
+Solve the problem only displaying the thought process necessary for the current step of solving the problem.
+
+**Action Status:** stoped, tool_using or code_executing
+Use 'stopped' when the task has been completed, and no further use of tools or execution of code is necessary.
+Use 'tool_using' when the current step in the process involves utilizing a tool to proceed.
+Use 'code_executing' when the current step requires writing and executing code.
+
+**Action:**
+
+If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this:
+```json
+{
+ "tool_name": "$TOOL_NAME",
+ "tool_params": "$INPUT"
+}
+```
+
+If Action Status is 'code_executing', write the necessary code to solve the issue, enclosed in a code block, like this:
+```python
+Write your running code here
+```
+
+If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this:
+```text
+The final response or instructions to the user question.
+```'''
+
+
+class TaskAgent(BaseAgent):
+ """TaskAgent class that extends the BaseAgent class for delegaing query into multi task.
+
+ TaskAgent Examples:
+ .. code-block:: python
+
+ from muagent.schemas import Message
+ from muagent.agents import BaseAgent
+ from muagent import get_project_config_from_env
+
+ tools = list(TOOL_SETS)
+ tools = ["KSigmaDetector", "MetricsQuery"]
+ role_prompt = "you are a helpful assistant!"
+
+ AGENT_CONFIGS = {
+ "tasker": {
+ "system_prompt": role_prompt,
+ "agent_type": "TaskAgent",
+ "agent_name": "tasker",
+ "tools": tools,
+ "llm_config_name": "qwen_chat"
+ }
+ }
+ os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+ #
+ project_config = get_project_config_from_env()
+ agent = BaseAgent.init_from_project_config(
+ "tasker", project_config
+ )
+
+ query_content = "先帮我获取下127.0.0.1这个服务器在10点的数,然后在帮我判断下数据是否存在异常"
+ query = Message(
+ role_name="human",
+ role_type="user",
+ content=query_content,
+ )
+ # agent.pre_print(query)
+ output_message = agent.step(query)
+ print("### intput ###\n", output_message.input_text)
+ print("### content ###\n", output_message.content)
+ print("### step content ###\n", output_message.step_content)
+ """
+
+ agent_type: str = "TaskAgent"
+ """The type of the agent, which is defined as 'TaskAgent'."""
+
+ agent_id: str
+ """Unique identifier for the agent."""
+
+ def __init__(
+ self,
+ agent_name: str = "codefuse_tasker",
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = executor_output_template,
+ prompt: Optional[str] = None,
+ agents: List[str] = [],
+ tools: List[str] = [],
+ agent_desc: str = "",
+ *,
+ agent_config: Optional[AgentConfig] = None,
+ model_config: Optional[ModelConfig] = None,
+ prompt_config: Optional[PromptConfig] = PromptConfig(),
+ project_config: Optional[ProjectConfig] = None,
+ #
+ do_all_task: bool = True,
+ log_verbose: str = "0",
+ ):
+ super().__init__(
+ agent_name=agent_name,
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=output_template or executor_output_template,
+ prompt=prompt,
+ agents=agents,
+ tools=tools,
+ agent_desc=agent_desc,
+ agent_config=agent_config,
+ model_config=model_config,
+ prompt_config=prompt_config,
+ project_config=project_config,
+ log_verbose=log_verbose
+ )
+ #
+ self.do_all_task = do_all_task
+
+ def step_stream(
+ self,
+ query: Message,
+ memory_manager: Optional[BaseMemoryManager]=None,
+ session_index: str = "default"
+ ) -> Generator[Message, None, None]:
+ '''Stream the agent's responses based on an input multi-message query.'''
+
+ session_index = query.session_index or session_index
+
+ # insert query into memory
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # transform query into output_message.input_text
+ output_message = self.inherit_extrainfo(query)
+ output_message = self.start_action_step(output_message)
+
+ # get memory from self or memory_manager
+ memory = self.get_memory(session_index)
+
+ # generate prompt by prompt manager
+ input_text = query.content or output_message.input_text
+ prompt = PLAN_EXECUTOR_PROMPT.format(
+ **{"content": input_text.replace("*", "")}
+ )
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"{self.agent_name} prompt: {prompt}")
+
+ model = self._get_model()
+ content = model.predict(prompt)
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"{self.agent_name} content: {content}")
+
+ plan_message = Message(
+ session_index=session_index,
+ role_name="plan_extracter",
+ role_type="assistant",
+ content=content,
+ global_kwargs=query.global_kwargs
+ )
+ plan_message = self.prompt_manager.parser(plan_message)
+ # process input_quert to plans and plan_step
+ plan_step = int(plan_message.parsed_content.get("PLAN_STEP", 0))
+ plans = plan_message.parsed_content.get("PLAN", [input_text])
+
+ if self.do_all_task:
+ # run all tasks step by step
+ for idx, task_content in enumerate(plans[plan_step:]):
+ for output_message in self._execute_line(
+ task_content, output_message, plan_step+idx, session_index
+ ):
+ yield output_message
+ else:
+ task_content = plans[plan_step]
+ for output_message in self._execute_line(
+ task_content, output_message, plan_step+idx, session_index
+ ):
+ pass
+
+ # end
+ output_message = self.end_action_step(output_message)
+
+ # update self_memory and memory pool
+ self.append_history(output_message)
+ self.update_memory_manager(output_message, memory_manager)
+ yield output_message
+
+ def _execute_line(
+ self,
+ task_content: str,
+ output_message: Message,
+ plan_step,
+ session_index
+ ) -> Generator[Tuple[Message, Memory], None, None]:
+ '''task execute line'''
+ query = copy.deepcopy(output_message)
+ query.parsed_content = {"CURRENT_STEP": task_content}
+ query = self.start_action_step(query)
+
+ # get memory from self or memory_manager
+ memory = self.get_memory(session_index)
+
+ for output_message in self._run_stream(
+ query, output_message, memory, session_index
+ ):
+ yield output_message
+ output_message.update_spec_parsed_content(
+ {**output_message.spec_parsed_content, **{"PLAN_STEP": plan_step}}
+ )
+ yield output_message
+
+ def _run_stream(
+ self,
+ query: Message,
+ output_message: Message,
+ memory: Memory,
+ session_index: str
+ ) -> Generator[Tuple[Message, Memory], None, None]:
+ '''execute the llm predict by created prompt'''
+ prompt = self.prompt_manager.generate_prompt(
+ query=query,
+ memory=memory,
+ tools=self.tools,
+ )
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"{self.agent_name} prompt: {prompt}")
+
+ model = self._get_model()
+ content = model.predict(prompt)
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"{self.agent_name} content: {content}")
+
+ output_message.update_content(content)
+ output_message = self.prompt_manager.parser(output_message)
+ # according the output to choose one action for code_content or tool_content
+ output_message, observation_message = self.message_util.step_router(
+ output_message, session_index=session_index,
+ tools=self.tools
+ )
+ react_message = copy.deepcopy(output_message)
+ self.append_history(react_message)
+ # task_memory.append(react_message)
+ if observation_message:
+ # task_memory.append(observation_message)
+ self.append_history(observation_message)
+ output_message.update_parsed_content(observation_message.parsed_content)
+ output_message.update_spec_parsed_content(observation_message.parsed_content)
+ yield output_message
+
+ def start_action_step(self, message: Message) -> Message:
+ '''Perform any required actions before predicting the response of the agent.'''
+ # action_json = self.start_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
+
+ def end_action_step(self, message: Message) -> Message:
+ '''Perform any required actions after the agent has predicted the response.'''
+ # action_json = self.end_action()
+ # message["customed_kargs"]["xx"] = action_json
+ return message
\ No newline at end of file
diff --git a/muagent/agents/user_agent.py b/muagent/agents/user_agent.py
new file mode 100644
index 0000000..3393d33
--- /dev/null
+++ b/muagent/agents/user_agent.py
@@ -0,0 +1,140 @@
+from abc import ABCMeta
+from pydantic import BaseModel
+import os
+from typing import (
+ List,
+ Union,
+ Generator,
+ Optional,
+)
+
+from loguru import logger
+
+from ..schemas import (
+ Message,
+ Memory,
+ PromptConfig,
+ AgentConfig,
+ ProjectConfig
+)
+from .base_agent import BaseAgent
+from ..schemas.models import ModelConfig
+from ..memory_manager import BaseMemoryManager
+
+from muagent.connector.schema import LogVerboseEnum
+
+
+class UserAgent(BaseAgent):
+ """UserAgent class that extends the BaseAgent class for simulating user' response."""
+
+ agent_type: str = "UserAgent"
+ """The type of the agent, which is defined as 'UserAgent'."""
+
+ agent_id: str
+ """Unique identifier for the agent."""
+
+ def __init__(
+ self,
+ agent_name: str = "codefuse_user",
+ system_prompt: str = "",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = "",
+ prompt: Optional[str] = None,
+ agents: List[str] = [],
+ tools: List[str] = [],
+ agent_desc: str = "",
+ *,
+ agent_config: Optional[AgentConfig] = None,
+ model_config: Optional[ModelConfig] = None,
+ prompt_config: Optional[PromptConfig] = PromptConfig(),
+ project_config: Optional[ProjectConfig] = None,
+ #
+ log_verbose: str = "0",
+ ):
+
+ super().__init__(
+ agent_name=agent_name,
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=output_template,
+ prompt=prompt,
+ agents=agents,
+ tools=tools,
+ agent_desc=agent_desc,
+ agent_config=agent_config,
+ model_config=model_config,
+ prompt_config=prompt_config,
+ project_config=project_config,
+ log_verbose=log_verbose
+ )
+
+ def step_stream(
+ self,
+ query: Message,
+ memory_manager: Optional[BaseMemoryManager]=None,
+ session_index: str = "default"
+ ) -> Generator[Message, None, None]:
+ '''Stream the agent's responses based on an input multi-message query.'''
+
+ session_index = query.session_index or session_index
+
+ # insert query into memory
+ self.append_history(query)
+ self.update_memory_manager(query, memory_manager)
+
+ # transform query into output_message.input_text
+ output_message = self.inherit_extrainfo(query)
+ output_message = self.start_action_step(output_message)
+
+ # get memory from self or memory_manager
+ memory = self.get_memory(session_index)
+
+ # generate prompt by prompt manager
+ prompt = self.prompt_manager.generate_prompt(
+ query=output_message, memory=memory, tools=self.tools
+ )
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"{self.agent_name} prompt: {prompt}")
+
+ # predict
+ content = input("please answer: \n")
+
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"{self.agent_name} content: {content}")
+
+ # update infomation
+ output_message.update_content(content)
+
+ # common parse llm' content to message
+ output_message = self.prompt_manager.parser(output_message)
+
+ # todo: action step
+ output_message, observation_message = self.message_util.step_router(
+ output_message,
+ session_index=session_index,
+ tools=self.tools,
+ )
+ # end
+ output_message = self.end_action_step(output_message)
+
+ # update self_memory and memory pool
+ self.append_history(output_message)
+ self.update_memory_manager(output_message, memory_manager)
+ if observation_message:
+ self.append_history(observation_message)
+ self.update_memory_manager(observation_message, memory_manager)
+
+ yield output_message
+
+ def pre_print(
+ self,
+ query: Message,
+ memory_manager: BaseMemoryManager=None,
+ tools: List[str] = [],
+ session_index: str = "default"
+
+ ) -> None:
+ """pre print this agent prompt format"""
+ title = f"<<<<{self.agent_name}'s prompt>>>>"
+ print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{query.content}\n\n")
diff --git a/muagent/agents/util.py b/muagent/agents/util.py
new file mode 100644
index 0000000..e69de29
diff --git a/muagent/base_configs/env_config.py b/muagent/base_configs/env_config.py
index 4b78388..765944c 100644
--- a/muagent/base_configs/env_config.py
+++ b/muagent/base_configs/env_config.py
@@ -12,19 +12,19 @@
# SOURCE_PATH = os.environ.get("SOURCE_PATH", None) or os.path.join(executable_path, "sources")
# 知识库默认存储路径
-KB_ROOT_PATH = os.environ.get("KB_ROOT_PATH", None) or os.path.join(executable_path, "knowledge_base")
+KB_ROOT_PATH = os.environ.get("KB_ROOT_PATH", None) or os.path.join(executable_path, "data/knowledge_base")
# 代码库默认存储路径
-CB_ROOT_PATH = os.environ.get("CB_ROOT_PATH", None) or os.path.join(executable_path, "code_base")
+CB_ROOT_PATH = os.environ.get("CB_ROOT_PATH", None) or os.path.join(executable_path, "data/code_base")
# # nltk 模型存储路径
# NLTK_DATA_PATH = os.environ.get("NLTK_DATA_PATH", None) or os.path.join(executable_path, "nltk_data")
# 代码存储路径
-JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(executable_path, "jupyter_work")
+JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(executable_path, "data/jupyter_work")
-# WEB_CRAWL存储路径
-WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base")
+# # WEB_CRAWL存储路径
+# WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base")
# NEBULA_DATA存储路径
NEBULA_PATH = os.environ.get("NEBULA_PATH", None) or os.path.join(executable_path, "data/nebula_data")
@@ -32,7 +32,7 @@
# CHROMA 存储路径
CHROMA_PERSISTENT_PATH = os.environ.get("CHROMA_PERSISTENT_PATH", None) or os.path.join(executable_path, "data/chroma_data")
-for _path in [LOG_PATH, KB_ROOT_PATH, CB_ROOT_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
+for _path in [LOG_PATH, KB_ROOT_PATH, CB_ROOT_PATH, JUPYTER_WORK_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
if not os.path.exists(_path) and int(os.environ.get("do_create_dir", True)):
os.makedirs(_path, exist_ok=True)
@@ -83,8 +83,8 @@
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
# Mac 可能存在无法使用normalized_L2的问题,因此调整SCORE_THRESHOLD至 0~1100
-FAISS_NORMALIZE_L2 = True if system_name in ["Linux", "Windows"] else False
-SCORE_THRESHOLD = 1 if system_name in ["Linux", "Windows"] else 1100
+FAISS_NORMALIZE_L2 = True if system_name in ["Darwin", "Linux", "Windows"] else False
+SCORE_THRESHOLD = 1 if system_name in ["Darwin", "Linux", "Windows"] else 1100
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = os.environ.get("SEARCH_ENGINE_TOP_K") or 5
diff --git a/muagent/base_configs/prompts/functioncall_template_prompt.py b/muagent/base_configs/prompts/functioncall_template_prompt.py
new file mode 100644
index 0000000..812e481
--- /dev/null
+++ b/muagent/base_configs/prompts/functioncall_template_prompt.py
@@ -0,0 +1,36 @@
+
+FUNCTION_CALL_PROMPT_en = """You have access to the following functions:
+
+{tool_desc}
+
+To call a function, please respond with JSON for a function call.
+
+Respond in the format [{"name": function name, "arguments": dictionary of argument name and its value}].
+"""
+
+FC_AUTO_PROMPT_en = """
+The function can be called zero or multiple according to your needs.
+"""
+
+
+FC_REQUIRED_PROMPT_en = """
+You must call a function as least.
+"""
+
+FC_PARALLEL_PROMPT_en = """
+The function can be called in parallel.
+"""
+
+
+FC_RESPONSE_PROMPT_en = """## Response Ouput
+Response the function calls by formatting the in JSON. The format should be:
+
+```json
+[
+{
+ "name": function name,
+ "arguments": dictionary of argument name and its value
+}
+]
+```
+"""
\ No newline at end of file
diff --git a/muagent/base_configs/prompts/intention_template_prompt.py b/muagent/base_configs/prompts/intention_template_prompt.py
index 4e171fd..64d2c2b 100644
--- a/muagent/base_configs/prompts/intention_template_prompt.py
+++ b/muagent/base_configs/prompts/intention_template_prompt.py
@@ -74,6 +74,7 @@
## 输出格式
最相关意图对应的数字(第一个意图对应数字1){extra}。
+其他任何内容都是不允许的。
{example}
## 用户询问
"""
@@ -141,16 +142,8 @@ def get_intention_prompt(
name='整体计划查询', tag='allPlan'
)
INTENTION_NEXTSTEP = IntentionInfo(
- description='用户询问某个问题或方案中某一个特定步骤。通常会提及“下一步”、“具体操作”等。',
- name='某一步任务查询', tag='nextStep'
-)
-INTENTION_SEVERALSTEPS = IntentionInfo(
- description='用户询问某个问题或方案中其中某几个步骤。',
- name='某几步任务查询', tag='severalSteps'
-)
-INTENTION_BACKGROUND = IntentionInfo(
- description='用户询问某个问题或方案的背景知识,规则以及流程介绍等。',
- name='背景查询', tag='background'
+ description='用户询问某个问题或方案的特定步骤,通常会提及“下一步”、“具体操作”等。',
+ name='下一步任务查询', tag='nextStep'
)
INTENTION_CHAT = IntentionInfo(
description='用户询问的内容与当前的技术问题或解决方案无关,更多是出于兴趣或社交性质的交流。',
@@ -179,16 +172,13 @@ def get_intention_prompt(
}
)
-INTENTIONS_CONSULT_WHICH = (INTENTION_ALLPLAN, INTENTION_NEXTSTEP, INTENTION_SEVERALSTEPS, INTENTION_BACKGROUND, INTENTION_CHAT)
+INTENTIONS_CONSULT_WHICH = (INTENTION_ALLPLAN, INTENTION_NEXTSTEP, INTENTION_CHAT)
CONSULT_WHICH_PROMPT = get_intention_prompt(
intentions=INTENTIONS_CONSULT_WHICH,
examples={
'如何组织一次活动?': INTENTION_ALLPLAN,
'系统升级的整个流程是怎样的?': INTENTION_ALLPLAN,
'为什么我没有收到红包?请告诉我方案': INTENTION_ALLPLAN,
- '如果我想学习一门新语言,第一步我需要先做些什么?': INTENTION_NEXTSTEP,
- '项目开发中代码开发完成后需要经过哪几步测试才能发布到生产呢?': INTENTION_SEVERALSTEPS,
- '请问下狼人杀游戏中猎人的主要职责是什么?': INTENTION_BACKGROUND,
'听说你们采用了新工具,能讲讲它的特点吗?': INTENTION_CHAT
}
)
diff --git a/muagent/connector/memory/hierarchical_memory_manager.py b/muagent/connector/memory/hierarchical_memory_manager.py
index 180dc35..7ccac84 100644
--- a/muagent/connector/memory/hierarchical_memory_manager.py
+++ b/muagent/connector/memory/hierarchical_memory_manager.py
@@ -15,7 +15,8 @@
from muagent.connector.memory_manager import BaseMemoryManager
from muagent.llm_models import *
from muagent.base_configs.env_config import KB_ROOT_PATH
-from muagent.orm import table_init
+# from muagent.orm import table_init
+from muagent.db_handler import table_init
from muagent.utils.common_utils import *
diff --git a/muagent/connector/memory_manager.py b/muagent/connector/memory_manager.py
index e536466..0c7dedf 100644
--- a/muagent/connector/memory_manager.py
+++ b/muagent/connector/memory_manager.py
@@ -19,7 +19,8 @@
from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
from muagent.retrieval.utils import load_embeddings_from_path
from muagent.utils.common_utils import *
-from muagent.orm import table_init
+# from muagent.orm import table_init
+from muagent.db_handler import table_init
from muagent.base_configs.env_config import KB_ROOT_PATH
# from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD
# from configs.model_config import embedding_model_dict
diff --git a/muagent/connector/schema/general_schema.py b/muagent/connector/schema/general_schema.py
index d87c373..350f72c 100644
--- a/muagent/connector/schema/general_schema.py
+++ b/muagent/connector/schema/general_schema.py
@@ -26,7 +26,7 @@ class ActionStatus(Enum):
def __eq__(self, other):
if isinstance(other, str):
- return self.value.lower() == other.lower()
+ return self.value.strip().lower() == other.strip().lower()
return super().__eq__(other)
@@ -198,6 +198,10 @@ def __le__(self, other):
@classmethod
def ge(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']):
return enum_value <= other
+
+ @classmethod
+ def le(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']):
+ return enum_value <= other
class Task(BaseModel):
diff --git a/muagent/connector/utils.py b/muagent/connector/utils.py
index ea2377e..68ff47c 100644
--- a/muagent/connector/utils.py
+++ b/muagent/connector/utils.py
@@ -1,3 +1,6 @@
+from typing import (
+ Dict,
+)
import re, copy, json
from loguru import logger
@@ -70,8 +73,9 @@ def parse_section_to_dict(text, section_name):
def parse_text_to_dict(text):
- # Define a regular expression pattern to capture the key and value
- main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)"
+ """through a regular expression pattern to capture the key and value"""
+ # main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)"
+ main_pattern = r'\*\*([^*]+):\*\*\s*(.*?)(?=\*\*([^*]+):\*\*|$)'
list_pattern = r'```python\n(.*?)```'
plan_pattern = r'(\[\s*.*?\s*\])'
@@ -79,7 +83,10 @@ def parse_text_to_dict(text):
main_matches = re.findall(main_pattern, text, re.DOTALL)
# Convert main matches to a dictionary
- parsed_dict = {key.strip(): value.strip() for key, value in main_matches}
+ parsed_dict = {
+ v[0].strip(): v[1].strip()
+ for v in main_matches
+ }
for k, v in parsed_dict.items():
for pattern in [list_pattern, plan_pattern]:
@@ -94,12 +101,13 @@ def parse_text_to_dict(text):
return parsed_dict
-def parse_dict_to_dict(parsed_dict) -> dict:
+def parse_dict_to_dict(parsed_dict: Dict) -> Dict:
+ """through a regular expression pattern to decode ```python/json/java``` into fragment"""
code_pattern = r'```python\n(.*?)```'
tool_pattern = r'```json\n(.*?)```'
java_pattern = r'```java\n(.*?)```'
- pattern_dict = {"code": code_pattern, "json": tool_pattern, "java": java_pattern}
+ pattern_dict = {"python": code_pattern, "json": tool_pattern, "java": java_pattern}
spec_parsed_dict = copy.deepcopy(parsed_dict)
for key, pattern in pattern_dict.items():
for k, text in parsed_dict.items():
diff --git a/muagent/db_handler/__init__.py b/muagent/db_handler/__init__.py
index aebb417..f475467 100644
--- a/muagent/db_handler/__init__.py
+++ b/muagent/db_handler/__init__.py
@@ -8,10 +8,28 @@
from .graph_db_handler import NebulaHandler, NetworkxHandler, AliYunSLSHandler, GeaBaseHandler, GBHandler
from .vector_db_handler import LocalFaissHandler, TbaseHandler, ChromaHandler
-
+from .db import _engine, Base
__all__ = [
"GBHandler", "NebulaHandler", "NetworkxHandler", "GeaBaseHandler",
"ChromaHandler", "TbaseHandler", "LocalFaissHandler",
"AliYunSLSHandler"
-]
\ No newline at end of file
+]
+
+
+def create_tables():
+ Base.metadata.create_all(bind=_engine)
+
+def reset_tables():
+ Base.metadata.drop_all(bind=_engine)
+ create_tables()
+
+
+def check_tables_exist(table_name) -> bool:
+ table_exist = _engine.dialect.has_table(_engine.connect(), table_name, schema=None)
+ return table_exist
+
+def table_init():
+ if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")) or \
+ (not check_tables_exist ("code_base")):
+ create_tables()
diff --git a/muagent/orm/db.py b/muagent/db_handler/db.py
similarity index 100%
rename from muagent/orm/db.py
rename to muagent/db_handler/db.py
diff --git a/muagent/db_handler/graph_db_handler/nebula_handler.py b/muagent/db_handler/graph_db_handler/nebula_handler.py
index 9d94180..eba6e95 100644
--- a/muagent/db_handler/graph_db_handler/nebula_handler.py
+++ b/muagent/db_handler/graph_db_handler/nebula_handler.py
@@ -57,7 +57,6 @@ def __init__(self,gb_config : GBConfig = None):
self.connection_pool = ConnectionPool()
if gb_config == None:
-
self.connection_pool.init([('graphd', '9669')], config)
self.username = '' or 'root'
self.nb_pw = '' or 'nebula'
@@ -116,7 +115,7 @@ def execute_cypher(self, cypher: str, space_name: str = '',ignore_log: bool = Fa
if ignore_log == False:
if resp.is_succeeded():
- #logger.info(f"Successfully executed Cypher query: {cypher}")
+ # logger.info(f"Successfully executed Cypher query: {cypher}")
pass
@@ -165,18 +164,20 @@ def execute_cypher_return_status(self, cypher: str, space_name: str = '', format
errorMessage=resp.error_msg(),
errorCode=resp.error_code(),
)
-
def add_hosts(self, hostname, port):
+ while not self.is_host_connected(hostname, port):
with self.connection_pool.session_context(self.username, self.nb_pw) as session:
cypher = f'ADD HOSTS "{hostname}":{port}'
resp = session.execute(cypher)
- return resp
+ print('增加NebulaGraph Storage主机中,等待20秒')
+ time.sleep(20)
+ return
def close_connection(self):
self.connection_pool.close()
- def create_space(self, space_name: str, vid_type: str = 'FIXED_STRING(32)', comment: str = ''):
+ def create_space(self, space_name: str, vid_type: str = 'FIXED_STRING(1024)', comment: str = ''):
'''
create space
@param space_name: cannot startwith number
@@ -277,6 +278,20 @@ def show_edge_type(self):
resp = self.execute_cypher(cypher, self.space_name)
return resp
+ def is_host_connected(self, hostname, port):
+ # 查询系统表以检查主机的连接状态
+ with self.connection_pool.session_context(self.username, self.nb_pw) as session:
+ cypher = 'SHOW HOSTS'
+ resp = session.execute(cypher)
+
+ resp = resp.as_primitive()
+ # 假设返回结果中包含一个名为 'host' 的字段
+ for i in resp:
+ if hostname==i['Host'] and port == i['Port'] and i["Status"] =="ONLINE":
+ return True
+
+ return False
+
def delete_edge_type(self, edge_type_name: str):
cypher = f'DROP EDGE {edge_type_name}'
return self.execute_cypher(cypher, self.space_name)
diff --git a/muagent/db_handler/vector_db_handler/local_faiss_handler.py b/muagent/db_handler/vector_db_handler/local_faiss_handler.py
index f18605c..e43d429 100644
--- a/muagent/db_handler/vector_db_handler/local_faiss_handler.py
+++ b/muagent/db_handler/vector_db_handler/local_faiss_handler.py
@@ -1,11 +1,15 @@
from loguru import logger
-from typing import List
+from typing import List, Union
from functools import lru_cache
import os, shutil
from langchain.embeddings.base import Embeddings
from langchain_community.docstore.document import Document
+
+from muagent.models import get_model
+from muagent.schemas.models import ModelConfig
+
from muagent.utils.server_utils import torch_gc
from muagent.retrieval.base_service import SupportedVSType
from muagent.retrieval.faiss_m import FAISS
@@ -23,17 +27,20 @@ class LocalFaissHandler:
def __init__(
self,
- embed_config: EmbedConfig,
+ embed_config: Union[EmbedConfig, ModelConfig],
vb_config: VBConfig = None
):
self.vb_config = vb_config
self.embed_config = embed_config
- self.embeddings = load_embeddings_from_path(
- self.embed_config.embed_model_path,
- self.embed_config.model_device,
- self.embed_config.langchain_embeddings
- )
+ if isinstance(self.embed_config, ModelConfig):
+ self.embeddings = get_model(self.embed_config)
+ else:
+ self.embeddings = load_embeddings_from_path(
+ self.embed_config.embed_model_path,
+ self.embed_config.model_device,
+ self.embed_config.langchain_embeddings
+ )
# INIT
self.search_index: FAISS = None
diff --git a/muagent/ekg_project.py b/muagent/ekg_project.py
new file mode 100644
index 0000000..d5b3949
--- /dev/null
+++ b/muagent/ekg_project.py
@@ -0,0 +1,659 @@
+from typing import (
+ Union,
+ Sequence,
+ Literal,
+ Mapping,
+ Optional,
+ Dict,
+ List
+)
+from pydantic import BaseModel
+import os
+import json
+from loguru import logger
+import concurrent.futures
+import time
+import random
+
+from .llm_models import LLMConfig, EmbedConfig
+from .schemas.db import TBConfig, GBConfig
+from .schemas.models import ModelConfig
+from .schemas import EKGProjectConfig, Message, Memory, AgentConfig, PromptConfig
+from .schemas.common import GNode, GEdge
+from .db_handler import *
+from .agents import FunctioncallAgent
+
+from .service.utils import decode_biznodes, encode_biznodes
+
+# from .connector.schema import Memory, Message
+from .connector.memory_manager import TbaseMemoryManager
+from .service.ekg_construct.ekg_construct_base import EKGConstructService
+from .service.ekg_inference import IntentionRouter
+from .service.ekg_reasoning.src.graph_search.graph_search_main import main as reasoning
+
+
+class LingSiResponse(BaseModel):
+ '''lingsi的输出值, 算法的输入值
+ The following is an example:
+
+ .. code-block:: python
+
+ from xx import LingSiResponse
+ ls_resp = LingSiResponse(
+ observation={'content': '一起来玩谁是卧底'},
+ sessionId='default_sessionId',
+ scene="UNDERCOVER",
+ )
+
+ ls_resp = LingSiResponse(
+ observation={'toolResponse': '我的单词是一种工业品'},
+ currentNodeId='剧本杀/谁是卧底/智能交互/开始新一轮的讨论'
+ sessionId='default_sessionId',
+ scene="UNDERCOVER",
+ type='reactExecution'
+ )
+ '''
+ sessionId: str
+ """The session index"""
+
+ currentNodeId: Optional[str] = None
+ """The last node index, the first is null"""
+
+ type: Optional[str] = None
+ """The last execute type, the first is null"""
+
+ agentName:Optional[str]=None
+ """The agent name from last node output, the first is null"""
+
+ scene: Literal["UNDERCOVER", "WEREWOLF" , "NEXA" ] = "NEXA"
+ """The scene type of this task."""
+
+ observation: Optional[Union[str,Dict]] # jsonstr
+ '''last observation from last node
+ .. code-block:: python
+ observation: Literal["content", "tool_response"]
+ '''
+
+ userAnswer: Optional[str]=None
+ """no use"""
+
+ startRootNodeId: Optional[str] = ''
+ """The default team root id"""
+
+ intentionData: Optional[Union[List,str] ] = None
+ """equal query, only once at first"""
+
+ startFromRoot: Literal['True', 'false', 'true', 'False'] = 'True'
+ """"""
+
+ intentionRule: Optional[Union[List,str]]= ["nlp"]
+ """no use"""
+
+
+class QuestionContent(BaseModel):
+ '''
+ {'question': '请玩家根据当前情况发言', 'candidate': None }
+ '''
+ question:str
+ candidate:Optional[str]=None
+
+class QuestionDescription(BaseModel):
+ '''
+ {'questionType': 'essayQuestion',
+ 'questionContent': {'question': '请玩家根据当前情况发言','candidate': None }}
+ '''
+ questionType: Literal["essayQuestion", "multipleChoice"] = "essayQuestion"
+ questionContent: QuestionContent
+
+class ToolPlanOneStep(BaseModel):
+ '''
+ tool_plan_one_step = {'toolDescription': '请用户回答',
+ 'currentNodeId': nodeId,
+ 'memory': None,
+ 'type': 'userProblem',
+ 'questionDescription': {'questionType': 'essayQuestion',
+ 'questionContent': {'question': '请玩家根据当前情况发言',
+ 'candidate': None }}}
+ '''
+ currentNodeId: Optional[str] = None
+ """from last node index"""
+
+ toolDescription:str
+ """the input for functioncalling"""
+
+ currentNodeInfo: Optional[str] = None
+ """equal agent name"""
+
+ memory: Optional[str] = None
+ """memory"""
+
+ questionDescription: Optional[QuestionDescription]=None
+ """反问的过程"""
+
+ type: Optional[Literal["onlyTool", "userProblem", "reactExecution"]] = None
+ """request type"""
+
+
+class ResToLingsi(BaseModel):
+ '''lingsi的输入值, 算法的输出值
+ The following is an example:
+
+ .. code-block:: python
+
+ from xx import ResToLingsi
+ resp_to_ls = ResToLingsi(
+ sessionId = "default_sessionId",
+ type="onlyTool",
+ summary=None,
+ toolPlan=ToolPlan(
+ toolDescription="agent_李静",
+ currentNodeId='26921eb05153216c5a1f585f9d318c77%%@@#agent_李静',
+ currentNodeInfo='agent_李静',
+ memory="",
+ questionDescription=None,
+ type="reactExecution"
+ userInteraction='开始新一轮的讨论
**主持人:**
各位玩家请注意,现在所有玩家均存活,我们将按照座位顺序进行发言。发言顺序为1号李静、2号张伟、3号人类玩家、4号王鹏。现在,请1号李静开始发言。'
+ intentionRecognitionSituation=None,
+ )
+ '''
+ sessionId: str
+ """session index from last node output"""
+
+ toolPlan:Optional[List[ToolPlanOneStep]] = None
+ """"""
+
+ userInteraction:Optional[str]=None
+ """if userInteraction, yield"""
+
+ summary: Optional[str] = None
+ """if summary, end, yield"""
+
+ type: Optional[str] = None
+ """no use"""
+
+ intentionRecognitionSituation: Optional[str]=None
+ """no use"""
+
+
+def get_ekg_project_config_from_env(
+ model_configs: Optional[Dict[str, Union[LLMConfig, ModelConfig]]] = None,
+ embed_configs: Optional[Dict[str, Union[EmbedConfig, ModelConfig]]] = None,
+ db_configs: Optional[Mapping[str, Union[GBConfig, TBConfig]]] = None,
+ agent_configs: Optional[Mapping[str, AgentConfig]] = None,
+ prompt_configs: Optional[Mapping[str,PromptConfig]] = None,
+) -> EKGProjectConfig:
+ """"""
+ project_configs = {
+ "model_configs": {},
+ "embed_configs": {},
+ "db_configs": {},
+ "agent_configs": {},
+ "prompt_configs": {},
+ }
+ #
+ db_config_name_to_class = {
+ "gb_config": GBConfig,
+ "tb_config": TBConfig,
+ }
+ # init model configs
+ if model_configs:
+ for k, v in model_configs.items():
+ if isinstance(v, LLMConfig) or isinstance(v, ModelConfig):
+ project_configs["model_configs"][k] = v
+ else:
+ try:
+ project_configs["model_configs"][k] = ModelConfig(**v)
+ except:
+ project_configs["model_configs"][k] = LLMConfig(**v)
+ elif "model_configs".upper() in os.environ:
+ _model_configs = json.loads(os.environ["model_configs".upper()])
+ for k, v in _model_configs.items():
+ try:
+ project_configs["model_configs"][k] = ModelConfig(**v)
+ except:
+ project_configs["model_configs"][k] = LLMConfig(**v)
+
+ chat_list = [_type for _type in project_configs["model_configs"].keys() if "chat" in _type]
+ embedding_list = [_type for _type in project_configs["model_configs"].keys() if "embedding" in _type]
+ if chat_list:
+ model_type = random.choice(chat_list)
+ default_model_config = project_configs["model_configs"][model_type]
+ project_configs["model_configs"]["default_chat"] = default_model_config
+ os.environ["DEFAULT_MODEL_TYPE"] = model_type
+ os.environ["DEFAULT_MODEL_NAME"] = default_model_config.model_name
+ os.environ["DEFAULT_API_KEY"] = default_model_config.api_key or ""
+ os.environ["DEFAULT_API_URL"] = default_model_config.api_url or ""
+
+ if embedding_list:
+ model_type = random.choice(embedding_list)
+ default_model_config = project_configs["model_configs"][model_type]
+ project_configs["model_configs"]["default_embed"] = default_model_config
+ project_configs[k] = v
+
+ # init embedding configs
+ if embed_configs:
+ for k, v in embed_configs.items():
+ if isinstance(v, EmbedConfig) or isinstance(v, ModelConfig):
+ project_configs["embed_configs"][k] = v
+ else:
+ try:
+ project_configs["embed_configs"][k] = EmbedConfig(**v)
+ except:
+ project_configs["embed_configs"][k] = ModelConfig(**v)
+ elif "embed_configs".upper() in os.environ:
+ embed_configs = json.loads(os.environ["embed_configs".upper()])
+ for k, v in embed_configs.items():
+ if isinstance(v, EmbedConfig) or isinstance(v, ModelConfig):
+ project_configs["embed_configs"][k] = v
+ else:
+ try:
+ project_configs["embed_configs"][k] = EmbedConfig(**v)
+ except:
+ project_configs["embed_configs"][k] = ModelConfig(**v)
+
+ # init db configs
+ db_configs = db_configs or json.loads(os.environ["DB_CONFIGS"])
+ for k in ["tb_config", "gb_config"]:
+ if db_configs and k not in db_configs:
+ raise KeyError(
+ f"EKG must have {k}. "
+ f"please check your env config or input."
+ )
+ else:
+ project_configs["db_configs"][k] = db_config_name_to_class[k](
+ **db_configs[k])
+
+ # init agent configs
+ if "AGENT_CONFIGS" in os.environ:
+ agent_configs = agent_configs or json.loads(os.environ["AGENT_CONFIGS"])
+ agent_configs = {
+ kk: AgentConfig(**vv)
+ for kk, vv in agent_configs.items()
+ }
+ project_configs["agent_configs"] = agent_configs
+ else:
+ logger.warning(
+ f"Cant't init any AGENT_CONFIGS in this env."
+ )
+
+ # init prompt configs
+ if "PROMPT_CONFIGS" in os.environ:
+ prompt_configs = prompt_configs or json.loads(os.environ["PROMPT_CONFIGS"])
+ prompt_configs = {
+ kk: PromptConfig(**vv)
+ for kk, vv in prompt_configs.items()
+ }
+ project_configs["prompt_configs"] = prompt_configs
+ else:
+ logger.warning(
+ f"Cant't init any AGENT_CONFIGS in this env."
+ )
+
+
+ return EKGProjectConfig(**project_configs)
+
+
+class EKG:
+ """Class to represent and manage the EKG project."""
+
+ def __init__(
+ self,
+ tb_config: Optional[TBConfig] = None,
+ gb_config: Optional[GBConfig] = None,
+ embed_config: Union[ModelConfig, EmbedConfig] = None,
+ llm_config: Union[ModelConfig, LLMConfig] = None,
+ project_config: EKGProjectConfig = None,
+ agents: List[str] = [],
+ tools: List[str] = [],
+ *,
+ initialize_space = True
+ ):
+
+ # Initialize various configuration settings for the EKG project.
+ self.tb_config = tb_config
+ self.gb_config = gb_config
+ self.embed_config = embed_config
+ self.llm_config = llm_config
+ self.project_config = project_config
+ self.agents = agents
+ self.tools = tools
+
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
+ self.futures = []
+ # Set whether to initialize space
+ self.initialize_space = initialize_space
+ self.init_from_project()
+
+ @classmethod
+ def from_project(cls, project_config: EKGProjectConfig, initialize_space=False) -> 'EKG':
+ """Create an instance of EKG from a project configuration."""
+ return cls(project_config=project_config, initialize_space=initialize_space)
+
+ def init_from_project(self):
+ """Initialize settings from the provided project configuration."""
+
+ # Setup the time-based configuration
+ if self.project_config and self.project_config.db_configs:
+ self.tb_config = self.tb_config or \
+ self.project_config.db_configs.get("tb_config")
+ elif self.tb_config:
+ pass
+ else:
+ raise KeyError(
+ f"EKG Project must have 'tb_config' in "
+ f"db_configs"
+ )
+
+ # Setup the graph-based configuration
+ if self.project_config and self.project_config.db_configs:
+ self.gb_config = self.gb_config or \
+ self.project_config.db_configs.get("gb_config")
+ elif self.gb_config:
+ pass
+ else:
+ raise KeyError(
+ f"EKG Project must have 'gb_config' in "
+ f"db_configs"
+ )
+
+ # Setup embedding configuration
+ if self.project_config and self.project_config.embed_configs:
+ if "default" not in self.project_config.embed_configs:
+ raise KeyError(
+ f"EKG Project must have key=default in "
+ f"embed_configs"
+ )
+ self.embed_config = self.project_config.embed_configs.get("default")
+
+ # Setup LLM configuration and environment variables
+ if self.project_config and self.project_config.model_configs:
+ # if "default_chat" not in self.project_config.llm_configs:
+ # raise KeyError(
+ # f"EKG Project must have key=default in "
+ # f"llm_configs"
+ # )
+
+ # os.environ["API_BASE_URL"] = self.project_config.llm_configs["default"].api_base_url
+ # os.environ["OPENAI_API_KEY"] = self.project_config.llm_configs["default"].api_key
+ # os.environ["model_name"] = self.project_config.llm_configs["default"].model_name
+ # os.environ["model_engine"] = self.project_config.llm_configs["default"].model_engine
+ # os.environ["llm_temperature"] = self.project_config.llm_configs["default"].temperature
+ self.llm_config = self.project_config.model_configs.get("default_chat")
+ self.llm_config = LLMConfig(
+ model_name=os.environ["model_name"],
+ model_engine=os.environ["model_engine"],
+ api_key=os.environ["OPENAI_API_KEY"],
+ api_base_url=os.environ["API_BASE_URL"],
+ )
+
+ # Ensure 'codefuser' config exists
+ if "codefuser" not in self.project_config.model_configs:
+ raise KeyError(
+ f"EKG Project must have key=codefuser in "
+ f"llm_configs"
+ )
+
+ os.environ["gpt4-API_BASE_URL"] = self.project_config.model_configs["codefuser"].api_base_url
+ os.environ["gpt4-OPENAI_API_KEY"] = self.project_config.model_configs["codefuser"].api_key
+ os.environ["gpt4-model_name"] = self.project_config.model_configs["codefuser"].model_name
+ os.environ["gpt4-model_engine"] = self.project_config.model_configs["codefuser"].model_engine
+ os.environ["gpt4-llm_temperature"] = self.project_config.model_configs["codefuser"].temperature
+
+ self._init_ekg_construt_service() # Initialize the EKG construction service
+ self._init_memory_manager() # Initialize the memory manager
+ self._init_intention_router() # Initialize the intention router
+
+ def _init_ekg_construt_service(self):
+ """Initialize the service responsible for building the EKG graph."""
+ self.ekg_construct_service = EKGConstructService(
+ embed_config=self.embed_config,
+ llm_config=self.llm_config,
+ tb_config=self.tb_config,
+ gb_config=self.gb_config,
+ initialize_space=self.initialize_space
+ )
+
+ def _init_memory_manager(self):
+ """Initialize the memory manager with the appropriate configuration."""
+ tb = TbaseHandler(
+ self.tb_config,
+ self.tb_config.index_name,
+ definition_value=self.tb_config.extra_kwargs.get(
+ "memory_definition_value")
+ )
+
+ self.memory_manager = TbaseMemoryManager(
+ unique_name="EKG",
+ embed_config=self.embed_config,
+ llm_config=self.llm_config,
+ tbase_handler=tb, # Use the Tbase handler for database management
+ use_vector=False
+ )
+
+ def _init_intention_router(self):
+ """Initialize the routing mechanism for intentions within the EKG project."""
+ self.intention_router = IntentionRouter(
+ self.ekg_construct_service.model,
+ self.ekg_construct_service.gb,
+ self.ekg_construct_service.tb,
+ self.embed_config
+ )
+
+ def __call__(self):
+ """Call method for EKG class instance (to be implemented)."""
+ pass
+
+ def add_node(
+ self,
+ node: Union[Dict, GNode],
+ *,
+ teamid: str = "default",
+ ) -> None:
+ """Add a node to the EKG graph."""
+ gnode = GNode(**node) if isinstance(node, Dict) else node
+ gnodes, _ = decode_biznodes([gnode]) # Decode the business nodes
+ self.ekg_construct_service.add_nodes(gnodes, teamid) # Add nodes to the construct service
+
+ def add_edge(
+ self,
+ start_id: str,
+ end_id: str,
+ *,
+ teamid: str = "",
+ ) -> None:
+ """Add an edge between two nodes in the EKG graph."""
+ start_node = self.ekg_construct_service.get_node_by_id(start_id)
+ end_node = self.ekg_construct_service.get_node_by_id(end_id)
+
+ # If both start and end nodes exist, create an edge
+ if start_node and end_node:
+ edge = {
+ "start_id": start_id,
+ "end_id": end_id,
+ "type": f"{start_node.type}_route_{end_node.type}",
+ "attributes": {}
+ }
+ edges = [GEdge(**edge)] # Create an edge object
+ self.ekg_construct_service.add_edges(edges, teamid) # Add edges to the construct service
+
+ def run(
+ self,
+ query: str,
+ scene: str = "NEXA",
+ rootid: str = "ekg_team_default",
+ ):
+ """Run the EKG processing with the provided query and scene."""
+ import uuid
+ sessionId = str(uuid.uuid4()).replace("-", "") # Generate a unique session ID
+ request = LingSiResponse(
+ observation={"content": query},
+ intentionData=query,
+ startRootNodeId=rootid,
+ sessionId=sessionId,
+ scene=scene,
+ )
+ logger.error(query)
+
+ summary = "" # Initialize summary variable
+ history_done = []
+ while True:
+ # Wait for the first completed future object
+ done, not_done = concurrent.futures.wait(
+ self.futures, return_when=concurrent.futures.FIRST_COMPLETED
+ )
+ history_done.extend(done)
+ for future in done:
+ self.futures.remove(future)
+
+ if history_done:
+ # for future in done:
+ future = history_done.pop(0)
+ try:
+ result = future.result() # Retrieve the result of the completed task
+ # logger.error(f"Task completed: {result}")
+ # self.futures.remove(future) # Remove completed task from the futures list
+
+ # Assemble the new request with the result data
+ request = LingSiResponse(
+ observation={"toolResponse": result.get("toolResponse")},
+ currentNodeId=result.get("currentNodeId"),
+ type=result.get("type"),
+ agentName=result.get("agentName"),
+ startRootNodeId=rootid,
+ sessionId=sessionId,
+ scene=scene,
+ )
+ except Exception as e:
+ logger.error(f"Task generated an exception: {e}")
+
+ # Perform inference using the request
+ if request:
+ # logger.error(f"{request}")
+ result = reasoning(
+ request.dict(),
+ self.memory_manager,
+ self.ekg_construct_service.gb,
+ self.intention_router,
+ self.llm_config
+ )
+ # logger.error(f"{result}")
+ # Yield user interaction if present
+ if result.get("userInteraction"):
+ print(result["userInteraction"])
+ yield result["userInteraction"]
+
+ summary = summary or result.get("summary") # Update summary if empty
+
+ # If a summary is available, yield it and break the loop
+ if summary:
+ print(summary)
+ yield summary
+ break
+
+ # Update tasks in the pool based on the result
+ user_tasks = []
+ toolPlans = result.get("toolPlan", []) or []
+ for toolplan in toolPlans:
+ # 当存在关键信息,直接返回 " "
+ if "关键信息" in toolplan["toolDescription"]:
+ self.futures.append(
+ self.executor.submit(
+ self.empty_function,
+ **{
+ "content": " ",
+ "toolResponse": " ",
+ "type": toolplan["type"],
+ "currentNodeId": toolplan["currentNodeId"],
+ "agentName": toolplan.get("currentNodeInfo"),
+ }
+ )
+ )
+ continue
+
+ # if toolplan
+ if toolplan["type"] == "userProblem":
+ user_tasks.append(toolplan)
+
+ if toolplan["type"] in ["onlyTool", "reactExecution"]:
+ # Submit function call to the executor for execution
+ future = self.executor.submit(
+ self.function_call,
+ **{
+ "content": toolplan["toolDescription"],
+ "type": toolplan["type"],
+ "currentNodeId": toolplan["currentNodeId"],
+ "agentName": toolplan.get("currentNodeInfo"),
+ "memory": toolplan.get("memory"),
+ "toolDescription": toolplan.get("toolDescription")
+ }
+ )
+ self.futures.append(future)
+
+ # Process user tasks and gather user input
+ for user_task in user_tasks:
+ questionType = user_task.get(
+ "questionDescription").get("questionType")
+ user_query = user_task.get(
+ "questionDescription").get(
+ "questionContent").get(
+ "question")
+ if user_query is None or questionType is None:
+ continue
+
+ print(user_query)
+ yield user_query # Yield the user query for input
+ user_answer = input() # Get user input
+
+ # Submit the user answer as a future task
+ self.futures.append(
+ self.executor.submit(
+ self.empty_function,
+ **{
+ "content": user_answer,
+ "toolResponse": user_answer,
+ "type": user_task["type"],
+ "currentNodeId": user_task["currentNodeId"],
+ "agentName": user_task["currentNodeInfo"]
+ }
+ )
+ )
+
+ if not self.futures: # If there are no tasks, pause briefly before checking again
+ time.sleep(0.1)
+
+ # Reset request to prepare for the next inference loop
+ request = None
+ toolPlans = None
+ result = None
+
+ def empty_function(self, **kwargs) -> Dict:
+ """Return the input parameters as is (placeholder function)."""
+ return kwargs
+
+ def function_call(self, **kwargs) -> Dict:
+ """Perform a single step function call and return the result."""
+
+ function_caller = FunctioncallAgent(
+ agent_name="codefuse_function_caller", # Set the agent name
+ project_config=self.project_config, # Provide the project configuration
+ tools=self.tools # Provide the tools available for use
+ )
+
+ query = Message(
+ role_type="user",
+ content=f"帮我选择匹配的工具并进行执行,工具描述为'{kwargs['content']}'"
+ )
+ for msg in function_caller.step_stream(query, extra_params={"memory": kwargs.get("memory")}):
+ pass # Process the stream, if any
+
+ observation = ""
+ # Extract the observation from the processed messages
+ if msg.parsed_contents:
+ observation = msg.parsed_contents[-1].get("Observation", "")
+ result = {
+ "toolResponse": observation,
+ "currentNodeId": kwargs.get("currentNodeId"),
+ "type": kwargs.get("type"),
+ "agentName": kwargs.get("agentName"),
+ }
+ return result # Return the result of the function call
diff --git a/muagent/httpapis/ekg_construct/api.py b/muagent/httpapis/ekg_construct/api.py
index 9dd8f8e..7565bbc 100644
--- a/muagent/httpapis/ekg_construct/api.py
+++ b/muagent/httpapis/ekg_construct/api.py
@@ -1,9 +1,7 @@
from fastapi import FastAPI
from typing import Dict
-import asyncio
import uvicorn
from loguru import logger
-import tqdm
import ollama
import json
import os
@@ -126,6 +124,34 @@ async def update_llm_params(request: LLMRequest):
answer=answer
)
+ # ~/llm/generate
+ @app.post("/functioncall/chat", response_model=LLMFCResponse)
+ async def fc_chat(request: LLMFCRequest):
+ # 添加预测逻辑的代码
+ errorMessage = "ok"
+ successCode = True
+ choices = []
+ try:
+ model_names = [i["name"] for i in ollama.list()["models"]]
+ if llm.model_type=="ollama" and llm.model_name not in model_names:
+ errorMessage = f"{llm.model_name} not in ollama.list {model_names}. " \
+ f"please request llm/ollama/pull for downloading the ollama model"
+ successCode = False
+ else:
+ fc_output = llm.fc(request)
+ choices = fc_output.choices
+ except Exception as e:
+ logger.exception(e)
+ errorMessage = str(e)
+ successCode = False
+
+ logger.info(f"choices.type: {type(choices)}")
+ logger.info(f"choices {choices}")
+ return LLMFCResponse(
+ successCode=successCode, errorMessage=errorMessage,
+ choices=choices
+ )
+
# ~/embeddings/params
@app.get("/embeddings/params", response_model=EmbeddingsParamsResponse)
async def embedding_params():
diff --git a/muagent/llm_models/llm_shemas.py b/muagent/llm_models/llm_shemas.py
new file mode 100644
index 0000000..45f54e9
--- /dev/null
+++ b/muagent/llm_models/llm_shemas.py
@@ -0,0 +1,50 @@
+from pydantic import BaseModel, Field
+from typing import List, Dict, Optional, Union
+from enum import Enum
+
+
+
+class ChatMessage(BaseModel):
+ role: str
+ content: str
+
+
+class FunctionCallData(BaseModel):
+ name: str
+ arguments: Union[str, dict]
+
+
+class ToolCall(BaseModel):
+ id: Optional[Union[str, int]] = None
+ type: str = "function"
+ function: FunctionCallData
+
+
+class LLMOuputMessage(BaseModel):
+ content: Optional[str] = None
+ role: str
+ tool_calls: List[ToolCall] = []
+
+
+class Choice(BaseModel):
+ finish_reason: str
+ index: int = 0
+ message: LLMOuputMessage
+
+
+class UsageData(BaseModel):
+ completion_tokens: int
+ prompt_tokens: int
+ total_token: int
+
+
+class LLMResponse(BaseModel):
+ choices: List[Choice]
+ created: int = 0
+ id: str
+ model: str
+ object: str
+ usage: Optional[UsageData] = None
+
+
+
diff --git a/muagent/llm_models/openai_model.py b/muagent/llm_models/openai_model.py
index c56b988..ba4ee14 100644
--- a/muagent/llm_models/openai_model.py
+++ b/muagent/llm_models/openai_model.py
@@ -1,5 +1,6 @@
import os
-from typing import Union, Optional, List
+import re
+from typing import Union, Optional, List, Dict, Literal
from loguru import logger
from langchain.callbacks import AsyncIteratorCallbackHandler
@@ -8,7 +9,20 @@
from langchain.llms.base import LLM
from .llm_config import LLMConfig
+from .llm_shemas import *
+
+try:
+ import ollama
+except:
+ pass
+
# from configs.model_config import (llm_model_dict, LLM_MODEL)
+def replacePrompt(prompt: str, keys: list[str] = []):
+ prompt = prompt.replace("{", "{{").replace("}", "}}")
+ for key in keys:
+ prompt = prompt.replace(f"{{{{{key}}}}}", f"{{{key}}}")
+ return prompt
+
class CustomLLMModel:
@@ -32,6 +46,100 @@ def batch(self, prompts: str,
stop: Optional[List[str]] = None):
return [self(prompt, stop) for prompt in prompts]
+ def fc(
+ self,
+ messages: List[ChatMessage],
+ tools: List[Union[str, object]] = [],
+ system_prompt: Optional[str] = None,
+ tool_choice: Optional[Literal["auto", "required"]] = "auto",
+ parallel_tool_calls: Optional[bool] = None,
+ stop: Optional[List[str]] = None,
+ **kwargs
+ ) -> LLMResponse:
+ '''
+ '''
+ from muagent.base_configs.prompts.functioncall_template_prompt import (
+ FUNCTION_CALL_PROMPT_en,
+ FC_AUTO_PROMPT_en,
+ FC_REQUIRED_PROMPT_en,
+ FC_PARALLEL_PROMPT_en,
+ FC_RESPONSE_PROMPT_en
+ )
+
+ use_tools = len(tools) > 0
+
+ prompts = []
+ if use_tools:
+ prompts.append(FUNCTION_CALL_PROMPT_en)
+
+ if system_prompt:
+ prompts.append(system_prompt)
+
+ if use_tools and tool_choice =="auto":
+ prompts.append(FC_AUTO_PROMPT_en)
+ elif use_tools and tool_choice =="required":
+ prompts.append(FC_REQUIRED_PROMPT_en)
+
+ if use_tools and parallel_tool_calls:
+ prompts.append(FC_PARALLEL_PROMPT_en)
+
+ prompts.append("you are a helpful assistant to respond user's question:\n## Question Input\n{content}")
+
+ if use_tools:
+ prompts.append(FC_RESPONSE_PROMPT_en)
+
+ system_prompt = "\n".join(prompts)
+ #
+ content = "\n\n".join([f"{i.role}: {i.content}" for i in messages])
+ content = "\n\n".join([f"{i.content}" for i in messages])
+ if use_tools:
+ system_prompt = replacePrompt(system_prompt, keys=["content", "tool_desc"])
+ prompt = system_prompt.format(content=content, tool_desc="\n".join(tools))
+ else:
+ system_prompt = replacePrompt(system_prompt, keys=["content"])
+ prompt = system_prompt.format(content=content)
+
+ llm_output = self.predict(prompt)
+
+ # logger.info(f"prompt: {prompt}")
+ # logger.info(f"llm_output: {llm_output}")
+ # parse llm functioncall
+ if "```json" in llm_output:
+ match_value = re.search(r'```json\n(.*?)```', llm_output, re.DOTALL)
+ else:
+ match_value = llm_output
+
+ try:
+ function_call_output = json.loads(match_value.group(1).strip())
+ except:
+ function_call_output = eval(match_value.group(1).strip())
+
+ function_call_output = function_call_output if isinstance(function_call_output, list) \
+ else [function_call_output]
+ #
+ fc_response = LLMResponse(
+ choices=[Choice(
+ finish_reason="tool_calls",
+ message=LLMOuputMessage(
+ content=None,
+ role="assistant",
+ tool_calls=[
+ ToolCall(
+ function=FunctionCallData(
+ name=fco["name"],
+ arguments=fco["arguments"],
+ )
+ )
+ for fco in function_call_output
+ ],
+ )
+ )],
+ id="",
+ model="",
+ object="chat.completion",
+ usage=None
+ )
+ return fc_response
class OpenAILLMModel(CustomLLMModel):
@@ -128,6 +236,24 @@ def __init__(self, llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler
)
+class OllamaLLMModel(CustomLLMModel):
+ def __init__(self, llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None,):
+ self.llm = ollama.Client()
+ self.model_name = llm_config.model_name
+
+ def __call__(self, prompt: str,
+ stop: Optional[List[str]] = None):
+ stream = self.llm.generate(
+ model=self.model_name,
+ prompt=prompt,
+ stream=True,
+ )
+ answer = ""
+ for chunk in stream:
+ answer += chunk['response']
+ return answer
+
+
class KIMILLMModel(LYWWLLMModel):
pass
@@ -143,7 +269,7 @@ def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbac
model_class_dict = {
"openai": OpenAILLMModel, "lingyiwanwu": LYWWLLMModel,
"kimi": KIMILLMModel, "moonshot": KIMILLMModel,
- "qwen": QwenLLMModel,
+ "qwen": QwenLLMModel, "ollama": OllamaLLMModel
}
model_class = model_class_dict[llm_config.model_engine]
model = model_class(llm_config, callBack)
diff --git a/muagent/memory/__init__.py b/muagent/memory/__init__.py
deleted file mode 100644
index 719b23b..0000000
--- a/muagent/memory/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .hierarchical_memory_manager import HierarchicalMemoryManager
-
-__all__ = [
- "HierarchicalMemoryManager"
-]
\ No newline at end of file
diff --git a/muagent/memory_manager/__init__.py b/muagent/memory_manager/__init__.py
new file mode 100644
index 0000000..59215ce
--- /dev/null
+++ b/muagent/memory_manager/__init__.py
@@ -0,0 +1,13 @@
+from .hierarchical_memory_manager import HierarchicalMemoryManager
+from .base_memory_manager import BaseMemoryManager
+from .local_memory_manager import LocalMemoryManager
+from .tbase_memory_manager import TbaseMemoryManager
+
+
+
+__all__ = [
+ "BaseMemoryManager",
+ "LocalMemoryManager",
+ "TbaseMemoryManager",
+ "HierarchicalMemoryManager"
+]
\ No newline at end of file
diff --git a/muagent/memory_manager/base_memory_manager.py b/muagent/memory_manager/base_memory_manager.py
new file mode 100644
index 0000000..e2ad8e2
--- /dev/null
+++ b/muagent/memory_manager/base_memory_manager.py
@@ -0,0 +1,271 @@
+from abc import abstractmethod, ABC
+from typing import (
+ List,
+ Dict,
+ Optional
+)
+from loguru import logger
+
+from ..schemas import Memory, Message
+from ..schemas.db import DBConfig, GBConfig, VBConfig, TBConfig
+from ..schemas.models import ModelConfig
+from ..db_handler import *
+
+# from muagent.orm import table_init
+from muagent.db_handler import table_init
+
+
+class BaseMemoryManager(ABC):
+ """
+ This class represents a local memory manager that inherits from BaseMemoryManager.
+
+ Attributes:
+ - memory_type: A string representing the memory type. Default is "recall".
+ - do_init: A boolean indicating whether to initialize. Default is False.
+ - recall_memory: An instance of Memory class representing the recall memory.
+ - save_message_keys: A list of strings representing the keys for saving messages.
+
+ Methods:
+ - __init__: Initializes the LocalMemoryManager with the given user_name, unique_name, memory_type, and do_init.
+ - init_vb: Initializes the vb.
+ - append: Appends a message to the recall memory, current memory, and summary memory.
+ - extend: Extends the recall memory, current memory, and summary memory.
+ - load: Loads the memory from the specified directory and returns a Memory instance.
+ - router_retrieval: Routes the retrieval based on the retrieval type.
+ - embedding_retrieval: Retrieves messages based on embedding.
+ - text_retrieval: Retrieves messages based on text.
+ - datetime_retrieval: Retrieves messages based on datetime.
+ - recursive_summary: Performs recursive summarization of messages.
+ """
+
+ memory_manager_type: str = "base_memory_manager"
+ """The type of memory manager"""
+
+ def __init__(
+ self,
+ vb_config: Optional[VBConfig] = None,
+ db_config: Optional[DBConfig] = None,
+ gb_config: Optional[GBConfig] = None,
+ tb_config: Optional[TBConfig] = None,
+ embed_config: Optional[ModelConfig] = None,
+ do_init: bool = False,
+ ):
+ """
+ Initializes the LocalMemoryManager with the given parameters.
+
+ Args:
+ - embed_config: EmbedConfig, the embedding model config
+ - llm_config: LLMConfig, the LLM model config
+ - db_config: DBConfig, the Database config
+ - vb_config: VBConfig, the vector base config
+ - gb_config: GBConfig, the graph base config
+ - do_init: A boolean indicating whether to initialize. Default is False.
+ """
+ self.db_config = db_config
+ self.vb_config = vb_config
+ self.gb_config = gb_config
+ self.tb_config = tb_config
+ self.embed_config = embed_config
+ self.do_init = do_init
+ self.recall_memory_dict: Dict[str, Memory] = {}
+ self.save_message_keys = [
+ 'session_index', 'role_name', 'role_type', 'role_prompt', 'input_query',
+ 'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
+ 'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
+
+ def init_handler(self, ):
+ """Initializes Database VectorBase GraphDB TbaseDB"""
+ self.init_vb()
+ self.init_tb()
+ self.init_db()
+ self.init_gb()
+
+ def reinit_handler(self, do_init: bool=False):
+ self.init_vb()
+ self.init_tb()
+ self.init_db()
+ self.init_gb()
+
+ def init_vb(self, do_init: bool=None):
+ """
+ Initializes the vb.
+ """
+ if self.vb_config:
+ table_init()
+ vb_dict = {"LocalFaissHandler": LocalFaissHandler}
+ vb_class = vb_dict.get(self.vb_config.vb_type, LocalFaissHandler)
+ self.vb: LocalFaissHandler = vb_class(self.embed_config, vb_config=self.vb_config)
+
+ def init_db(self, ):
+ """Initializes Database VectorBase GraphDB TbaseDB"""
+ if self.db_config:
+ db_dict = {"LocalFaissHandler": LocalFaissHandler}
+ db_class = db_dict.get(self.db_config.db_type)
+ self.db = db_class(self.db_config)
+
+ def init_tb(self, do_init: bool=None):
+ """
+ Initializes the tb.
+ """
+ if self.tb_config:
+ tb_dict = {"TbaseHandler": TbaseHandler}
+ tb_class = tb_dict.get(self.tb_config.tb_type, TbaseHandler)
+ self.tb = tb_class(self.tb_config, self.tb_config.index_name)
+
+ def init_gb(self, do_init: bool=None):
+ """
+ Initializes the gb.
+ """
+ if self.gb_config:
+ gb_dict = {"NebulaHandler": NebulaHandler}
+ gb_class = gb_dict.get(self.gb_config.gb_type, NebulaHandler)
+ self.gb = gb_class(self.gb_config)
+
+ def append(self, message: Message, role_tag: str):
+ """
+ Appends a message to the recall memory, current memory, and summary memory.
+
+ Args:
+ - message: An instance of Message class representing the message to be appended.
+ """
+ pass
+
+ def extend(self, memory: Memory, role_tag: str):
+ """
+ Extends the recall memory, current memory, and summary memory.
+
+ Args:
+ - memory: An instance of Memory class representing the memory to be extended.
+ """
+ pass
+
+ def load(self, load_dir: str = "") -> Memory:
+ """
+ Loads the memory from the specified directory and returns a Memory instance.
+
+ Args:
+ - load_dir: A string representing the directory to load the memory from. Default is KB_ROOT_PATH.
+
+ Returns:
+ - An instance of Memory class representing the loaded memory.
+ """
+ pass
+
+ def get_memory_pool(self, session_index: str) -> Memory:
+ """
+ return memory_pool
+ """
+ pass
+
+ def search_messages(self, text: str=None, n=5, **kwargs) -> List[Message]:
+ """
+ return the search messages
+
+ Args:
+ - text: A string representing the text for retrieval. Default is None.
+ - n: An integer representing the number of messages. Default is 5.
+ """
+
+ def router_retrieval(self,
+ session_index: str = "default", text: str=None, datetime: str = None,
+ n=5, top_k=5, retrieval_type: str = "embedding", **kwargs
+ ) -> Memory:
+ """
+ Routes the retrieval based on the retrieval type.
+
+ Args:
+ - text: A string representing the text for retrieval. Default is None.
+ - datetime: A string representing the datetime for retrieval. Default is None.
+ - n: An integer representing the number of messages. Default is 5.
+ - top_k: An integer representing the top k messages. Default is 5.
+ - retrieval_type: A string representing the retrieval type. Default is "embedding".
+ - **kwargs: Additional keyword arguments for retrieval.
+
+ Returns:
+ - A list of Message instances representing the retrieved messages.
+ """
+ retrieval_func_dict = {
+ "embedding": self.embedding_retrieval,
+ "text": self.text_retrieval,
+ "datetime": self.datetime_retrieval
+ }
+
+ # 确保提供了合法的检索类型
+ if retrieval_type not in retrieval_func_dict:
+ raise ValueError(
+ f"Invalid retrieval_type: '{retrieval_type}'. "
+ f"Available types: {list(retrieval_func_dict.keys())}"
+ )
+
+ retrieval_func = retrieval_func_dict[retrieval_type]
+ #
+ params = locals()
+ params.pop("self")
+ params.pop("retrieval_type")
+ params.update(params.pop('kwargs', {}))
+ #
+ return retrieval_func(**params)
+
+ def embedding_retrieval(self, text: str, embed_model="", top_k=1, score_threshold=1.0, **kwargs) -> Memory:
+ """
+ Retrieves messages based on embedding.
+
+ Args:
+ - text: A string representing the text for retrieval.
+ - embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL.
+ - top_k: An integer representing the top k messages. Default is 1.
+ - score_threshold: A float representing the score threshold. Default is SCORE_THRESHOLD.
+ - **kwargs: Additional keyword arguments for retrieval.
+
+ Returns:
+ - A list of Message instances representing the retrieved messages.
+ """
+ pass
+
+ def text_retrieval(self, text: str, **kwargs) -> Memory:
+ """
+ Retrieves messages based on text.
+
+ Args:
+ - text: A string representing the text for retrieval.
+ - **kwargs: Additional keyword arguments for retrieval.
+
+ Returns:
+ - A list of Message instances representing the retrieved messages.
+ """
+ pass
+
+ def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> Memory:
+ """
+ Retrieves messages based on datetime.
+
+ Args:
+ - datetime: A string representing the datetime for retrieval.
+ - text: A string representing the text for retrieval. Default is None.
+ - n: An integer representing the number of messages. Default is 5.
+ - **kwargs: Additional keyword arguments for retrieval.
+
+ Returns:
+ - A list of Message instances representing the retrieved messages.
+ """
+ pass
+
+ def recursive_summary(self, messages: List[Message], split_n: int = 20) -> Memory:
+ """
+ Performs recursive summarization of messages.
+
+ Args:
+ - messages: A list of Message instances representing the messages to be summarized.
+ - split_n: An integer representing the split n. Default is 20.
+
+ Returns:
+ - A list of Message instances representing the summarized messages.
+ """
+ pass
+
+ def reranker(self, ):
+ """
+ rerank the retrieval message from memory
+ """
+ pass
+
diff --git a/muagent/memory/hierarchical_memory_manager.py b/muagent/memory_manager/hierarchical_memory_manager.py
similarity index 96%
rename from muagent/memory/hierarchical_memory_manager.py
rename to muagent/memory_manager/hierarchical_memory_manager.py
index 180dc35..99b40f8 100644
--- a/muagent/memory/hierarchical_memory_manager.py
+++ b/muagent/memory_manager/hierarchical_memory_manager.py
@@ -15,8 +15,8 @@
from muagent.connector.memory_manager import BaseMemoryManager
from muagent.llm_models import *
from muagent.base_configs.env_config import KB_ROOT_PATH
-from muagent.orm import table_init
-
+# from muagent.orm import table_init
+from muagent.db_handler import table_init
from muagent.utils.common_utils import *
diff --git a/muagent/memory_manager/local_memory_manager.py b/muagent/memory_manager/local_memory_manager.py
new file mode 100644
index 0000000..bcb3180
--- /dev/null
+++ b/muagent/memory_manager/local_memory_manager.py
@@ -0,0 +1,443 @@
+from abc import abstractmethod, ABC
+from typing import List, Dict
+import os, sys, copy, json, uuid, random
+from jieba.analyse import extract_tags
+from collections import Counter
+from loguru import logger
+import numpy as np
+
+from langchain_community.docstore.document import Document
+
+
+from .base_memory_manager import BaseMemoryManager
+
+from ..schemas import Memory, Message
+from ..schemas.models import ModelConfig
+from ..schemas.db import DBConfig, GBConfig, VBConfig, TBConfig
+
+from ..models import get_model
+
+from muagent.connector.configs.generate_prompt import *
+from muagent.db_handler import *
+from muagent.llm_models import getChatModelFromConfig
+from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
+from muagent.utils.common_utils import *
+from muagent.base_configs.env_config import KB_ROOT_PATH
+
+
+class LocalMemoryManager(BaseMemoryManager):
+ """This class represents a LocalMemoryManager that inherits from BaseMemoryManager.
+ It provides functionalities to handle local memory storage and retrieval of messages.
+ """
+ memory_manager_type: str = "local_memory_manager"
+ """The type of memory manager"""
+
+ def __init__(
+ self,
+ embed_config: Union[ModelConfig, EmbedConfig],
+ llm_config: Union[LLMConfig, ModelConfig],
+ vb_config: Optional[VBConfig] = None,
+ db_config: Optional[DBConfig] = None,
+ gb_config: Optional[GBConfig] = None,
+ tb_config: Optional[TBConfig] = None,
+ do_init: bool = False,
+ kb_root_path: str = KB_ROOT_PATH,
+ ):
+ """Initialize the LocalMemoryManager with configurations.
+
+ Args:
+ embed_config (Union[ModelConfig, EmbedConfig]): Configuration for embedding.
+ llm_config (Union[LLMConfig, ModelConfig]): Configuration for LLM.
+ vb_config (Optional[VBConfig], optional): Vector database configuration.
+ db_config (Optional[DBConfig], optional): Database configuration.
+ gb_config (Optional[GBConfig], optional): Graph database configuration.
+ tb_config (Optional[TBConfig], optional): Tbase configuration.
+ do_init (bool, optional): Flag indicating if initialization is required.
+ kb_root_path (str, optional): Path for storing knowledge base files (default is KB_ROOT_PATH).
+ """
+ super().__init__(
+ vb_config or VBConfig(vb_type="LocalFaissHandler"),
+ db_config, gb_config, tb_config,
+ embed_config
+ )
+
+ self.do_init = do_init
+ self.kb_root_path = kb_root_path
+ self.embed_config: Union[ModelConfig, EmbedConfig] = embed_config
+ self.llm_config: Union[LLMConfig, ModelConfig] = llm_config
+
+ # default
+ self.session_index: str = "default"
+ self.kb_name = f"{self.session_index}"
+ self.uuid_file = os.path.join(
+ self.kb_root_path, f"{self.session_index}/conversation.jsonl")
+
+ self.recall_memory_dict: Dict[str, Memory] = {}
+ self.memory_uuids = set()
+ self.save_message_keys = [
+ 'session_index', 'message_index', 'role_name', 'role_type', 'content',
+ 'input_text', 'role_tags', 'content', 'step_content',
+ 'parsed_content', 'spec_parsed_contents', 'global_kwargs',
+ 'start_datetime', 'end_datetime',
+ "keyword", "vector"
+ ]
+ # init from config
+ if isinstance(self.llm_config, LLMConfig):
+ self.model = getChatModelFromConfig(self.llm_config)
+ else:
+ self.model = get_model(self.llm_config)
+ self.init_handler()
+ self.load(do_init)
+
+ def clear_local(self, re_init: bool = False, handler_type: str = None):
+ """Clear local memory and reinitialize if specified.
+
+ Args:
+ re_init (bool, optional): Whether to reinitialize after clearing.
+ handler_type (str, optional): Type of handler to use (currently unused).
+ """
+ if self.vb: # 存到了本地需要清理
+ self.vb.clear_vs_local()
+ self.load(re_init)
+
+ def append(self, message: Message, role_tag: str=None) -> None:
+ """Append a message to the local memory and update vector store if necessary.
+
+ Args:
+ message (Message): The message to append.
+ role_tag (str, optional): An optional role tag for the message.
+ """
+ # update the newest uuid_name
+ self.check_uuid_name(message)
+ datetimes = self.recall_memory_dict[self.session_index].get_datetimes()
+ contents = self.recall_memory_dict[self.session_index].get_contents()
+ message_indexes = self.recall_memory_dict[
+ self.session_index].get_memory_values("message_index")
+ # if message not in chat history, no need to update
+ if message.message_index in message_indexes:
+ self.update2vb(message, role_tag)
+ elif ((message.end_datetime not in datetimes) or
+ ((message.input_text not in contents) and (message.content not in contents))
+ ):
+ self.append2vb(message, role_tag)
+
+ def append2vb(self, message: Message, role_tag: str=None) -> None:
+ """Append a message and its embeddings to the vector store (VB).
+
+ Args:
+ message (Message): The message to append to the vector store.
+ role_tag (str, optional): Optional role tag for the message.
+ """
+ if role_tag:
+ if isinstance(message.role_tags, list):
+ message.role_tags = list(set(message.role_tags + [role_tag]))
+ else:
+ message.role_tags += f", {role_tag}"
+ self.recall_memory_dict[self.session_index].append(message)
+ memory = self.recall_memory_dict[self.session_index]
+ #
+ docs, json_messages = self.message_process([message])
+ if self.embed_config:
+ self.vb.add_docs(docs, kb_name=self.kb_name)
+ #
+ if True: # resave the local
+ _, json_messages = self.message_process(memory.messages)
+ save_to_json_file(json_messages, self.uuid_file)
+
+
+ def update2vb(self, message: Message, role_tag: str=None) -> None:
+ """Update an existing message in the vector store.
+
+ Args:
+ message (Message): The message to update.
+ role_tag (str, optional): Optional role tag for the message.
+ """
+ memory = self.recall_memory_dict[self.session_index]
+ memory.update(message, role_tag)
+
+ #
+ docs, json_messages = self.message_process([message])
+ # if self.embed_config:
+ # # search
+ # # delete
+ # # add
+ # self.vb.add_docs(docs, kb_name=self.kb_name)
+ #
+ if True: # resave the local
+ _, json_messages = self.message_process(memory.messages)
+ save_to_json_file(json_messages, self.uuid_file)
+
+
+ def extend(self, memory: Memory, role_tag: str=None):
+ """Append multiple messages from a Memory object to local memory.
+
+ Args:
+ memory (Memory): The Memory object containing messages to append.
+ role_tag (str, optional): An optional role tag for messages.
+ """
+ for message in memory.messages:
+ self.append(message, role_tag)
+
+ def message_process(self, messages: List[Message]):
+ """Convert message objects to vector store/local data format.
+
+ Args:
+ messages (List[Message]): List of messages to process.
+
+ Returns:
+ Tuple[List[Document], dict]: Tuple containing documents for vector storage and a JSON representation of messages.
+ """
+ messages = [{
+ k: v for k, v in m.dict().items()
+ if k in self.save_message_keys
+ }
+ for m in messages
+ ]
+ docs = [{
+ "page_content": m["step_content"] or m["content"] or m["input_text"],
+ "metadata": m}
+ for m in messages
+ ]
+ docs = [Document(**doc) for doc in docs]
+ # convert messages to local data-format
+ memory_messages = self.recall_memory_dict[self.session_index].dict()
+ json_messages = {
+ k: [
+ {kkk: vvv for kkk, vvv in vv.items()
+ if kkk in self.save_message_keys}
+ for vv in v
+ ]
+ for k, v in memory_messages.items()
+ }
+
+ return docs, json_messages
+
+ def load(self, re_init=False) -> Memory:
+ """Load memory from files in the specified database root path.
+
+ Args:
+ re_init (bool, optional): Flag indicating if reinitialization of memory should occur.
+
+ Returns:
+ Memory: Loaded messages in memory format.
+ """
+ if not re_init:
+ for root, dirs, files in os.walk(self.kb_root_path):
+ for file in files:
+ if file != 'conversation.jsonl': continue
+ file_path = os.path.join(root, file)
+ # get uuid_name
+ relative_path = os.path.relpath(root, self.kb_root_path)
+ path_parts = relative_path.split(os.sep)
+ uuid_name = "_".join(path_parts)
+ # load to local cache
+ recall_memory = Memory(**read_json_file(file_path))
+ self.recall_memory_dict[uuid_name] = recall_memory
+ else:
+ self.recall_memory_dict = {}
+
+ def get_memory_pool(self, session_index: str = "") -> Memory:
+ """Retrieve the memory pool for a specific session index.
+
+ Args:
+ session_index (str, optional): Session index (default is empty string).
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ return self.recall_memory_dict.get(session_index, Memory(messages=[]))
+
+ def embedding_retrieval(
+ self,
+ text: str,
+ top_k=1,
+ score_threshold=0.7,
+ session_index: str = "default",
+ **kwargs
+ ) -> List[Message]:
+ """Retrieve messages based on text embedding.
+
+ Args:
+ text (str): The input text for embedding retrieval.
+ top_k (int, optional): The number of top results to retrieve (default is 1).
+ score_threshold (float, optional): Minimum score for message retrieval (default is 0.7).
+ session_index (str, optional): Session identifier (default is "default").
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ if text is None: return Memory(messages=[])
+
+ # kb_name = self.get_vbname_from_sessionindex(session_index)
+ kb_name = session_index
+ docs = self.vb.search(
+ text,
+ top_k=top_k,
+ score_threshold=score_threshold,
+ kb_name=kb_name
+ )
+ return Memory(messages=[Message(**doc.metadata) for doc, score in docs])
+
+ def text_retrieval(
+ self,
+ text: str,
+ session_index: str = "default",
+ **kwargs
+ ) -> Memory:
+ """Retrieve messages based on text content.
+
+ Args:
+ text (str): The text to match against messages.
+ session_index (str, optional): Session identifier (default is "default").
+
+ Returns:
+ Memory: Messages matching the text content.
+ """
+ if text is None: return Memory(messages=[])
+
+ # uuid_name = self.get_uuid_from_sessionindex(session_index)
+ messages = self.recall_memory_dict.get(
+ session_index, Memory(messages=[])).messages
+ return self._text_retrieval_from_cache(
+ messages, text, score_threshold=0.3, topK=5, **kwargs
+ )
+
+ def datetime_retrieval(
+ self,
+ session_index: str,
+ datetime: str,
+ text: str = None,
+ n: int = 5,
+ key: str = "start_datetime",
+ **kwargs
+ ) -> Memory:
+ """Retrieve messages based on date and time criteria.
+
+ Args:
+ session_index (str): The session index to filter messages.
+ datetime (str): The datetime string reference for filtering.
+ text (str, optional): Optional text to match with messages.
+ n (int, optional): Number of minutes to define the range (default is 5).
+ key (str, optional): The key for datetime filtering (default is "start_datetime").
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ if datetime is None: return Memory(messages=[])
+
+ # uuid_name = self.get_uuid_from_sessionindex(session_index)
+ messages = self.recall_memory_dict.get(
+ session_index, Memory(messages=[])).messages
+ return self._datetime_retrieval_from_cache(
+ messages, datetime, text, n, **kwargs
+ )
+
+ def _text_retrieval_from_cache(
+ self,
+ messages: List[Message],
+ text: str = None,
+ score_threshold=0.3,
+ topK=5,
+ tag_topK=5,
+ **kwargs
+ ) -> Memory:
+ keywords = extract_tags(text, topK=tag_topK)
+
+ matched_messages = []
+ for message in messages:
+ content = message.step_content or message.input_text or message.content
+ message_keywords = extract_tags(content, topK=tag_topK)
+ # calculate jaccard similarity
+ intersection = Counter(keywords) & Counter(message_keywords)
+ union = Counter(keywords) | Counter(message_keywords)
+ similarity = sum(intersection.values()) / sum(union.values())
+ if similarity >= score_threshold:
+ matched_messages.append((message, similarity))
+ matched_messages = sorted(matched_messages, key=lambda x:x[1])
+ # return [m for m, s in matched_messages][:topK]
+ return Memory(messages=[m for m, s in matched_messages][:topK] )
+
+ def _datetime_retrieval_from_cache(
+ self,
+ messages: List[Message],
+ datetime: str,
+ text: str = None,
+ n: int = 5, **kwargs
+ ) -> Memory:
+ # select message by datetime
+ datetime_before, datetime_after = addMinutesToTimestamp(datetime, n)
+ select_messages = [
+ message for message in messages
+ if datetime_before<=dateformatToTimestamp(message.end_datetime, 1, message.datetime_format)<=datetime_after
+ ]
+ return self._text_retrieval_from_cache(select_messages, text)
+
+ def recursive_summary(
+ self,
+ messages: List[Message],
+ split_n: int = 20,
+ session_index: str=""
+ ) -> Memory:
+ """Generate a recursive summary of the provided messages.
+
+ Args:
+ messages (List[Message]): List of messages to summarize.
+ split_n (int, optional): Number of messages to include in each summary pass (default is 20).
+ session_index (str, optional): Session identifier for the summary.
+
+ Returns:
+ Memory: Updated messages including the summary.
+ """
+ if len(messages) == 0:
+ return Memory(messages=[])
+
+ newest_messages = messages[-split_n:]
+ summary_messages = messages[:max(0, len(messages)-split_n)]
+
+ while (len(newest_messages) != 0) and (newest_messages[0].role_type != "user"):
+ message = newest_messages.pop(0)
+ summary_messages.append(message)
+
+ # summary
+ summary_content = '\n\n'.join([
+ m.role_type + "\n" + "\n".join(([f"*{k}* {v}" for parsed_output in m.spec_parsed_contents for k, v in parsed_output.items() if k not in ['Action Status']]))
+ for m in summary_messages if m.role_type not in ["summary"]
+ ])
+
+ summary_prompt = createSummaryPrompt(conversation=summary_content)
+ content = self.model.predict(summary_prompt)
+ summary_message = Message(
+ session_index=session_index,
+ role_name="summaryer",
+ role_type="summary",
+ content=content,
+ step_content=content,
+ spec_parsed_contents=[],
+ global_kwargs={}
+ )
+ summary_message.spec_parsed_contents.append({"summary": content})
+ newest_messages.insert(0, summary_message)
+
+ return Memory(messages=newest_messages)
+
+ def check_uuid_name(self, message: Message = None):
+ if message.session_index != self.session_index:
+ self.session_index = message.session_index
+ # self.init_vb()
+
+ self.kb_name = self.session_index
+ self.uuid_file = os.path.join(self.kb_root_path, f"{self.session_index}/conversation.jsonl")
+
+ self.memory_uuids.add(self.session_index)
+ if self.session_index not in self.recall_memory_dict:
+ self.recall_memory_dict[self.session_index] = Memory(messages=[])
+
+ def modified_message(self, message: Message, update_rule_text: str) -> Message:
+ # 创建提示语,在更新规则文本中包含当前消息的内容
+ prompt = f"结合以下更新内容修改当前消息内容:\n更新内容: {update_rule_text}\n\n当前消息内容:\n{message.role_content}\n\n请生成新的消息内容:"
+
+ new_content = self.model.predict(prompt)
+
+ message.content = new_content
+
+ return message
\ No newline at end of file
diff --git a/muagent/memory_manager/tbase_memory_manager.py b/muagent/memory_manager/tbase_memory_manager.py
new file mode 100644
index 0000000..49747d3
--- /dev/null
+++ b/muagent/memory_manager/tbase_memory_manager.py
@@ -0,0 +1,628 @@
+from typing import (
+ List,
+ Union,
+ Optional,
+)
+import numpy as np
+from jieba.analyse import extract_tags
+import random
+from collections import Counter
+from loguru import logger
+import json
+
+from .base_memory_manager import BaseMemoryManager
+
+from ..schemas import Memory, Message
+from ..schemas.models import ModelConfig
+from ..schemas.db import DBConfig, GBConfig, VBConfig, TBConfig
+
+from ..db_handler import *
+from ..models import get_model
+
+
+from muagent.llm_models import getChatModelFromConfig
+from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
+from muagent.connector.configs.generate_prompt import *
+from muagent.utils.common_utils import *
+
+from muagent.llm_models.get_embedding import get_embedding
+from redis.commands.search.field import (
+ TextField,
+ NumericField,
+ VectorField,
+ TagField
+)
+
+DIM = 768
+MESSAGE_SCHEMA = [
+ TextField("session_index", ),
+ TextField("message_index", ),
+ TextField("node_index"),
+ TextField("role_name",),
+ TextField("role_type", ),
+ TextField('input_text'),
+ TextField("content", ),
+ TextField("role_tags"),
+ TextField("parsed_output"),
+ TextField("global_kwargs",),
+ NumericField("start_datetime",) ,
+ NumericField("end_datetime",),
+ VectorField("vector",
+ 'FLAT',
+ {
+ "TYPE": "FLOAT32",
+ "DIM": DIM,
+ "DISTANCE_METRIC": "COSINE"
+ }),
+ TagField(name='keyword', separator='|')
+]
+
+
+
+class TbaseMemoryManager(BaseMemoryManager):
+ """
+ This class represents a TbaseMemoryManager that inherits from BaseMemoryManager.
+ """
+
+ memory_manager_typy = "tbase_memory_manager"
+ """The type of memory manager for identification purposes."""
+
+ def __init__(
+ self,
+ embed_config: Union[ModelConfig, EmbedConfig],
+ llm_config: Union[LLMConfig, ModelConfig],
+ tbase_handler: TbaseHandler = None,
+ use_vector: bool = False,
+ vb_config: Optional[VBConfig] = None,
+ db_config: Optional[DBConfig] = None,
+ gb_config: Optional[GBConfig] = None,
+ tb_config: Optional[TBConfig] = None,
+ do_init: bool = False,
+ ):
+ """Initialize the TbaseMemoryManager with specified configurations.
+
+ Args:
+ embed_config (Union[ModelConfig, EmbedConfig]): Configuration for embedding.
+ llm_config (Union[LLMConfig, ModelConfig]): Configuration for the LLM.
+ tbase_handler (TbaseHandler, optional): Handler for Tbase database access.
+ use_vector (bool, optional): Flag to specify whether to use vector embeddings.
+ vb_config (Optional[VBConfig], optional): Configuration for the vector database.
+ db_config (Optional[DBConfig], optional): Configuration for the main database.
+ gb_config (Optional[GBConfig], optional): Configuration for graph database.
+ tb_config (Optional[TBConfig], optional): Configuration for Tbase.
+ do_init (bool, optional): Flag to indicate if initialization is required.
+ """
+
+ super().__init__(vb_config, db_config, gb_config, tb_config)
+ self.do_init = do_init
+ self.embed_config: Union[ModelConfig, EmbedConfig] = embed_config
+ self.llm_config: Union[LLMConfig, ModelConfig] = llm_config
+ self.tb: TbaseHandler = tbase_handler
+ self.save_message_keys = [
+ 'session_index', 'message_index', 'node_index', 'role_name', 'role_type', 'content',
+ 'input_text', 'role_tags', 'content', 'step_content',
+ 'parsed_content', 'spec_parsed_contents', 'global_kwargs',
+ 'start_datetime', 'end_datetime',
+ "keyword", "vector"
+ ]
+ self.use_vector = use_vector
+ self.init_handler()
+ self.init_tb_index()
+
+ def init_tb_index(self, do_init: bool=None):
+ """Initialize the Tbase index if it does not already exist.
+
+ Args:
+ do_init (bool, optional): Optional flag for initialization (unused here).
+ """
+ # Create index if it does not exist
+ if not self.tb.is_index_exists():
+ res = self.tb.create_index(schema=MESSAGE_SCHEMA)
+ logger.info(res)
+
+ def append(self, message: Message, role_tag: str=None) -> None:
+ """Append a message to the Tbase memory.
+
+ Args:
+ message (Message): The message to be appended.
+ role_tag (str, optional): Optional role tag for the message.
+ """
+ tbase_message = self.localMessage2TbaseMessage(message, role_tag) # Convert local message to Tbase format
+ self.tb.insert_data_hash(tbase_message) # Insert into Tbase
+
+ def extend(self, memory: Memory, role_tag: str=None) -> None:
+ """Append multiple messages from memory to Tbase.
+
+ Args:
+ memory (Memory): The memory containing messages to append.
+ role_tag (str, optional): Optional role tag for all messages.
+ """
+ for message in memory.messages:
+ self.append(message, role_tag) # Append each message
+
+ def append_tools(self, tool_information: dict, session_index: str, nodeid: str, node_index: str="default") -> None:
+ """Append tool-related information to Tbase as messages.
+
+ Args:
+ tool_information (dict): Dictionary containing tool information.
+ session_index (str): Session identifier.
+ nodeid (str): Graph node ID.
+ node_index (str, optional): Node index for differentiating nodes.
+ """
+ tool_map = {
+ "toolKey": {"role_name": "tool_selector", "role_type": "assistant",
+ "customed_keys": ["toolDef"]
+ },
+ "toolParam": {"role_name": "tool_filler", "role_type": "assistant"},
+ "toolResponse": {"role_name": "function_caller", "role_type": "observation"},
+ "toolSummary": {"role_name": "function_summary", "role_type": "Summary"},
+ }
+
+ for k, v in tool_map.items():
+ try:
+ message = Message(
+ session_index=session_index,
+ message_index= f"{nodeid}_{k}",
+ node_index=node_index,
+ role_name = v["role_name"], # Assign role name
+ role_type = v["role_type"], # Assign role type
+ content = tool_information[k], # Assign tool information content
+ global_kwargs = {
+ **{kk: vv for kk, vv in tool_information.items()
+ if kk in v.get("customed_keys", [])}
+ } # Store additional tool information
+ )
+ except:
+ pass
+ self.append(message) # Append the message to Tbase
+
+
+ def get_memory_by_sessionindex_tags(self, session_index: str, tags: List[str], limit: int = 10) -> Memory:
+ """Retrieve memory messages by session index and tags.
+
+ Args:
+ session_index (str): The session index to search for.
+ tags (List[str]): List of tags to match against messages.
+ limit (int, optional): The maximum number of messages to retrieve (default is 10).
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ tags_str = '|'.join([f"*{tag}*" for tag in tags]) # Create a tags search string
+ querys = [
+ f"@session_index:{session_index}", # Query for session index
+ f"@role_tags:{tags_str}", # Query for role tags
+ ]
+ query = f"({')('.join(querys)})" if len(querys) >=2 else "".join(querys) # Combine queries
+ r = self.tb.search(query, limit=limit) # Search Tbase
+ return self.tbasedoc2Memory(r) # Convert results to Memory format
+
+ def get_memory_by_chatindex_tags(self, chat_index: str, tags: List[str], limit: int = 10) -> Memory:
+ """Retrieve memory messages by chat index and tags.
+
+ Args:
+ chat_index (str): The chat index to search for.
+ tags (List[str]): List of tags to match against messages.
+ limit (int, optional): The maximum number of messages to retrieve (default is 10).
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ tags_str = '|'.join([f"*{tag}*" for tag in tags]) # Create a tags search string
+ querys = [
+ f"@session_index:{chat_index}", # Query for session index
+ f"@role_tags:{tags_str}", # Query for role tags
+ ]
+ query = f"({')('.join(querys)})" if len(querys) >=2 else "".join(querys) # Combine queries
+ logger.debug(f"{query}")
+ r = self.tb.search(query, limit=limit) # Search Tbase
+ return self.tbasedoc2Memory(r) # Convert results to Memory format
+
+ def get_memory_pool(self, session_index: str = "") -> Memory:
+ """Get the memory pool for a specific session index.
+
+ Args:
+ session_index (str, optional): Session index (default is empty string).
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ return self.get_memory_pool_by_all({"session_index": session_index}) # Retrieve all memory for session
+
+ def get_memory_pool_by_content(self, content: str) -> Memory:
+ """Get memory pool based on content search.
+
+ Args:
+ content (str): Content to search for in messages.
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ r = self.tb.search(content) # Search Tbase
+ return self.tbasedoc2Memory(r) # Convert results to Memory format
+
+ def get_memory_pool_by_key_content(self, key: str, content: str) -> Memory:
+ """Get memory pool based on key and content search.
+
+ Args:
+ key (str): Key to search for in messages.
+ content (str): Content to search for in messages.
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ if key == "keyword":
+ query = f"@{key}:{{{content}}}" # Special handling for keywords
+ else:
+ query = f"@{key}:{content}" # General query
+ r = self.tb.search(query) # Search Tbase
+ return self.tbasedoc2Memory(r) # Convert results to Memory format
+
+ def get_memory_pool_by_all(self, search_key_contents: dict, limit: int =10) -> Memory:
+ """Get memory pool based on multiple search criteria.
+
+ Args:
+ search_key_contents (dict): Dictionary containing key-value pairs for searching messages.
+ limit (int, optional): The maximum number of messages to retrieve (default is 10).
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ querys = []
+ for k, v in search_key_contents.items():
+ if not v: continue
+ if k == "keyword":
+ querys.append(f"@{k}:{{{v}}}")
+ elif k == "role_tags":
+ tags_str = '|'.join([f"*{tag}*" for tag in v]) if isinstance(v, list) else f"{v}"
+ querys.append(f"@role_tags:{tags_str}")
+ elif k == "start_datetime":
+ query = f"(@start_datetime:[{v[0]} {v[1]}])"
+ querys.append(query)
+ elif k == "end_datetime":
+ query = f"(@end_datetime:[{v[0]} {v[1]}])"
+ querys.append(query)
+ else:
+ querys.append(f"@{k}:{v}")
+
+ query = f"({')('.join(querys)})" if len(querys) >=2 else "".join(querys)
+ r = self.tb.search(query, limit=limit)
+ return self.tbasedoc2Memory(r)
+
+ def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, session_index: str = "default", **kwargs) -> Memory:
+ """Retrieve memory using vector embeddings based on input text.
+
+ Args:
+ text (str): The input text for which embeddings are generated.
+ top_k (int, optional): Number of top results to retrieve (default is 1).
+ score_threshold (float, optional): Minimum score for fetching results (default is 1.0).
+ session_index (str, optional): Session identifier (default is "default").
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ if text is None: return Memory(messages=[])
+ if not self.use_vector and self.embed_config:
+ logger.error(f"can't use vector search, because the use_vector is {self.use_vector}")
+ return Memory(messages=[])
+
+ if self.use_vector and self.embed_config:
+ query_embedding = self._get_embedding_array(text)
+
+ base_query = f'(@session_index:{session_index})=>[KNN {top_k} @vector $vector AS distance]'
+ query_params = {"vector": query_embedding}
+ r = self.tb.vector_search(base_query, query_params=query_params)
+ return self.tbasedoc2Memory(r)
+
+ def text_retrieval(self, text: str, session_index: str = "default", **kwargs) -> Memory:
+ """Retrieve messages based on text content and session index.
+
+ Args:
+ text (str): The text to search for.
+ session_index (str, optional): Session identifier (default is "default").
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+ keywords = extract_tags(text, topK=-1)
+ if len(keywords) > 0:
+ keyword = "|".join(keywords)
+ query = f"(@session_index:{session_index})(@keyword:{{{keyword}}})"
+ else:
+ query = f"@session_index:{session_index}"
+ # logger.debug(f"text_retrieval query: {query}")
+ r = self.tb.search(query)
+ memory = self.tbasedoc2Memory(r)
+ return self._text_retrieval_from_cache(memory.messages, text)
+
+ def datetime_retrieval(
+ self,
+ session_index: str,
+ datetime: str,
+ text: str = None,
+ n: int = 5,
+ key: str = "start_datetime",
+ **kwargs
+ ) -> Memory:
+ """Retrieve messages based on datetime range and session index.
+
+ Args:
+ session_index (str): The session index to filter messages.
+ datetime (str): The timestamp used for filtering messages.
+ text (str, optional): Optional text to retrieve alongside datetime.
+ n (int, optional): Number of minutes to define the range (default is 5).
+ key (str, optional): The key for datetime filtering (default is "start_datetime").
+
+ Returns:
+ Memory: Retrieved messages in memory format.
+ """
+
+ intput_timestamp = None
+ for datetime_format in ["%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S"]:
+ try:
+ intput_timestamp = dateformatToTimestamp(datetime, 1000, datetime_format)
+ break
+ except:
+ pass
+ if intput_timestamp is None:
+ raise ValueError(f"can't transform datetime into [%Y-%m-%d %H:%M:%S.%f, %Y-%m-%d %H:%M:%S]")
+
+ query = f"(@session_index:{session_index})(@{key}:[{intput_timestamp-n*60*1000} {intput_timestamp+n*60*1000}])"
+ # logger.debug(f"datetime_retrieval query: {query}")
+ r = self.tb.search(query)
+ memory = self.tbasedoc2Memory(r)
+ return self._text_retrieval_from_cache(memory.messages, text)
+
+ def _text_retrieval_from_cache(
+ self,
+ messages: List[Message],
+ text: str = None,
+ score_threshold=0.3,
+ topK=5,
+ tag_topK=5,
+ **kwargs
+ ) -> Memory:
+ """Retrieve messages based on text similarity from cached messages."""
+
+ if text is None:
+ return Memory(messages=messages[:topK])
+
+ if len(messages) < topK:
+ return Memory(messages=messages)
+
+ keywords = extract_tags(text, topK=tag_topK)
+
+ matched_messages = []
+ for message in messages:
+ message_keywords = extract_tags(
+ message.step_content or message.content or message.input_text,
+ topK=tag_topK
+ )
+ # calculate jaccard similarity
+ intersection = Counter(keywords) & Counter(message_keywords)
+ union = Counter(keywords) | Counter(message_keywords)
+ similarity = sum(intersection.values()) / sum(union.values())
+ if similarity >= score_threshold:
+ matched_messages.append((message, similarity))
+ matched_messages = sorted(matched_messages, key=lambda x:x[1])
+ return Memory(messages=[m for m, s in matched_messages][:topK])
+
+ def recursive_summary(
+ self,
+ messages: List[Message],
+ session_index: str,
+ split_n: int = 20
+ ) -> Memory:
+ """Generate a recursive summary of the provided messages.
+
+ Args:
+ messages (List[Message]): List of messages to summarize.
+ session_index (str): Session identifier for the summary.
+ split_n (int, optional): Number of messages to include in each summary pass (default is 20).
+
+ Returns:
+ Memory: Updated messages including the summary.
+ """
+
+ if len(messages) == 0:
+ return Memory(messages=messages)
+
+ newest_messages = messages[-split_n:]
+ summary_messages = messages[:len(messages)-split_n]
+
+ while (len(newest_messages) != 0) and (newest_messages[0].role_type != "user"):
+ message = newest_messages.pop(0)
+ summary_messages.append(message)
+
+ # summary
+ model = self._get_model()
+ summary_content = '\n\n'.join([
+ m.role_type + "\n" + "\n".join(([f"*{k}* {v}" for parsed_output in m.spec_parsed_contents for k, v in parsed_output.items() if k not in ['Action Status']]))
+ for m in summary_messages if m.role_type not in ["summary"]
+ ])
+ # summary_prompt = CONV_SUMMARY_PROMPT_SPEC.format(conversation=summary_content)
+ summary_prompt = createSummaryPrompt(conversation=summary_content)
+ logger.debug(f"{summary_prompt}")
+ content = model.predict(summary_prompt)
+ summary_message = Message(
+ session_index=session_index,
+ role_name="summaryer",
+ role_type="summary",
+ content=content,
+ step_content=content,
+ parsed_output_list=[],
+ global_kwargs={}
+ )
+ summary_message.spec_parsed_contents.append({"summary": content})
+ newest_messages.insert(0, summary_message)
+ return Memory(messages=newest_messages)
+
+ def localMessage2TbaseMessage(self, message: Message, role_tag: str= None):
+ """Convert a local Message object to a format suitable for Tbase storage."""
+
+ r = self.tb.search(f"@message_index: {message.message_index}")
+ history_role_tags = json.loads(r.docs[0].role_tags) if r.total == 1 else []
+
+ tbase_message = {}
+ for k, v in message.dict().items():
+ v = list(set(history_role_tags+[role_tag])) if k=="role_tags" and role_tag else v
+ if isinstance(v, dict) or isinstance(v, list):
+ v = json.dumps(v, ensure_ascii=False)
+ tbase_message[k] = v
+
+ tbase_message["start_datetime"] = dateformatToTimestamp(message.start_datetime, 1000, "%Y-%m-%d %H:%M:%S.%f")
+ tbase_message["end_datetime"] = dateformatToTimestamp(message.end_datetime, 1000, "%Y-%m-%d %H:%M:%S.%f")
+
+ if self.use_vector and self.embed_config:
+ tbase_message["vector"] = self._get_embedding_array(message.content)
+ tbase_message["keyword"] = " | ".join(extract_tags(message.content, topK=-1)
+ + [tbase_message["message_index"].split("-")[0]])
+
+ tbase_message = {
+ k: v for k, v in tbase_message.items()
+ if k in self.save_message_keys
+ }
+ return tbase_message
+
+ def tbasedoc2Memory(self, r_docs) -> Memory:
+ """Convert Tbase documents back into Memory objects."""
+
+ memory = Memory()
+ for doc in r_docs.docs:
+ tbase_message = {}
+ for k, v in doc.__dict__.items():
+ if k in ["content", "input_text"]:
+ tbase_message[k] = v
+ continue
+ try:
+ v = json.loads(v)
+ except:
+ pass
+
+ tbase_message[k] = v
+
+ message = Message(**tbase_message)
+ memory.append(message)
+
+ for message in memory.messages:
+ message.start_datetime = timestampToDateformat(int(message.start_datetime), 1000, "%Y-%m-%d %H:%M:%S.%f")
+ message.end_datetime = timestampToDateformat(int(message.end_datetime), 1000, "%Y-%m-%d %H:%M:%S.%f")
+
+ memory.sort_by_key("end_datetime")
+ return memory
+
+
+ def init_global_msg(self, session_index: str, role_name: str, content: str, role_type: str = "global_value") -> bool:
+ """Initialize a global message and append it to the memory.
+
+ Args:
+ session_index (str): The session index to which the message belongs.
+ role_name (str): The role name for the message.
+ content (str): The content of the message.
+ role_type (str, optional): The role type of the message (default is "global_value").
+
+ Returns:
+ bool: True if the message was initialized successfully; otherwise, False.
+ """
+
+ msg = Message(session_index=session_index, message_index = role_name ,role_name=role_name, role_type=role_type, content=content)
+ try:
+ self.append(msg)
+ return True
+ except Exception as e:
+ logger.error(f"Failed to initialize global message: {e}")
+ return False
+
+ def get_msg_by_role_name(self, session_index: str, role_name: str) -> Optional[Message]:
+ """Retrieve a message by its role name within a session.
+
+ Args:
+ session_index (str): The session index to search within.
+ role_name (str): The role name of the desired message.
+
+ Returns:
+ Optional[Message]: The found message, or None if not found.
+ """
+
+ memory = self.get_memory_pool_by_all({"session_index": session_index, "role_name": role_name})
+ # memory = self.get_memory_pool(session_index)
+ for msg in memory.messages:
+ if msg.role_name == role_name:
+ return msg
+ return None
+
+ def get_msg_content_by_role_name(self, session_index: str, role_name: str) -> Optional[str]:
+ """Retrieve the content of a message by its role name.
+
+ Args:
+ session_index (str): The session index to search within.
+ role_name (str): The role name of the desired message.
+
+ Returns:
+ Optional[str]: The content of the found message, or None if not found.
+ """
+
+ message = self.get_msg_by_role_name(session_index, role_name)
+ if message == None:
+ return None
+ else:
+ return message.content
+
+ def update_msg_content_by_rule(self, session_index: str, role_name: str, new_content: str,update_rule: str) -> bool:
+ """Update the content of a message based on an update rule.
+
+ Args:
+ session_index (str): The session index to search within.
+ role_name (str): The role name of the message to update.
+ new_content (str): The new content to apply.
+ update_rule (str): The rule to apply for the update.
+
+ Returns:
+ bool: True if the message was successfully updated; otherwise, False.
+ """
+
+ message = self.get_msg_by_role_name(session_index, role_name)
+
+ if message == None:
+ return False
+
+ prompt = f"{new_content}\n{role_name}:{message.content}\n{update_rule}"
+ model = self._get_model()
+
+ new_content = model.predict(prompt)
+
+ if new_content is not None:
+ message.content = new_content
+ self.append(message)
+ return True
+ else:
+ return False
+
+ def _get_embedding(self, text) -> Dict[str, List[float]]:
+ text_vector = {}
+ if self.embed_config and text:
+ if isinstance(self.embed_config, ModelConfig):
+ self.emebd_model = get_model(self.embed_config)
+ vector = self.emebd_model.embed_query(text)
+ text_vector = {text: vector}
+ else:
+ text_vector = get_embedding(
+ self.embed_config.embed_engine, [text],
+ self.embed_config.embed_model_path, self.embed_config.model_device,
+ self.embed_config
+ )
+ else:
+ text_vector = {text: [random.random() for _ in range(768)]}
+ return text_vector
+
+ def _get_embedding_array(self, text) -> Dict[str, List[bytes]]:
+ text_vector = self._get_embedding(text)
+ return np.array(text_vector[text]).\
+ astype(dtype=np.float32).tobytes()
+
+ def _get_model(self, ):
+ if isinstance(self.llm_config, LLMConfig):
+ model = getChatModelFromConfig(self.llm_config)
+ else:
+ model = get_model(self.llm_config)
+ return model
\ No newline at end of file
diff --git a/muagent/models/__init__.py b/muagent/models/__init__.py
new file mode 100644
index 0000000..b97b2d5
--- /dev/null
+++ b/muagent/models/__init__.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+""" Import modules in models package."""
+from typing import Type
+from loguru import logger
+
+from ..schemas.models import ModelConfig
+from .base_model import ModelWrapperBase
+from .openai_model import (
+ OpenAIWrapperBase,
+ OpenAIChatWrapper,
+ # OpenAIDALLEWrapper,
+ OpenAIEmbeddingWrapper,
+)
+from .dashscope_model import (
+ DashScopeChatWrapper,
+ # DashScopeImageSynthesisWrapper,
+ DashScopeTextEmbeddingWrapper,
+ # DashScopeMultiModalWrapper,
+)
+from .ollama_model import (
+ OllamaChatWrapper,
+ OllamaEmbeddingWrapper,
+ # OllamaGenerationWrapper,
+)
+from .qwen_model import (
+ QwenChatWrapper,
+ QwenTextEmbeddingWrapper
+)
+from .kimi_model import (
+ KimiChatWrapper,
+ KimiEmbeddingWrapper
+)
+# from .gemini_model import (
+# GeminiChatWrapper,
+# GeminiEmbeddingWrapper,
+# )
+# from .zhipu_model import (
+# ZhipuAIChatWrapper,
+# ZhipuAIEmbeddingWrapper,
+# )
+# from .litellm_model import (
+# LiteLLMChatWrapper,
+# )
+from .yi_model import (
+ YiChatWrapper,
+)
+
+__all__ = [
+ "ModelWrapperBase",
+ "ModelResponse",
+ "PostAPIModelWrapperBase",
+ "PostAPIChatWrapper",
+ "OpenAIWrapperBase",
+ "OpenAIChatWrapper",
+ "OpenAIDALLEWrapper",
+ "OpenAIEmbeddingWrapper",
+ "DashScopeChatWrapper",
+ "DashScopeImageSynthesisWrapper",
+ "DashScopeTextEmbeddingWrapper",
+ "DashScopeMultiModalWrapper",
+ "OllamaChatWrapper",
+ "OllamaEmbeddingWrapper",
+ "OllamaGenerationWrapper",
+ "GeminiChatWrapper",
+ "GeminiEmbeddingWrapper",
+ "ZhipuAIChatWrapper",
+ "ZhipuAIEmbeddingWrapper",
+ "LiteLLMChatWrapper",
+ "YiChatWrapper",
+ "QwenChatWrapper",
+ "QwenTextEmbeddingWrapper",
+ "KimiChatWrapper",
+ "KimiEmbeddingWrapper"
+]
+
+
+def _get_model_wrapper(model_type: str) -> Type[ModelWrapperBase]:
+ """Get the specific type of model wrapper
+
+ Args:
+ model_type (`str`): The model type name.
+
+ Returns:
+ `Type[ModelWrapperBase]`: The corresponding model wrapper class.
+ """
+ wrapper = ModelWrapperBase.get_wrapper(model_type=model_type)
+ if wrapper is None:
+ raise KeyError(
+ f"Unsupported model_type [{model_type}],"
+ "use PostApiModelWrapper instead.",
+ )
+ return wrapper
+
+
+def get_model(model_config: ModelConfig) -> ModelWrapperBase:
+ """Get the model by model config
+
+ Args:
+ model_config (`ModelConfig`): The model config
+
+ Returns:
+ `ModelWrapperBase`: The specific model
+ """
+ return ModelWrapperBase.from_config(model_config)
\ No newline at end of file
diff --git a/muagent/models/base_model.py b/muagent/models/base_model.py
new file mode 100644
index 0000000..7231839
--- /dev/null
+++ b/muagent/models/base_model.py
@@ -0,0 +1,504 @@
+"""
+The implementation of this _ModelWrapperMeta are borrowed from
+https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/model.py
+"""
+
+
+from __future__ import annotations
+from abc import ABCMeta, abstractmethod
+from typing import (
+ Any,
+ Optional,
+ Type,
+ Union,
+ Sequence,
+ List,
+ Generator,
+ Literal,
+ Mapping
+)
+from loguru import logger
+from openai.types.chat import ChatCompletion, ChatCompletionChunk
+
+from muagent.schemas import Message, Memory
+from muagent.schemas.models import (
+ ModelConfig,
+)
+from muagent.utils.common_utils import _convert_to_str
+
+
+class _ModelWrapperMeta(ABCMeta):
+ """A meta call to replace the model wrapper's __call__ function with
+ wrapper about error handling."""
+
+ def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any:
+ if "__call__" in attrs:
+ attrs["__call__"] = attrs["__call__"]
+ return super().__new__(mcs, name, bases, attrs)
+
+ def __init__(cls, name: Any, bases: Any, attrs: Any) -> None:
+ if not hasattr(cls, "_registry"):
+ cls._registry = {}
+ cls._type_registry = {}
+ cls._deprecated_type_registry = {}
+ else:
+ cls._registry[name] = cls
+ if hasattr(cls, "model_type"):
+ cls._type_registry[cls.model_type] = cls
+ if hasattr(cls, "deprecated_model_type"):
+ cls._deprecated_type_registry[
+ cls.deprecated_model_type
+ ] = cls
+ super().__init__(name, bases, attrs)
+
+
+class ModelWrapperBase(metaclass=_ModelWrapperMeta):
+ """The base class for model wrapper."""
+
+ model_type: str
+ """The type of the model wrapper, which is to identify the model wrapper
+ class in model configuration."""
+
+ config_name: str
+ """The name of the model configuration."""
+
+ model_name: str
+ """The name of the model, which is used in model api calling."""
+
+ api_key: Optional[str] = None
+ """The api key of the model, which is used in model api calling."""
+
+ api_url: Optional[str] = None
+ """The api url of the model, which is used in model api calling."""
+
+ def __init__(
+ self, # pylint: disable=W0613
+ config_name: str,
+ model_name: str,
+ model_type: str = "codefuse",
+ api_key: Optional[str] = "model_base_xxx",
+ api_url: Optional[str]="https://codefuse.ai",
+ **kwargs: Any,
+ ) -> None:
+ """Base class for model wrapper.
+
+ All model wrappers should inherit this class and implement the
+ `__call__` function.
+
+ Args:
+ config_name (`str`):
+ The id of the model, which is used to extract configuration
+ from the config file.
+ model_name (`str`):
+ The name of the model.
+ api_key (`str`):
+ The api key of the model.
+ api_url (`str`):
+ The api url of the model.
+ model_type (`str`):
+ The type of the model wrapper.
+ """
+ self.config_name = config_name
+ self.model_name = model_name
+ self.api_key = api_key
+ self.api_url = api_url
+ self.model_type = model_type
+ # logger.info(f"Initialize model by configuration [{config_name}]")
+
+ @classmethod
+ def from_config(self, model_config: ModelConfig) -> 'ModelWrapperBase':
+ model_config_dict = model_config.dict()
+ model_type = model_config_dict.pop("model_type")
+ return self.get_wrapper(model_type)(**model_config_dict)
+
+ @classmethod
+ def get_wrapper(cls, model_type: str) -> Type[ModelWrapperBase]:
+ """Get the specific model wrapper"""
+ if model_type in cls._type_registry:
+ return cls._type_registry[model_type] # type: ignore[return-value]
+ elif model_type in cls._registry:
+ return cls._registry[model_type] # type: ignore[return-value]
+ elif model_type in cls._deprecated_type_registry:
+ deprecated_cls = cls._deprecated_type_registry[model_type]
+ logger.warning(
+ f"Model type [{model_type}] will be deprecated in future "
+ f"releases, please use [{deprecated_cls.model_type}] instead.",
+ )
+ return deprecated_cls # type: ignore[return-value]
+ else:
+ raise KeyError(
+ f"Unsupported model_type [{model_type}],"
+ "use PostApiModelWrapper instead.",
+ )
+
+ def __call__(
+ self,
+ prompt: str = None,
+ messages: Sequence[dict] = [],
+ tools: Sequence[object] = [],
+ *,
+ tool_choice: Optional[Literal['auto', 'required']] = None,
+ parallel_tool_calls: Optional[bool] = None,
+ stream: bool = None,
+ stop: Optional[str] = '',
+ format_type: Literal["str", "dict", "raw"] = "str",
+ **kwargs: Any,
+ ) -> Generator[Union[ChatCompletion, ChatCompletionChunk, str, Mapping], None, None]:
+ """Process input with the model.
+
+ Args:
+ prompt (str, optional): The prompt string to provide to the model.
+ messages (Sequence[dict], optional): A sequence of messages for conversation context.
+ tools (Sequence[object], optional): Tools that can be utilized in the processing.
+ tool_choice (Optional[Literal['auto', 'required']], optional): Determining how to select tools.
+ parallel_tool_calls (Optional[bool], optional): If true, allows parallel calls to tools.
+ stream (bool, optional): If true, the output is streamed rather than returned all at once.
+ stop (Optional[str], optional): Token to signify stopping generation.
+ format_type (Literal["str", "dict", "raw"], optional): The format of the output.
+ **kwargs: Additional keyword arguments for extensibility.
+
+ Returns:
+ Generator[Union[ChatCompletion, ChatCompletionChunk, str, Mapping], None, None]:
+ A generator yielding completion responses from the model.
+ """
+ raise NotImplementedError(
+ f"Model Wrapper [{type(self).__name__}]"
+ f" is missing the required `__call__`"
+ f" method.",
+ )
+
+ def predict(
+ self,
+ prompt: str,
+ stop: Optional[str] = '',
+ ) -> Union[ChatCompletion, str]:
+ """Generate a prediction based on the provided prompt.
+
+ Args:
+ prompt (str): The input prompt for prediction.
+ stop (Optional[str], optional): Token to signify stopping generation.
+
+ Returns:
+ Union[ChatCompletion, str]: The model's prediction in the specified format.
+ """
+ return self.generate(prompt, stop, "str")
+
+ def generate(
+ self,
+ prompt: str,
+ stop: Optional[str] = '',
+ format_type: Literal["str", "raw"] = "raw",
+ ) -> Union[ChatCompletion, str]:
+ """Generate a response by calling the model.
+
+ Args:
+ prompt (str): The input prompt.
+ stop (Optional[str], optional): Token to signify stopping generation.
+ format_type (Literal["str", "raw"], optional): The format of the output.
+
+ Returns:
+ Union[ChatCompletion, str]: The generated response from the model.
+ """
+ for i in self.__call__(prompt, stop=stop, stream=False, format_type=format_type):
+ pass
+ return i
+
+ def generate_stream(self,
+ prompt: str,
+ stop: Optional[str] = '',
+ format_type: Literal["str", "raw"] = "raw",
+ ) -> Generator[Union[ChatCompletionChunk, str], None, None]:
+ """Stream the generated response from the model.
+
+ Args:
+ prompt (str): The input prompt.
+ stop (Optional[str], optional): Token to signify stopping generation.
+ format_type (Literal["str", "raw"], optional): The format of the output.
+
+ Yields:
+ Generator[Union[ChatCompletionChunk, str], None, None]: A generator yielding parts of the response.
+ """
+ for i in self.__call__(prompt, stop=stop, stream=True, format_type=format_type):
+ yield i
+
+ def chat(self,
+ messages: Optional[Sequence[dict]],
+ stop: Optional[str] = '',
+ format_type: Literal["str", "raw"] = "raw",
+ ) -> Union[ChatCompletion, str]:
+ """Process a chat message input and return the model's response.
+
+ Args:
+ messages (Optional[Sequence[dict]]): A sequence of messages for conversation context.
+ stop (Optional[str], optional): Token to signify stopping generation.
+ format_type (Literal["str", "raw"], optional): The format of the output.
+
+ Returns:
+ Union[ChatCompletion, str]: The model's chat response in the specified format.
+ """
+ for i in self.__call__(None, messages, stop=stop, stream=False, format_type=format_type):
+ return i
+
+ def chat_stream(self,
+ messages: Optional[Sequence[dict]],
+ stop: Optional[str] = '',
+ format_type: Literal["str", "raw"] = "raw",
+ ) -> Generator[Union[ChatCompletionChunk, str], None, None]:
+ """Stream chat responses from the model.
+
+ Args:
+ messages (Optional[Sequence[dict]]): A sequence of messages for conversation context.
+ stop (Optional[str], optional): Token to signify stopping generation.
+ format_type (Literal["str", "raw"], optional): The format of the output.
+
+ Yields:
+ Generator[Union[ChatCompletionChunk, str], None, None]: A generator yielding parts of the chat response.
+ """
+ for i in self.__call__(None, messages, stop=stop, stream=True, format_type=format_type):
+ yield i
+
+ def function_call(
+ self,
+ messages: Optional[Sequence[dict]] = None,
+ tools: Sequence[object] = [],
+ *,
+ prompt: Optional[str] = None,
+ tool_choice: Optional[Literal['auto', 'required']] = None,
+ parallel_tool_calls: Optional[bool] = None,
+ stream: Optional[bool] = False,
+ stop: Optional[str] = '',
+ format_type: Literal["raw"] = "raw",
+ ) -> Union[ChatCompletion, Mapping]:
+ """Call a function to process messages with optional tools.
+
+ Args:
+ messages (Optional[Sequence[dict]], optional): A sequence of messages for context.
+ tools (Sequence[object], optional): Tools available for use.
+ prompt (Optional[str], optional): An optional prompt.
+ tool_choice (Optional[Literal['auto', 'required']], optional): How to select tools.
+ parallel_tool_calls (Optional[bool], optional): If true, allows parallel tool calls.
+ stream (Optional[bool], optional): If true, streams the output instead of returning it all at once.
+ stop (Optional[str], optional): Token to signify stopping generation.
+ format_type (Literal["raw"], optional): Specifies to return the output in raw format.
+
+ Returns:
+ Union[ChatCompletion, Mapping]: The result of the function call processed by the model.
+ """
+ kwargs = locals()
+ kwargs.pop("self")
+ for i in self.__call__(**kwargs):
+ pass
+ return i
+
+ def function_call_stream(
+ self,
+ messages: Optional[Sequence[dict]] = None,
+ tools: Sequence[object] = [],
+ *,
+ prompt: Optional[str] = None,
+ tool_choice: Optional[Literal['auto', 'required']] = 'auto',
+ parallel_tool_calls: Optional[bool] = None,
+ stream: Optional[bool] = True,
+ stop: Optional[str] = '',
+ format_type: Literal["raw"] = "raw",
+ ) -> Generator[Union[ChatCompletionChunk, Mapping], None, None]:
+ """Stream function call outputs.
+
+ Args:
+ messages (Optional[Sequence[dict]], optional): A sequence of messages for context.
+ tools (Sequence[object], optional): Tools available for use.
+ prompt (Optional[str], optional): An optional prompt.
+ tool_choice (Optional[Literal['auto', 'required']], optional): How to select tools.
+ parallel_tool_calls (Optional[bool], optional): If true, allows parallel tool calls.
+ stream (Optional[bool], optional): If true, streams the output.
+ stop (Optional[str], optional): Token to signify stopping generation.
+ format_type (Literal["raw"], optional): Specifies to return output in raw format.
+
+ Yields:
+ Generator[Union[ChatCompletionChunk, Mapping], None, None]: A generator yielding parts of the function output.
+ """
+ kwargs = locals()
+ kwargs.pop("self")
+ for i in self.__call__(**kwargs):
+ yield i
+
+ def batch(self, *args: Any, **kwargs: Any) -> List[ChatCompletion]:
+ """Process batch inputs with the model.
+
+ This method should be implemented by subclasses.
+
+ Raises:
+ NotImplementedError: If not implemented in subclass.
+ """
+ raise NotImplementedError(
+ f"Model Wrapper [{type(self).__name__}]"
+ f" is missing the required `batch`"
+ f" method.",
+ )
+
+ def embed_query(self, text: str) -> List[float]:
+ """Embed a query into a vector representation.
+
+ This method should be implemented by subclasses.
+
+ Args:
+ text (str): The text to embed.
+
+ Raises:
+ NotImplementedError: If not implemented in subclass.
+ """
+ raise NotImplementedError(
+ f"Model Wrapper [{type(self).__name__}]"
+ f" is missing the required `embed_query`"
+ f" method.",
+ )
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ """Embed a list of documents into vector representations.
+
+ This method should be implemented by subclasses.
+
+ Args:
+ texts (List[str]): The list of texts to embed.
+
+ Raises:
+ NotImplementedError: If not implemented in subclass.
+ """
+ raise NotImplementedError(
+ f"Model Wrapper [{type(self).__name__}]"
+ f" is missing the required `embed_documents`"
+ f" method.",
+ )
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> Union[List[dict], str]:
+ """Format the input messages into the format that the model
+ API required."""
+ raise NotImplementedError(
+ f"Model Wrapper [{type(self).__name__}]"
+ f" is missing the required `format` method",
+ )
+
+ @staticmethod
+ def format_for_common_chat_models(
+ *args: Union[Message, Sequence[Message]],
+ ) -> List[dict]:
+ """A common format strategy for chat models, which will format the
+ input messages into a system message (if provided) and a user message.
+
+ Note this strategy maybe not suitable for all scenarios,
+ and developers are encouraged to implement their own prompt
+ engineering strategies.
+
+ The following is an example:
+
+ .. code-block:: python
+
+ prompt1 = model.format(
+ Message("system", "You're a helpful assistant", role="system"),
+ Message("Bob", "Hi, how can I help you?", role="assistant"),
+ Message("user", "What's the date today?", role="user")
+ )
+
+ prompt2 = model.format(
+ Message("Bob", "Hi, how can I help you?", role="assistant"),
+ Message("user", "What's the date today?", role="user")
+ )
+
+ The prompt will be as follows:
+
+ .. code-block:: python
+
+ # prompt1
+ [
+ {
+ "role": "system",
+ "content": "You're a helpful assistant"
+ },
+ {
+ "role": "user",
+ "content": (
+ "## Conversation History\\n"
+ "Bob: Hi, how can I help you?\\n"
+ "user: What's the date today?"
+ )
+ }
+ ]
+
+ # prompt2
+ [
+ {
+ "role": "user",
+ "content": (
+ "## Conversation History\\n"
+ "Bob: Hi, how can I help you?\\n"
+ "user: What's the date today?"
+ )
+ }
+ ]
+
+
+ Args:
+ args (`Union[Message, Sequence[Message]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Message` object, or a list of `Message` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `List[dict]`:
+ The formatted messages.
+ """
+ if len(args) == 0:
+ raise ValueError(
+ "At least one message should be provided. An empty message "
+ "list is not allowed.",
+ )
+
+ # Parse all information into a list of messages
+ input_Messages = []
+ for _ in args:
+ if _ is None:
+ continue
+ if isinstance(_, Message):
+ input_Messages.append(_)
+ elif isinstance(_, list) and all(isinstance(__, Message) for __ in _):
+ input_Messages.extend(_)
+ else:
+ raise TypeError(
+ f"The input should be a Message object or a list "
+ f"of Message objects, got {type(_)}.",
+ )
+
+ # record dialog history as a list of strings
+ dialogue = []
+ sys_prompt = None
+ for i, unit in enumerate(input_Messages):
+ if i == 0 and unit.role == "system":
+ # if system prompt is available, place it at the beginning
+ sys_prompt = _convert_to_str(unit.content)
+ else:
+ # Merge all messages into a conversation history prompt
+ dialogue.append(
+ f"{unit.name}: {_convert_to_str(unit.content)}",
+ )
+
+ content_components = []
+
+ # The conversation history is added to the user message if not empty
+ if len(dialogue) > 0:
+ content_components.extend(["## Conversation History"] + dialogue)
+
+ messages = [
+ {
+ "role": "user",
+ "content": "\n".join(content_components),
+ },
+ ]
+
+ # Add system prompt at the beginning if provided
+ if sys_prompt is not None:
+ messages = [{"role": "system", "content": sys_prompt}] + messages
+
+ return messages
\ No newline at end of file
diff --git a/muagent/models/dashscope_model.py b/muagent/models/dashscope_model.py
new file mode 100644
index 0000000..4aca149
--- /dev/null
+++ b/muagent/models/dashscope_model.py
@@ -0,0 +1,514 @@
+# -*- coding: utf-8 -*-
+"""Model wrapper for DashScope models"""
+import os
+from abc import ABC
+from http import HTTPStatus
+from typing import (
+ Any,
+ Union,
+ List,
+ Sequence,
+ Optional,
+ Generator,
+ Literal
+)
+
+from loguru import logger
+
+from ..schemas import Message
+
+try:
+ import dashscope
+
+ dashscope_version = dashscope.version.__version__
+ if dashscope_version < "1.19.0":
+ logger.warning(
+ f"You are using 'dashscope' version {dashscope_version}, "
+ "which is below the recommended version 1.19.0. "
+ "Please consider upgrading to maintain compatibility.",
+ )
+ from dashscope.api_entities.dashscope_response import GenerationResponse
+except ImportError:
+ dashscope = None
+ GenerationResponse = None
+
+from .base_model import ModelWrapperBase
+from ..utils.common_utils import _convert_to_str
+
+
+
+class DashScopeWrapperBase(ModelWrapperBase, ABC):
+ """The model wrapper for DashScope API."""
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str = None,
+ api_key: str = None,
+ generate_args: dict = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the DashScope wrapper.
+
+ Args:
+ config_name (`str`):
+ The name of the model config.
+ model_name (`str`, default `None`):
+ The name of the model to use in DashScope API.
+ api_key (`str`, default `None`):
+ The API key for DashScope API.
+ generate_args (`dict`, default `None`):
+ The extra keyword arguments used in DashScope api generation,
+ e.g. `temperature`, `seed`.
+ """
+ if model_name is None:
+ model_name = config_name
+ logger.warning("model_name is not set, use config_name instead.")
+
+ super().__init__(config_name=config_name, model_name=model_name)
+
+ if dashscope is None:
+ raise ImportError(
+ "The package 'dashscope' is not installed. Please install it "
+ "by running `pip install dashscope>=1.19.0`",
+ )
+
+ self.generate_args = generate_args or {}
+
+ self.api_key = api_key
+ self.max_length = None
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> Union[List[dict], str]:
+ raise RuntimeError(
+ f"Model Wrapper [{type(self).__name__}] doesn't "
+ f"need to format the input. Please try to use the "
+ f"model wrapper directly.",
+ )
+
+
+class DashScopeChatWrapper(DashScopeWrapperBase):
+ """The model wrapper for DashScope's chat API, refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/api-details
+
+ Response:
+ - Refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/quick-start?spm=a2c4g.11186623.0.0.7e346eb5RvirBw
+
+ ```json
+ {
+ "status_code": 200,
+ "request_id": "a75a1b22-e512-957d-891b-37db858ae738",
+ "code": "",
+ "message": "",
+ "output": {
+ "text": null,
+ "finish_reason": null,
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "message": {
+ "role": "assistant",
+ "content": "xxx"
+ }
+ }
+ ]
+ },
+ "usage": {
+ "input_tokens": 25,
+ "output_tokens": 77,
+ "total_tokens": 102
+ }
+ }
+ ```
+ """
+
+ model_type: str = "dashscope_chat"
+
+ deprecated_model_type: str = "tongyi_chat"
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str = None,
+ api_key: str = None,
+ stream: bool = False,
+ generate_args: dict = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the DashScope wrapper.
+
+ Args:
+ config_name (`str`):
+ The name of the model config.
+ model_name (`str`, default `None`):
+ The name of the model to use in DashScope API.
+ api_key (`str`, default `None`):
+ The API key for DashScope API.
+ stream (`bool`, default `False`):
+ If True, the response will be a generator in the `stream`
+ field of the returned `ModelResponse` object.
+ generate_args (`dict`, default `None`):
+ The extra keyword arguments used in DashScope api generation,
+ e.g. `temperature`, `seed`.
+ """
+
+ super().__init__(
+ config_name=config_name,
+ model_name=model_name,
+ api_key=api_key,
+ generate_args=generate_args,
+ **kwargs,
+ )
+
+ self.stream = stream
+
+ def __call__(
+ self,
+ prompt: str = None,
+ messages: Sequence[dict] = [],
+ tools: Sequence[object] = [],
+ *,
+ tool_choice: Optional[Literal['auto', 'required']] = None,
+ parallel_tool_calls: Optional[bool] = None,
+ stop: Optional[str] = '',
+ stream: Optional[bool] = None,
+ format_type: Literal['str', 'raw', 'dict'] = 'raw',
+ **kwargs: Any,
+ ) -> Generator:
+ """Processes a list of messages to construct a payload for the
+ DashScope API call. It then makes a request to the DashScope API
+ and returns the response. This method also updates monitoring
+ metrics based on the API response.
+
+ Each message in the 'messages' list can contain text content and
+ optionally an 'image_urls' key. If 'image_urls' is provided,
+ it is expected to be a list of strings representing URLs to images.
+ These URLs will be transformed to a suitable format for the DashScope
+ API, which might involve converting local file paths to data URIs.
+
+ Args:
+ messages (`list`):
+ A list of messages to process.
+ stream (`Optional[bool]`, default `None`):
+ The stream flag to control the response format, which will
+ overwrite the stream flag in the constructor.
+ **kwargs (`Any`):
+ The keyword arguments to DashScope chat completions API,
+ e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
+ refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/api-details
+ for more detailed arguments.
+
+ Returns:
+ `ModelResponse`:
+ A response object with the response text in text field, and
+ the raw response in raw field. If stream is True, the response
+ will be a generator in the `stream` field.
+
+ Note:
+ `parse_func`, `fault_handler` and `max_retries` are reserved for
+ `_response_parse_decorator` to parse and check the response
+ generated by model wrapper. Their usages are listed as follows:
+ - `parse_func` is a callable function used to parse and check
+ the response generated by the model, which takes the response
+ as input.
+ - `max_retries` is the maximum number of retries when the
+ `parse_func` raise an exception.
+ - `fault_handler` is a callable function which is called
+ when the response generated by the model is invalid after
+ `max_retries` retries.
+ The rule of roles in messages for DashScope is very rigid,
+ for more details, please refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/api-details
+ """
+
+ messages = [{"role": "user", "content": prompt}] if prompt else messages
+ # step1: prepare keyword arguments
+ kwargs = {**self.generate_args, **kwargs}
+
+ # step2: checking messages
+ if not isinstance(messages, list):
+ raise ValueError(
+ "Dashscope `messages` field expected type `list`, "
+ f"got `{type(messages)}` instead.",
+ )
+ if not all("role" in msg and "content" in msg for msg in messages):
+ raise ValueError(
+ "Each message in the 'messages' list must contain a 'role' "
+ "and 'content' key for DashScope API.",
+ )
+
+ # step3: forward to generate response
+ if stream is None:
+ stream = self.stream
+
+ kwargs.update(
+ {
+ "model": self.model_name,
+ "messages": messages,
+ # Set the result to be "message" format.
+ "result_format": "message",
+ "stream": stream,
+ "tools": tools,
+ "stop": stop,
+ },
+ )
+
+ # Switch to the incremental_output mode
+ if stream:
+ kwargs["incremental_output"] = True
+
+ response = dashscope.Generation.call(api_key=self.api_key, **kwargs)
+
+ # step3: invoke llm api, record the invocation and update the monitor
+ if format_type == "str":
+ content = ""
+ if stream:
+ for chunk in response:
+ content += chunk["output"]["choices"][0]["message"]["content"] or ''
+ yield content
+ else:
+ yield response["output"]["choices"][0]["message"]["content"]
+ else:
+ if stream:
+ for chunk in response:
+ yield chunk
+ else:
+ yield response
+
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> List[dict]:
+ """A common format strategy for chat models, which will format the
+ input messages into a user message.
+
+ Note this strategy maybe not suitable for all scenarios,
+ and developers are encouraged to implement their own prompt
+ engineering strategies.
+
+ The following is an example:
+
+ .. code-block:: python
+
+ prompt1 = model.format(
+ Message("system", "You're a helpful assistant", role="system"),
+ Message("Bob", "Hi, how can I help you?", role="assistant"),
+ Message("user", "What's the date today?", role="user")
+ )
+
+ prompt2 = model.format(
+ Message("Bob", "Hi, how can I help you?", role="assistant"),
+ Message("user", "What's the date today?", role="user")
+ )
+
+ The prompt will be as follows:
+
+ .. code-block:: python
+
+ # prompt1
+ [
+ {
+ "role": "system",
+ "content": "You're a helpful assistant"
+ },
+ {
+ "role": "user",
+ "content": (
+ "## Conversation History\\n"
+ "Bob: Hi, how can I help you?\\n"
+ "user: What's the date today?"
+ )
+ }
+ ]
+
+ # prompt2
+ [
+ {
+ "role": "user",
+ "content": (
+ "## Conversation History\\n"
+ "Bob: Hi, how can I help you?\\n"
+ "user: What's the date today?"
+ )
+ }
+ ]
+
+
+ Args:
+ args (`Union[Msg, Sequence[Msg]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Msg` object, or a list of `Msg` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `List[dict]`:
+ The formatted messages.
+ """
+
+ return ModelWrapperBase.format_for_common_chat_models(*args)
+
+
+ def format_prompt(self, *args: Union[Message, Sequence[Message]]) -> str:
+ """Forward the input to the model.
+
+ Args:
+ args (`Union[Msg, Sequence[Msg]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Msg` object, or a list of `Msg` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `str`:
+ The formatted string prompt.
+ """
+ input_msgs: List[Message] = []
+ for _ in args:
+ if _ is None:
+ continue
+ if isinstance(_, Message):
+ input_msgs.append(_)
+ elif isinstance(_, list) and all(isinstance(__, Message) for __ in _):
+ input_msgs.extend(_)
+ else:
+ raise TypeError(
+ f"The input should be a Msg object or a list "
+ f"of Msg objects, got {type(_)}.",
+ )
+
+ sys_prompt = None
+ dialogue = []
+ for i, unit in enumerate(input_msgs):
+ if i == 0 and unit.role_type == "system":
+ # system prompt
+ sys_prompt = unit.content
+ else:
+ # Merge all messages into a conversation history prompt
+ dialogue.append(
+ f"{unit.role_name}: {unit.content}",
+ )
+
+ dialogue_history = "\n".join(dialogue)
+
+ if sys_prompt is None:
+ prompt_template = "## Conversation History\n{dialogue_history}"
+ else:
+ prompt_template = (
+ "{system_prompt}\n"
+ "\n"
+ "## Conversation History\n"
+ "{dialogue_history}"
+ )
+
+ return prompt_template.format(
+ system_prompt=sys_prompt,
+ dialogue_history=dialogue_history,
+ )
+
+class DashScopeTextEmbeddingWrapper(DashScopeWrapperBase):
+ """The model wrapper for DashScope Text Embedding API.
+
+ Response:
+ - Refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details?spm=a2c4g.11186623.0.i3
+
+ ```json
+ {
+ "status_code": 200, // 200 indicate success otherwise failed.
+ "request_id": "fd564688-43f7-9595-b986", // The request id.
+ "code": "", // If failed, the error code.
+ "message": "", // If failed, the error message.
+ "output": {
+ "embeddings": [ // embeddings
+ {
+ "embedding": [ // one embedding output
+ -3.8450357913970947, ...,
+ ],
+ "text_index": 0 // the input index.
+ }
+ ]
+ },
+ "usage": {
+ "total_tokens": 3 // the request tokens.
+ }
+ }
+ ```
+ """
+
+ model_type: str = "dashscope_text_embedding"
+
+ def __call__(
+ self,
+ texts: Union[list[str], str],
+ dimension: Literal[512, 768, 1024, 1536] = 768,
+ **kwargs: Any,
+ ):
+ """Embed the messages with DashScope Text Embedding API.
+
+ Args:
+ texts (`list[str]` or `str`):
+ The messages used to embed.
+ **kwargs (`Any`):
+ The keyword arguments to DashScope Text Embedding API,
+ e.g. `text_type`. Please refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/api-details-15
+ for more detailed arguments.
+
+ Returns:
+ `ModelResponse`:
+ A list of embeddings in embedding field and the raw
+ response in raw field.
+
+ Note:
+ `parse_func`, `fault_handler` and `max_retries` are reserved
+ for `_response_parse_decorator` to parse and check the response
+ generated by model wrapper. Their usages are listed as follows:
+ - `parse_func` is a callable function used to parse and
+ check the response generated by the model, which takes the
+ response as input.
+ - `max_retries` is the maximum number of retries when the
+ `parse_func` raise an exception.
+ - `fault_handler` is a callable function which is called
+ when the response generated by the model is invalid after
+ `max_retries` retries.
+ """
+ # step1: prepare keyword arguments
+ kwargs = {**self.generate_args, **kwargs}
+
+ # step2: forward to generate response
+ response = dashscope.TextEmbedding.call(
+ input=texts,
+ model=self.model_name,
+ api_key=self.api_key,
+ dimension=dimension,
+ **kwargs,
+ )
+
+ if response.status_code != HTTPStatus.OK:
+ error_msg = (
+ f" Request id: {response.request_id},"
+ f" Status code: {response.status_code},"
+ f" error code: {response.code},"
+ f" error message: {response.message}."
+ )
+ raise RuntimeError(error_msg)
+
+ # step5: return response
+ return response
+
+ def embed_query(self, text: str) -> List[float]:
+ response = self([text])
+ output = response["output"]
+ embeddings = output["embeddings"]
+ return embeddings[0]["embedding"]
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ response = self(texts)
+ output = response["output"]
+ embeddings = output["embeddings"]
+ return [emb["embedding"] for emb in embeddings]
+
\ No newline at end of file
diff --git a/muagent/models/kimi_model.py b/muagent/models/kimi_model.py
new file mode 100644
index 0000000..e2ee93c
--- /dev/null
+++ b/muagent/models/kimi_model.py
@@ -0,0 +1,372 @@
+# -*- coding: utf-8 -*-
+"""Model wrapper for OpenAI models
+The implementation of this _ModelWrapperMeta are borrowed from
+https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py
+"""
+
+
+from abc import ABC
+from typing import (
+ Union,
+ Any,
+ List,
+ Sequence,
+ Dict,
+ Optional,
+ Generator,
+ Literal
+)
+from urllib.parse import urlparse
+import os
+import base64
+from loguru import logger
+try:
+ import openai
+except ImportError as e:
+ raise ImportError(
+ "Cannot find openai package, please install it by "
+ "`pip install openai`",
+ ) from e
+
+from openai.types.chat import ChatCompletion, ChatCompletionChunk
+from openai.types import CreateEmbeddingResponse
+from .base_model import ModelWrapperBase
+from ..schemas import Message
+
+
+
+class KimiWrapperBase(ModelWrapperBase, ABC):
+ """The model wrapper for OpenAI API.
+
+ Response:
+ - From https://platform.moonshot.cn/docs/intro
+
+ ```json
+ {
+ "id": "chatcmpl-123",
+ "object": "chat.completion",
+ "created": 1677652288,
+ "model": "gpt-4o-mini",
+ "system_fingerprint": "fp_44709d6fcb",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "Hello there, how may I assist you today?",
+ },
+ "logprobs": null,
+ "finish_reason": "stop"
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 9,
+ "completion_tokens": 12,
+ "total_tokens": 21
+ }
+ }
+ ```
+ """
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str = None,
+ api_key: str = None,
+ api_url: str = "https://api.moonshot.cn/v1",
+ generate_args: dict = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the openai client.
+
+ Args:
+ config_name (`str`):
+ The name of the model config.
+ model_name (`str`, default `None`):
+ The name of the model to use in OpenAI API.
+ api_key (`str`, default `None`):
+ The API key for OpenAI API. If not specified, it will
+ be read from the environment variable `OPENAI_API_KEY`.
+ organization (`str`, default `None`):
+ The organization ID for OpenAI API. If not specified, it will
+ be read from the environment variable `OPENAI_ORGANIZATION`.
+ client_args (`dict`, default `None`):
+ The extra keyword arguments to initialize the OpenAI client.
+ generate_args (`dict`, default `None`):
+ The extra keyword arguments used in openai api generation,
+ e.g. `temperature`, `seed`.
+ """
+
+ if model_name is None:
+ model_name = config_name
+ logger.warning("model_name is not set, use config_name instead.")
+
+ init_params = locals()
+ init_params.pop("self")
+ init_params["model_type"] = self.model_type
+ super().__init__(**init_params)
+ # super().__init__(config_name=config_name, model_name=model_name)
+
+ self.generate_args = generate_args or {}
+ self.api_url = api_url or "https://api.moonshot.cn/v1"
+ self.client = openai.OpenAI(api_key=api_key, base_url=self.api_url,)
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> Union[List[dict], str]:
+ raise RuntimeError(
+ f"Model Wrapper [{type(self).__name__}] doesn't "
+ f"need to format the input. Please try to use the "
+ f"model wrapper directly.",
+ )
+
+
+class KimiChatWrapper(KimiWrapperBase):
+ """The model wrapper for OpenAI's chat API."""
+
+ model_type: str = "moonshot_chat"
+
+ substrings_in_vision_models_names = ["gpt-4-turbo", "vision", "gpt-4o"]
+ """The substrings in the model names of vision models."""
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str = None,
+ api_key: str = None,
+ api_url: str = "https://api.moonshot.cn/v1",
+ stream: bool = False,
+ generate_args: dict = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the openai client.
+
+ Args:
+ config_name (`str`):
+ The name of the model config.
+ model_name (`str`, default `None`):
+ The name of the model to use in OpenAI API.
+ api_key (`str`, default `None`):
+ The API key for OpenAI API. If not specified, it will
+ be read from the environment variable `OPENAI_API_KEY`.
+ organization (`str`, default `None`):
+ The organization ID for OpenAI API. If not specified, it will
+ be read from the environment variable `OPENAI_ORGANIZATION`.
+ client_args (`dict`, default `None`):
+ The extra keyword arguments to initialize the OpenAI client.
+ stream (`bool`, default `False`):
+ Whether to enable stream mode.
+ generate_args (`dict`, default `None`):
+ The extra keyword arguments used in openai api generation,
+ e.g. `temperature`, `seed`.
+ """
+
+ init_params = locals()
+ init_params.pop("self")
+ init_params["model_type"] = self.model_type
+ self.generate_args = generate_args
+ super().__init__(**init_params)
+ self.stream = stream
+
+ def __call__(
+ self,
+ prompt: str = None,
+ messages: Sequence[dict] = [],
+ tools: Sequence[object] = [],
+ *,
+ tool_choice: Optional[Literal['auto', 'required']] = None,
+ parallel_tool_calls: Optional[bool] = None,
+ stop: Optional[str] = '',
+ stream: bool = None,
+ format_type: Literal['str', 'raw', 'dict'] = 'raw',
+ **kwargs: Any,
+ ) -> Generator[Union[ChatCompletionChunk, ChatCompletion], None, None]:
+ """Processes a list of messages to construct a payload for the OpenAI
+ API call. It then makes a request to the OpenAI API and returns the
+ response. This method also updates monitoring metrics based on the
+ API response.
+
+ Each message in the 'messages' list can contain text content and
+ optionally an 'image_urls' key. If 'image_urls' is provided,
+ it is expected to be a list of strings representing URLs to images.
+ These URLs will be transformed to a suitable format for the OpenAI
+ API, which might involve converting local file paths to data URIs.
+
+ Args:
+ messages (`list`):
+ A list of messages to process.
+ stream (`Optional[bool]`, defaults to `None`)
+ Whether to enable stream mode, which will override the
+ `stream` argument in the constructor if provided.
+ **kwargs (`Any`):
+ The keyword arguments to OpenAI chat completions API,
+ e.g. `temperature`, `max_tokens`, `top_p`, etc. Please refer to
+ https://platform.openai.com/docs/api-reference/chat/create
+ for more detailed arguments.
+
+ Returns:
+ `ModelResponse`:
+ The response text in text field, and the raw response in
+ raw field.
+
+ Note:
+ `parse_func`, `fault_handler` and `max_retries` are reserved for
+ `_response_parse_decorator` to parse and check the response
+ generated by model wrapper. Their usages are listed as follows:
+ - `parse_func` is a callable function used to parse and check
+ the response generated by the model, which takes the response
+ as input.
+ - `max_retries` is the maximum number of retries when the
+ `parse_func` raise an exception.
+ - `fault_handler` is a callable function which is called
+ when the response generated by the model is invalid after
+ `max_retries` retries.
+ """
+
+ messages = [{"role": "user", "content": prompt}] if prompt else messages
+
+ # step1: prepare keyword arguments
+ kwargs = {**self.generate_args, **kwargs}
+
+ # step2: checking messages
+ if not isinstance(messages, list):
+ raise ValueError(
+ "Kimi `messages` field expected type `list`, "
+ f"got `{type(messages)}` instead.",
+ )
+ if not all("role" in Message and "content" in Message for Message in messages):
+ raise ValueError(
+ "Each message in the 'messages' list must contain a 'role' "
+ "and 'content' key for OpenAI API.",
+ )
+
+ # step3: forward to generate response
+ if stream is None:
+ stream = self.stream
+
+ kwargs.update(
+ {
+ "model": self.model_name,
+ "messages": messages,
+ "stream": stream,
+ "tools": tools,
+ "tool_choice": tool_choice,
+ "parallel_tool_calls": parallel_tool_calls,
+ "stop": stop
+ },
+ )
+
+ response = self.client.chat.completions.create(**kwargs)
+
+ if format_type == "str":
+ content = ""
+ if stream:
+ for chunk in response:
+ content += chunk.choices[0].delta.content or ''
+ yield content
+ else:
+ yield response.choices[0].message.content
+ else:
+ if stream:
+ for chunk in response:
+ yield chunk
+ else:
+ yield response
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> List[dict]:
+ """Format the input string and dictionary into the format that
+ OpenAI Chat API required. If you're using a OpenAI-compatible model
+ without a prefix "gpt-" in its name, the format method will
+ automatically format the input messages into the required format.
+
+ Args:
+ args (`Union[Message, Sequence[Message]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Message` object, or a list of `Message` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `List[dict]`:
+ The formatted messages in the format that OpenAI Chat API
+ required.
+ """
+
+ return ModelWrapperBase.format_for_common_chat_models(*args)
+
+
+class KimiEmbeddingWrapper(KimiWrapperBase):
+ """The model wrapper for OpenAI embedding API.
+
+ Response:
+ - Refer to
+ https://xx
+
+ ```json
+ {
+ "object": "list",
+ "data": [
+ {
+ "object": "embedding",
+ "embedding": [
+ 0.0023064255,
+ -0.009327292,
+ .... (1536 floats total for ada-002)
+ -0.0028842222,
+ ],
+ "index": 0
+ }
+ ],
+ "model": "text-embedding-ada-002",
+ "usage": {
+ "prompt_tokens": 8,
+ "total_tokens": 8
+ }
+ }
+ ```
+ """
+
+ model_type: str = "kimi_embedding"
+
+ def __call__(
+ self,
+ texts: Union[list[str], str],
+ **kwargs: Any,
+ ) -> CreateEmbeddingResponse:
+ """Embed the messages with OpenAI embedding API.
+
+ Args:
+ texts (`list[str]` or `str`):
+ The messages used to embed.
+ **kwargs (`Any`):
+ The keyword arguments to OpenAI embedding API,
+ e.g. `encoding_format`, `user`. Please refer to
+ https://platform.openai.com/docs/api-reference/embeddings
+ for more detailed arguments.
+
+ Returns:
+ `ModelResponse`:
+ A list of embeddings in embedding field and the
+ raw response in raw field.
+
+ Note:
+ `parse_func`, `fault_handler` and `max_retries` are reserved for
+ `_response_parse_decorator` to parse and check the response
+ generated by model wrapper. Their usages are listed as follows:
+ - `parse_func` is a callable function used to parse and check
+ the response generated by the model, which takes the response
+ as input.
+ - `max_retries` is the maximum number of retries when the
+ `parse_func` raise an exception.
+ - `fault_handler` is a callable function which is called
+ when the response generated by the model is invalid after
+ `max_retries` retries.
+ """
+ raise NotImplementedError(
+ f"Model Wrapper [{type(self).__name__}]"
+ f" is missing the required `__call__` method",
+ )
+
diff --git a/muagent/models/ollama_model.py b/muagent/models/ollama_model.py
new file mode 100644
index 0000000..fd73f7e
--- /dev/null
+++ b/muagent/models/ollama_model.py
@@ -0,0 +1,490 @@
+# -*- coding: utf-8 -*-
+"""Model wrapper for Ollama models."""
+import os
+from abc import ABC
+from typing import (
+ Sequence,
+ Any,
+ Optional,
+ List,
+ Union,
+ Generator,
+ Literal,
+ Mapping
+)
+
+from .base_model import ModelWrapperBase
+from ..schemas import Message
+
+
+class OllamaWrapperBase(ModelWrapperBase, ABC):
+ """The base class for Ollama model wrappers.
+
+ To use Ollama API, please
+ 1. First install ollama server from https://ollama.com/download and
+ start the server
+ 2. Pull the model by `ollama pull {model_name}` in terminal
+ After that, you can use the ollama API.
+ """
+
+ model_type: str
+ """The type of the model wrapper, which is to identify the model wrapper
+ class in model configuration."""
+
+ model_name: str
+ """The model name used in ollama API."""
+
+ options: dict
+ """A dict contains the options for ollama generation API,
+ e.g. {"temperature": 0, "seed": 123}"""
+
+ keep_alive: str
+ """Controls how long the model will stay loaded into memory following
+ the request."""
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str,
+ api_key: str = '',
+ options: dict = None,
+ keep_alive: str = "5m",
+ api_url: Optional[Union[str, None]] = "http://127.0.0.1:11434",
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the model wrapper for Ollama API.
+
+ Args:
+ model_name (`str`):
+ The model name used in ollama API.
+ options (`dict`, default `None`):
+ The extra keyword arguments used in Ollama api generation,
+ e.g. `{"temperature": 0., "seed": 123}`.
+ keep_alive (`str`, default `5m`):
+ Controls how long the model will stay loaded into memory
+ following the request.
+ host (`str`, default `None`):
+ The host port of the ollama server.
+ Defaults to `None`, which is 127.0.0.1:11434.
+ """
+
+ super().__init__(config_name=config_name, model_name=model_name)
+
+ self.options = options
+ self.keep_alive = keep_alive
+ self.api_url = api_url or "http://127.0.0.1:11434"
+
+ try:
+ import ollama
+ except ImportError as e:
+ raise ImportError(
+ "The package ollama is not found. Please install it by "
+ 'running command `pip install "ollama>=0.1.7"`',
+ ) from e
+
+ self.client = ollama.Client(host=self.api_url)
+
+
+class OllamaChatWrapper(OllamaWrapperBase):
+ """The model wrapper for Ollama chat API.
+
+ Response:
+ - Refer to
+ https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
+
+ ```json
+ {
+ "model": "registry.ollama.ai/library/llama3:latest",
+ "created_at": "2023-12-12T14:13:43.416799Z",
+ "message": {
+ "role": "assistant",
+ "content": "Hello! How are you today?"
+ },
+ "done": true,
+ "total_duration": 5191566416,
+ "load_duration": 2154458,
+ "prompt_eval_count": 26,
+ "prompt_eval_duration": 383809000,
+ "eval_count": 298,
+ "eval_duration": 4799921000
+ }
+ ```
+ """
+
+ model_type: str = 'ollama_chat'
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str,
+ stream: bool = False,
+ options: dict = None,
+ keep_alive: str = "5m",
+ api_url: Optional[Union[str, None]] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the model wrapper for Ollama API.
+
+ Args:
+ model_name (`str`):
+ The model name used in ollama API.
+ stream (`bool`, default `False`):
+ Whether to enable stream mode.
+ options (`dict`, default `None`):
+ The extra keyword arguments used in Ollama api generation,
+ e.g. `{"temperature": 0., "seed": 123}`.
+ keep_alive (`str`, default `5m`):
+ Controls how long the model will stay loaded into memory
+ following the request.
+ api_url (`str`, default `None`):
+ The host port of the ollama server.
+ Defaults to `None`, which is 127.0.0.1:11434.
+ """
+
+ super().__init__(
+ config_name=config_name,
+ model_name=model_name,
+ options=options,
+ keep_alive=keep_alive,
+ api_url=api_url,
+ **kwargs,
+ )
+
+ self.stream = stream
+
+ def __call__(
+ self,
+ prompt: str = None,
+ messages: Sequence[dict] = [],
+ tools: Sequence[object] = [],
+ *,
+ tool_choice: Optional[Literal['auto', 'required']] = None,
+ parallel_tool_calls: Optional[bool] = None,
+ stop: Optional[str] = '',
+ stream: Optional[bool] = None,
+ options: Optional[dict] = None,
+ keep_alive: Optional[str] = None,
+ format_type: Literal['str', 'raw', 'dict'] = 'raw',
+ **kwargs: Any,
+ ):
+ """Generate response from the given messages.
+
+ Args:
+ messages (`Sequence[dict]`):
+ A list of messages, each message is a dict contains the `role`
+ and `content` of the message.
+ stream (`bool`, default `None`):
+ Whether to enable stream mode, which will override the `stream`
+ input in the constructor.
+ options (`dict`, default `None`):
+ The extra arguments used in ollama chat API, which takes
+ effect only on this call, and will be merged with the
+ `options` input in the constructor,
+ e.g. `{"temperature": 0., "seed": 123}`.
+ keep_alive (`str`, default `None`):
+ How long the model will stay loaded into memory following
+ the request, which takes effect only on this call, and will
+ override the `keep_alive` input in the constructor.
+
+ Returns:
+ `ModelResponse`:
+ The response text in `text` field, and the raw response in
+ `raw` field.
+ """
+
+ messages = [{"role": "user", "content": prompt}] if prompt else messages
+ # step1: prepare parameters accordingly
+ if options is None:
+ options = self.options or {"stop": [stop] if stop else []}
+ else:
+ options = {**self.options, **options}
+
+ keep_alive = keep_alive or self.keep_alive
+
+ # step2: forward to generate response
+ stream = self.stream if stream is None else stream
+
+ kwargs.update(
+ {
+ "model": self.model_name,
+ "messages": messages,
+ "tools": tools,
+ "stream": stream,
+ "options": options,
+ "keep_alive": keep_alive,
+ },
+ )
+
+ response = self.client.chat(**kwargs)
+ if format_type == "str":
+ content = ""
+ if stream:
+ for chunk in response:
+ content += chunk["message"]["content"] or ''
+ yield content
+ else:
+ yield response["message"]["content"]
+ else:
+ if stream:
+ for chunk in response:
+ yield chunk
+ else:
+ yield response
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> List[dict]:
+ """Format the messages for ollama Chat API.
+
+ All messages will be formatted into a single system message with
+ system prompt and conversation history.
+
+ Note:
+ 1. This strategy maybe not suitable for all scenarios,
+ and developers are encouraged to implement their own prompt
+ engineering strategies.
+ 2. For ollama chat api, the content field shouldn't be empty string.
+
+ Example:
+
+ .. code-block:: python
+
+ prompt = model.format(
+ Message("system", "You're a helpful assistant", role="system"),
+ Message("Bob", "Hi, how can I help you?", role="assistant"),
+ Message("user", "What's the date today?", role="user")
+ )
+
+ The prompt will be as follows:
+
+ .. code-block:: python
+
+ [
+ {
+ "role": "system",
+ "content": "You're a helpful assistant"
+ },
+ {
+ "role": "user",
+ "content": (
+ "## Conversation History\\n"
+ "Bob: Hi, how can I help you?\\n"
+ "user: What's the date today?"
+ )
+ }
+ ]
+
+
+ Args:
+ args (`Union[Message, Sequence[Message]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Message` object, or a list of `Message` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `List[dict]`:
+ The formatted messages.
+ """
+
+ # Parse all information into a list of messages
+ input_msgs: List[Message] = []
+ for _ in args:
+ if _ is None:
+ continue
+ if isinstance(_, Message):
+ input_msgs.append(_)
+ elif isinstance(_, list) and all(isinstance(__, Message) for __ in _):
+ input_msgs.extend(_)
+ else:
+ raise TypeError(
+ f"The input should be a Message object or a list "
+ f"of Message objects, got {type(_)}.",
+ )
+
+ # record dialog history as a list of strings
+ system_prompt = None
+ history_content_template = []
+ dialogue = []
+ # TODO: here we default the url links to images
+ images = []
+ for i, unit in enumerate(input_msgs):
+ if i == 0 and unit.role_type == "system":
+ # system prompt
+ system_prompt = unit.content
+ else:
+ # Merge all messages into a conversation history prompt
+ dialogue.append(
+ f"{unit.role_name}: {unit.content}",
+ )
+
+ if unit.image_urls is not None:
+ images.extend(unit.image_urls)
+
+ if len(dialogue) != 0:
+ dialogue_history = "\n".join(dialogue)
+
+ history_content_template.extend(
+ ["## Conversation History", dialogue_history],
+ )
+
+ history_content = "\n".join(history_content_template)
+
+ # The conversation history message
+ history_message = {
+ "role": "user",
+ "content": history_content,
+ }
+
+ if len(images) != 0:
+ history_message["images"] = images
+
+ if system_prompt is None:
+ return [history_message]
+
+ return [
+ {"role": "system", "content": system_prompt},
+ history_message,
+ ]
+
+ def format_prompt(self, *args: Union[Message, Sequence[Message]]) -> str:
+ """Forward the input to the model.
+
+ Args:
+ args (`Union[Msg, Sequence[Msg]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Msg` object, or a list of `Msg` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `str`:
+ The formatted string prompt.
+ """
+ input_msgs: List[Message] = []
+ for _ in args:
+ if _ is None:
+ continue
+ if isinstance(_, Message):
+ input_msgs.append(_)
+ elif isinstance(_, list) and all(isinstance(__, Message) for __ in _):
+ input_msgs.extend(_)
+ else:
+ raise TypeError(
+ f"The input should be a Msg object or a list "
+ f"of Msg objects, got {type(_)}.",
+ )
+
+ sys_prompt = None
+ dialogue = []
+ for i, unit in enumerate(input_msgs):
+ if i == 0 and unit.role_type == "system":
+ # system prompt
+ sys_prompt = unit.content
+ else:
+ # Merge all messages into a conversation history prompt
+ dialogue.append(
+ f"{unit.role_name}: {unit.content}",
+ )
+
+ dialogue_history = "\n".join(dialogue)
+
+ if sys_prompt is None:
+ prompt_template = "## Conversation History\n{dialogue_history}"
+ else:
+ prompt_template = (
+ "{system_prompt}\n"
+ "\n"
+ "## Conversation History\n"
+ "{dialogue_history}"
+ )
+
+ return prompt_template.format(
+ system_prompt=sys_prompt,
+ dialogue_history=dialogue_history,
+ )
+
+
+class OllamaEmbeddingWrapper(OllamaWrapperBase):
+ """The model wrapper for Ollama embedding API.
+
+ Response:
+ - Refer to
+ https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings
+
+ ```json
+ {
+ "model": "all-minilm",
+ "embeddings": [[
+ 0.010071029, -0.0017594862, 0.05007221, 0.04692972,
+ 0.008599704, 0.105441414, -0.025878139, 0.12958129,
+ ]]
+ }
+ ```
+ """
+
+ model_type: str = "ollama_embedding"
+
+ def __call__(
+ self,
+ texts: str,
+ options: Optional[dict] = None,
+ keep_alive: Optional[str] = None,
+ **kwargs: Any,
+ ) -> Mapping[str, Sequence[float]]:
+ """Generate embedding from the given prompt.
+
+ Args:
+ prompt (`str`):
+ The prompt to generate response.
+ options (`dict`, default `None`):
+ The extra arguments used in ollama embedding API, which takes
+ effect only on this call, and will be merged with the
+ `options` input in the constructor,
+ e.g. `{"temperature": 0., "seed": 123}`.
+ keep_alive (`str`, default `None`):
+ How long the model will stay loaded into memory following
+ the request, which takes effect only on this call, and will
+ override the `keep_alive` input in the constructor.
+
+ Returns:
+ `ModelResponse`:
+ The response embedding in `embedding` field, and the raw
+ response in `raw` field.
+ """
+ # step1: prepare parameters accordingly
+ if options is None:
+ options = self.options
+ else:
+ options = {**self.options, **options}
+
+ keep_alive = keep_alive or self.keep_alive
+
+ # step2: forward to generate response
+ response = self.client.embed(
+ model=self.model_name,
+ input=texts,
+ options=options,
+ keep_alive=keep_alive,
+ **kwargs,
+ )
+ # step5: return response
+ return response
+
+ def embed_query(self, text: str) -> List[float]:
+ response = self([text])
+ embeddings = response["embeddings"]
+ return embeddings[0]
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ response = self(texts)
+ embeddings = response["embeddings"]
+ return embeddings
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> Union[List[dict], str]:
+ raise RuntimeError(
+ f"Model Wrapper [{type(self).__name__}] doesn't "
+ f"need to format the input. Please try to use the "
+ f"model wrapper directly.",
+ )
\ No newline at end of file
diff --git a/muagent/models/openai_model.py b/muagent/models/openai_model.py
new file mode 100644
index 0000000..338e530
--- /dev/null
+++ b/muagent/models/openai_model.py
@@ -0,0 +1,667 @@
+# -*- coding: utf-8 -*-
+"""Model wrapper for OpenAI models
+The implementation of this _ModelWrapperMeta are borrowed from
+https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py
+"""
+
+
+from abc import ABC
+from typing import (
+ Union,
+ Any,
+ List,
+ Sequence,
+ Dict,
+ Optional,
+ Generator,
+ Literal
+)
+from urllib.parse import urlparse
+import os
+import base64
+from loguru import logger
+try:
+ import openai
+except ImportError as e:
+ raise ImportError(
+ "Cannot find openai package, please install it by "
+ "`pip install openai`",
+ ) from e
+
+from openai.types.chat import ChatCompletion, ChatCompletionChunk
+from openai.types import CreateEmbeddingResponse
+from .base_model import ModelWrapperBase
+from ..schemas import Message
+
+
+
+OPENAI_MAX_LENGTH = {
+ "update": 20231212,
+ # gpt-4
+ "gpt-4o-mini": 8192,
+ "gpt-4-1106-preview": 128000,
+ "gpt-4-vision-preview": 128000,
+ "gpt-4": 8192,
+ "gpt-4-32k": 32768,
+ "gpt-4-0613": 8192,
+ "gpt-4-32k-0613": 32768,
+ "gpt-4-0314": 8192, # legacy
+ "gpt-4-32k-0314": 32768, # legacy
+ # gpt-3.5
+ "gpt-3.5-turbo-1106": 16385,
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-16k": 16385,
+ "gpt-3.5-turbo-instruct": 4096,
+ "gpt-3.5-turbo-0613": 4096, # legacy
+ "gpt-3.5-turbo-16k-0613": 16385, # deprecated on June 13th 2024
+ "gpt-3.5-turbo-0301": 4096, # deprecated on June 13th 2024
+ "text-davinci-003": 4096, # deprecated on Jan 4th 2024
+ "text-davinci-002": 4096, # deprecated on Jan 4th 2024
+ "code-davinci-002": 4096, # deprecated on Jan 4th 2024
+ # gpt-3 legacy
+ "text-curie-001": 2049,
+ "text-babbage-001": 2049,
+ "text-ada-001": 2049,
+ "davinci": 2049,
+ "curie": 2049,
+ "babbage": 2049,
+ "ada": 2049,
+ #
+ "text-embedding-3-small": 8191,
+ "text-embedding-3-large": 8191,
+ "text-embedding-ada-002": 8191,
+}
+
+
+def get_openai_max_length(model_name: str) -> int:
+ """Get the max length of the OpenAi models."""
+ try:
+ return OPENAI_MAX_LENGTH[model_name]
+ except KeyError as exc:
+ raise KeyError(
+ f"Model [{model_name}] not found in OPENAI_MAX_LENGTH. "
+ f"The last updated date is {OPENAI_MAX_LENGTH['update']}",
+ ) from exc
+
+
+
+def _to_openai_image_url(url: str) -> str:
+ """Convert an image url to openai format. If the given url is a local
+ file, it will be converted to base64 format. Otherwise, it will be
+ returned directly.
+
+ Args:
+ url (`str`):
+ The local or public url of the image.
+ """
+ # See https://platform.openai.com/docs/guides/vision for details of
+ # support image extensions.
+ support_image_extensions = (
+ ".png",
+ ".jpg",
+ ".jpeg",
+ ".gif",
+ ".webp",
+ )
+
+ parsed_url = urlparse(url)
+
+ lower_url = url.lower()
+
+ # Web url
+ if parsed_url.scheme != "":
+ if any(lower_url.endswith(_) for _ in support_image_extensions):
+ return url
+
+ # Check if it is a local file
+ elif os.path.exists(url) and os.path.isfile(url):
+ if any(lower_url.endswith(_) for _ in support_image_extensions):
+ with open(url, "rb") as image_file:
+ base64_image = base64.b64encode(image_file.read()).decode(
+ "utf-8",
+ )
+ extension = parsed_url.path.lower().split(".")[-1]
+ mime_type = f"image/{extension}"
+ return f"data:{mime_type};base64,{base64_image}"
+
+ raise TypeError(f"{url} should be end with {support_image_extensions}.")
+
+
+
+class OpenAIWrapperBase(ModelWrapperBase, ABC):
+ """The model wrapper for OpenAI API.
+
+ Response:
+ - From https://platform.openai.com/docs/api-reference/chat/create
+
+ ```json
+ {
+ "id": "chatcmpl-123",
+ "object": "chat.completion",
+ "created": 1677652288,
+ "model": "gpt-4o-mini",
+ "system_fingerprint": "fp_44709d6fcb",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "Hello there, how may I assist you today?",
+ },
+ "logprobs": null,
+ "finish_reason": "stop"
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 9,
+ "completion_tokens": 12,
+ "total_tokens": 21
+ }
+ }
+ ```
+ """
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str = None,
+ api_key: str = None,
+ api_url: str = "https://api.openai.com/v1",
+ organization: str = None,
+ client_args: dict = None,
+ generate_args: dict = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the openai client.
+
+ Args:
+ config_name (`str`):
+ The name of the model config.
+ model_name (`str`, default `None`):
+ The name of the model to use in OpenAI API.
+ api_key (`str`, default `None`):
+ The API key for OpenAI API. If not specified, it will
+ be read from the environment variable `OPENAI_API_KEY`.
+ organization (`str`, default `None`):
+ The organization ID for OpenAI API. If not specified, it will
+ be read from the environment variable `OPENAI_ORGANIZATION`.
+ client_args (`dict`, default `None`):
+ The extra keyword arguments to initialize the OpenAI client.
+ generate_args (`dict`, default `None`):
+ The extra keyword arguments used in openai api generation,
+ e.g. `temperature`, `seed`.
+ """
+
+ if model_name is None:
+ model_name = config_name
+ logger.warning("model_name is not set, use config_name instead.")
+
+ init_params = locals()
+ init_params.pop("self")
+ init_params["model_type"] = self.model_type
+ super().__init__(**init_params)
+ # super().__init__(config_name=config_name, model_name=model_name)
+
+ self.generate_args = generate_args or {}
+
+ try:
+ from zdatafront import ZDataFrontClient
+ from zdatafront.openai import SyncProxyHttpClient
+ VISIT_DOMAIN = os.environ.get("visit_domain")
+ VISIT_BIZ = os.environ.get("visit_biz")
+ VISIT_BIZ_LINE = os.environ.get("visit_biz_line")
+ aes_secret_key = os.environ.get("aes_secret_key")
+ zdatafront_client = ZDataFrontClient(
+ visit_domain=VISIT_DOMAIN,
+ visit_biz=VISIT_BIZ,
+ visit_biz_line=VISIT_BIZ_LINE,
+ aes_secret_key=aes_secret_key
+ )
+ http_client = SyncProxyHttpClient(zdatafront_client=zdatafront_client, prefer_async=True)
+ except Exception as e:
+ logger.warning("There is no zdatafront, act as openai")
+ http_client = None
+
+ if http_client:
+ self.client = openai.OpenAI(
+ api_key=api_key,
+ http_client=http_client,
+ organization=organization,
+ **(client_args or {}),
+ timeout=120,
+ )
+ else:
+ self.client = openai.OpenAI(
+ api_key=api_key,
+ organization=organization,
+ **(client_args or {}),
+ )
+ # Set the max length of OpenAI model
+ try:
+ self.max_length = get_openai_max_length(self.model_name)
+ except Exception as e:
+ logger.warning(
+ f"fail to get max_length for {self.model_name}: " f"{e}",
+ )
+ self.max_length = None
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> Union[List[dict], str]:
+ raise RuntimeError(
+ f"Model Wrapper [{type(self).__name__}] doesn't "
+ f"need to format the input. Please try to use the "
+ f"model wrapper directly.",
+ )
+
+
+class OpenAIChatWrapper(OpenAIWrapperBase):
+ """The model wrapper for OpenAI's chat API."""
+
+ model_type: str = "openai_chat"
+
+ substrings_in_vision_models_names = ["gpt-4-turbo", "vision", "gpt-4o"]
+ """The substrings in the model names of vision models."""
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str = None,
+ api_key: str = None,
+ api_url: str = "https://api.openai.com/v1",
+ organization: str = None,
+ client_args: dict = None,
+ stream: bool = False,
+ generate_args: dict = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the openai client.
+
+ Args:
+ config_name (`str`):
+ The name of the model config.
+ model_name (`str`, default `None`):
+ The name of the model to use in OpenAI API.
+ api_key (`str`, default `None`):
+ The API key for OpenAI API. If not specified, it will
+ be read from the environment variable `OPENAI_API_KEY`.
+ organization (`str`, default `None`):
+ The organization ID for OpenAI API. If not specified, it will
+ be read from the environment variable `OPENAI_ORGANIZATION`.
+ client_args (`dict`, default `None`):
+ The extra keyword arguments to initialize the OpenAI client.
+ stream (`bool`, default `False`):
+ Whether to enable stream mode.
+ generate_args (`dict`, default `None`):
+ The extra keyword arguments used in openai api generation,
+ e.g. `temperature`, `seed`.
+ """
+
+ init_params = locals()
+ init_params.pop("self")
+ init_params["model_type"] = self.model_type
+ super().__init__(**init_params)
+ self.stream = stream
+
+ def __call__(
+ self,
+ prompt: str = None,
+ messages: Sequence[dict] = [],
+ tools: Sequence[object] = [],
+ *,
+ tool_choice: Optional[Literal['auto', 'required']] = None,
+ parallel_tool_calls: Optional[bool] = None,
+ stream: bool = None,
+ stop: Optional[str] = '',
+ format_type: Literal['str', 'raw', 'dict'] = 'raw',
+ **kwargs: Any,
+ ) -> Generator[Union[ChatCompletionChunk, ChatCompletion], None, None]:
+ """Processes a list of messages to construct a payload for the OpenAI
+ API call. It then makes a request to the OpenAI API and returns the
+ response. This method also updates monitoring metrics based on the
+ API response.
+
+ Each message in the 'messages' list can contain text content and
+ optionally an 'image_urls' key. If 'image_urls' is provided,
+ it is expected to be a list of strings representing URLs to images.
+ These URLs will be transformed to a suitable format for the OpenAI
+ API, which might involve converting local file paths to data URIs.
+
+ Args:
+ messages (`list`):
+ A list of messages to process.
+ stream (`Optional[bool]`, defaults to `None`)
+ Whether to enable stream mode, which will override the
+ `stream` argument in the constructor if provided.
+ **kwargs (`Any`):
+ The keyword arguments to OpenAI chat completions API,
+ e.g. `temperature`, `max_tokens`, `top_p`, etc. Please refer to
+ https://platform.openai.com/docs/api-reference/chat/create
+ for more detailed arguments.
+
+ Returns:
+ `ModelResponse`:
+ The response text in text field, and the raw response in
+ raw field.
+
+ Note:
+ `parse_func`, `fault_handler` and `max_retries` are reserved for
+ `_response_parse_decorator` to parse and check the response
+ generated by model wrapper. Their usages are listed as follows:
+ - `parse_func` is a callable function used to parse and check
+ the response generated by the model, which takes the response
+ as input.
+ - `max_retries` is the maximum number of retries when the
+ `parse_func` raise an exception.
+ - `fault_handler` is a callable function which is called
+ when the response generated by the model is invalid after
+ `max_retries` retries.
+ """
+
+ messages = [{"role": "user", "content": prompt}] if prompt else messages
+
+ # step1: prepare keyword arguments
+ kwargs = {**self.generate_args, **kwargs}
+
+ # step2: checking messages
+ if not isinstance(messages, list):
+ raise ValueError(
+ "OpenAI `messages` field expected type `list`, "
+ f"got `{type(messages)}` instead.",
+ )
+ if not all("role" in Message and "content" in Message for Message in messages):
+ raise ValueError(
+ "Each message in the 'messages' list must contain a 'role' "
+ "and 'content' key for OpenAI API.",
+ )
+
+ # step3: forward to generate response
+ if stream is None:
+ stream = self.stream
+
+ kwargs.update(
+ {
+ "model": self.model_name,
+ "messages": messages,
+ "stream": stream,
+ "tools": tools,
+ "tool_choice": tool_choice,
+ "parallel_tool_calls": parallel_tool_calls,
+ "stop": stop,
+ },
+ )
+
+ if stream:
+ kwargs["stream_options"] = {"include_usage": True}
+
+ response = self.client.chat.completions.create(**kwargs)
+
+ if format_type == "str":
+ content = ""
+ if stream:
+ for chunk in response:
+ content += chunk.choices[0].delta.content or ''
+ yield content
+ else:
+ yield response.choices[0].message.content
+ else:
+ if stream:
+ for chunk in response:
+ yield chunk
+ else:
+ yield response
+
+ @staticmethod
+ def _format_Message_with_url(
+ message: Message,
+ model_name: str,
+ ) -> Dict:
+ """Format a message with image urls into openai chat format.
+ This format method is used for gpt-4o, gpt-4-turbo, gpt-4-vision and
+ other vision models.
+ """
+ # Check if the model is a vision model
+ if not any(
+ _ in model_name
+ for _ in OpenAIChatWrapper.substrings_in_vision_models_names
+ ):
+ logger.warning(
+ f"The model {model_name} is not a vision model. "
+ f"Skip the url in the message.",
+ )
+ return {
+ "role": message.role_type,
+ "name": message.role_name,
+ "content": message.content,
+ }
+
+ # Put all urls into a list
+ urls = message.image_urls if isinstance(message.image_urls, list) else [message.image_urls]
+
+ # Check if the url refers to an image
+ checked_urls = []
+ for url in urls:
+ try:
+ checked_urls.append(_to_openai_image_url(url))
+ except TypeError:
+ logger.warning(
+ f"The url {url} is not a valid image url for "
+ f"OpenAI Chat API, skipped.",
+ )
+
+ if len(checked_urls) == 0:
+ # If no valid image url is provided, return the normal message dict
+ return {
+ "role": message.role_type,
+ "name": message.role_name,
+ "content": message.content,
+ }
+ else:
+ # otherwise, use the vision format message
+ returned_Message = {
+ "role": message.role_type,
+ "name": message.role_name,
+ "content": [
+ {
+ "type": "text",
+ "text": message.content,
+ },
+ ],
+ }
+
+ image_dicts = [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": _,
+ },
+ }
+ for _ in checked_urls
+ ]
+
+ returned_Message["content"].extend(image_dicts)
+
+ return returned_Message
+
+ @staticmethod
+ def static_format(
+ *args: Union[Message, Sequence[Message]],
+ model_name: str,
+ ) -> List[dict]:
+ """A static version of the format method, which can be used without
+ initializing the OpenAIChatWrapper object.
+
+ Args:
+ args (`Union[Message, Sequence[Message]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Message` object, or a list of `Message` objects.
+ In distribution, placeholder is also allowed.
+ model_name (`str`):
+ The name of the model to use in OpenAI API.
+
+ Returns:
+ `List[dict]`:
+ The formatted messages in the format that OpenAI Chat API
+ required.
+ """
+ messages = []
+ for arg in args:
+ if arg is None:
+ continue
+ if isinstance(arg, Message):
+ if arg.image_urls is not None and arg.image_urls:
+ # Format the message according to the model type
+ # (vision/non-vision)
+ formatted_Message = OpenAIChatWrapper._format_Message_with_url(
+ arg,
+ model_name,
+ )
+ messages.append(formatted_Message)
+ else:
+ messages.append(
+ {
+ "role": arg.role_type,
+ "name": arg.role_name,
+ "content": arg.content,
+ },
+ )
+
+ elif isinstance(arg, list):
+ messages.extend(
+ OpenAIChatWrapper.static_format(
+ *arg,
+ model_name=model_name,
+ ),
+ )
+ else:
+ raise TypeError(
+ f"The input should be a Message object or a list "
+ f"of Message objects, got {type(arg)}.",
+ )
+
+ return messages
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> List[dict]:
+ """Format the input string and dictionary into the format that
+ OpenAI Chat API required. If you're using a OpenAI-compatible model
+ without a prefix "gpt-" in its name, the format method will
+ automatically format the input messages into the required format.
+
+ Args:
+ args (`Union[Message, Sequence[Message]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Message` object, or a list of `Message` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `List[dict]`:
+ The formatted messages in the format that OpenAI Chat API
+ required.
+ """
+
+ # Format messages according to the model name
+ if self.model_name.startswith("gpt-"):
+ return OpenAIChatWrapper.static_format(
+ *args,
+ model_name=self.model_name,
+ )
+ else:
+ # The OpenAI library maybe re-used to support other models
+ return ModelWrapperBase.format_for_common_chat_models(*args)
+
+
+class OpenAIEmbeddingWrapper(OpenAIWrapperBase):
+ """The model wrapper for OpenAI embedding API.
+
+ Response:
+ - Refer to
+ https://platform.openai.com/docs/api-reference/embeddings/create
+
+ ```json
+ {
+ "object": "list",
+ "data": [
+ {
+ "object": "embedding",
+ "embedding": [
+ 0.0023064255,
+ -0.009327292,
+ .... (1536 floats total for ada-002)
+ -0.0028842222,
+ ],
+ "index": 0
+ }
+ ],
+ "model": "text-embedding-ada-002",
+ "usage": {
+ "prompt_tokens": 8,
+ "total_tokens": 8
+ }
+ }
+ ```
+ """
+
+ model_type: str = "openai_embedding"
+
+ def __call__(
+ self,
+ texts: Union[list[str], str],
+ dimension=768,
+ **kwargs: Any,
+ ) -> CreateEmbeddingResponse:
+ """Embed the messages with OpenAI embedding API.
+
+ Args:
+ texts (`list[str]` or `str`):
+ The messages used to embed.
+ **kwargs (`Any`):
+ The keyword arguments to OpenAI embedding API,
+ e.g. `encoding_format`, `user`. Please refer to
+ https://platform.openai.com/docs/api-reference/embeddings
+ for more detailed arguments.
+
+ Returns:
+ `ModelResponse`:
+ A list of embeddings in embedding field and the
+ raw response in raw field.
+
+ Note:
+ `parse_func`, `fault_handler` and `max_retries` are reserved for
+ `_response_parse_decorator` to parse and check the response
+ generated by model wrapper. Their usages are listed as follows:
+ - `parse_func` is a callable function used to parse and check
+ the response generated by the model, which takes the response
+ as input.
+ - `max_retries` is the maximum number of retries when the
+ `parse_func` raise an exception.
+ - `fault_handler` is a callable function which is called
+ when the response generated by the model is invalid after
+ `max_retries` retries.
+ """
+ # step1: prepare keyword arguments
+ kwargs = {**self.generate_args, **kwargs}
+
+ # step2: forward to generate response
+ response = self.client.embeddings.create(
+ input=texts,
+ model=self.model_name,
+ **kwargs,
+ )
+ # step5: return response
+ response_json = response.model_dump()
+ return response_json
+
+ def embed_query(self, text: str) -> List[float]:
+ response = self([text])
+ output = response["data"]
+ return output[0]["embedding"]
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ response = self(texts)
+ output = response["data"]
+ return [emb["embedding"] for emb in output]
+
\ No newline at end of file
diff --git a/muagent/models/qwen_model.py b/muagent/models/qwen_model.py
new file mode 100644
index 0000000..2844bf1
--- /dev/null
+++ b/muagent/models/qwen_model.py
@@ -0,0 +1,461 @@
+# -*- coding: utf-8 -*-
+"""Model wrapper for DashScope models"""
+import os
+from abc import ABC
+from http import HTTPStatus
+from typing import (
+ Any,
+ Union,
+ List,
+ Sequence,
+ Optional,
+ Generator,
+ Literal
+)
+
+import openai
+from openai.types.chat import ChatCompletion, ChatCompletionChunk
+from loguru import logger
+
+try:
+ import dashscope
+
+ dashscope_version = dashscope.version.__version__
+ if dashscope_version < "1.19.0":
+ logger.warning(
+ f"You are using 'dashscope' version {dashscope_version}, "
+ "which is below the recommended version 1.19.0. "
+ "Please consider upgrading to maintain compatibility.",
+ )
+ from dashscope.api_entities.dashscope_response import GenerationResponse
+except ImportError:
+ dashscope = None
+ GenerationResponse = None
+
+
+from ..schemas import Message
+from .base_model import ModelWrapperBase
+from ..utils.common_utils import _convert_to_str
+
+
+
+class QwenWrapperBase(ModelWrapperBase, ABC):
+ """The model wrapper for DashScope API."""
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str = None,
+ api_key: str = None,
+ generate_args: dict = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the DashScope wrapper.
+
+ Args:
+ config_name (`str`):
+ The name of the model config.
+ model_name (`str`, default `None`):
+ The name of the model to use in DashScope API.
+ api_key (`str`, default `None`):
+ The API key for DashScope API.
+ generate_args (`dict`, default `None`):
+ The extra keyword arguments used in DashScope api generation,
+ e.g. `temperature`, `seed`.
+ """
+ if model_name is None:
+ model_name = config_name
+ logger.warning("model_name is not set, use config_name instead.")
+
+ super().__init__(config_name=config_name, model_name=model_name)
+
+ self.generate_args = generate_args or {}
+
+ self.api_key = api_key
+ self.max_length = None
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> Union[List[dict], str]:
+ raise RuntimeError(
+ f"Model Wrapper [{type(self).__name__}] doesn't "
+ f"need to format the input. Please try to use the "
+ f"model wrapper directly.",
+ )
+
+
+class QwenChatWrapper(QwenWrapperBase):
+ """The model wrapper for DashScope's chat API, refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/api-details
+
+ Response:
+ - Refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/quick-start?spm=a2c4g.11186623.0.0.7e346eb5RvirBw
+
+ ```json
+ {
+ "status_code": 200,
+ "request_id": "a75a1b22-e512-957d-891b-37db858ae738",
+ "code": "",
+ "message": "",
+ "output": {
+ "text": null,
+ "finish_reason": null,
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "message": {
+ "role": "assistant",
+ "content": "xxx"
+ }
+ }
+ ]
+ },
+ "usage": {
+ "input_tokens": 25,
+ "output_tokens": 77,
+ "total_tokens": 102
+ }
+ }
+ ```
+ """
+
+ model_type: str = "qwen_chat"
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str = None,
+ api_key: str = None,
+ api_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
+ stream: bool = False,
+ generate_args: dict = None,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the DashScope wrapper.
+
+ Args:
+ config_name (`str`):
+ The name of the model config.
+ model_name (`str`, default `None`):
+ The name of the model to use in DashScope API.
+ api_key (`str`, default `None`):
+ The API key for DashScope API.
+ stream (`bool`, default `False`):
+ If True, the response will be a generator in the `stream`
+ field of the returned `ModelResponse` object.
+ generate_args (`dict`, default `None`):
+ The extra keyword arguments used in DashScope api generation,
+ e.g. `temperature`, `seed`.
+ """
+
+ super().__init__(
+ config_name=config_name,
+ model_name=model_name,
+ api_key=api_key,
+ generate_args=generate_args,
+ **kwargs,
+ )
+ self.api_url = api_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
+ self.stream = stream
+ self.client = openai.OpenAI(api_key=self.api_key, base_url=self.api_url)
+
+ def __call__(
+ self,
+ prompt: str = None,
+ messages: Sequence[dict] = [],
+ tools: Sequence[object] = [],
+ *,
+ tool_choice: Optional[Literal['auto', 'required']] = None,
+ parallel_tool_calls: Optional[bool] = None,
+ stream: Optional[bool] = None,
+ stop: Optional[str] = '',
+ format_type: Literal['str', 'raw', 'dict'] = 'raw',
+ **kwargs
+ ) -> Generator[Union[ChatCompletionChunk, ChatCompletion], None, None]:
+ """Invoke the Yi Chat API by sending a list of messages."""
+
+ messages = [{"role": "user", "content": prompt}] if prompt else messages
+ # Checking messages
+ if not isinstance(messages, list):
+ raise ValueError(
+ f"Yi `messages` field expected type `list`, "
+ f"got `{type(messages)}` instead.",
+ )
+
+ if not all("role" in Message and "content" in Message for Message in messages):
+ raise ValueError(
+ "Each message in the 'messages' list must contain a 'role' "
+ "and 'content' key for Yi API.",
+ )
+ #
+
+ stream = stream or self.stream
+ model_name = self.model_name
+ #
+ # step1: prepare keyword arguments
+ kwargs = {**self.generate_args, **kwargs}
+ kwargs.update(
+ {
+ "model": model_name,
+ "messages": messages,
+ "stream": stream,
+ "stop": stop
+ # "tools": tools,
+ # "tool_choice": tool_choice,
+ # "parallel_tool_calls": parallel_tool_calls,
+ },
+ )
+ if tools:
+ kwargs["tools"] = tools
+
+ response = self.client.chat.completions.create(**kwargs)
+ if format_type == "str":
+ content = ""
+ if stream:
+ for chunk in response:
+ content += chunk.choices[0].delta.content or ''
+ yield content
+ else:
+ yield response.choices[0].message.content
+ else:
+ if stream:
+ for chunk in response:
+ yield chunk
+ else:
+ yield response
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> List[dict]:
+ """A common format strategy for chat models, which will format the
+ input messages into a user message.
+
+ Note this strategy maybe not suitable for all scenarios,
+ and developers are encouraged to implement their own prompt
+ engineering strategies.
+
+ The following is an example:
+
+ .. code-block:: python
+
+ prompt1 = model.format(
+ Message("system", "You're a helpful assistant", role="system"),
+ Message("Bob", "Hi, how can I help you?", role="assistant"),
+ Message("user", "What's the date today?", role="user")
+ )
+
+ prompt2 = model.format(
+ Message("Bob", "Hi, how can I help you?", role="assistant"),
+ Message("user", "What's the date today?", role="user")
+ )
+
+ The prompt will be as follows:
+
+ .. code-block:: python
+
+ # prompt1
+ [
+ {
+ "role": "system",
+ "content": "You're a helpful assistant"
+ },
+ {
+ "role": "user",
+ "content": (
+ "## Conversation History\\n"
+ "Bob: Hi, how can I help you?\\n"
+ "user: What's the date today?"
+ )
+ }
+ ]
+
+ # prompt2
+ [
+ {
+ "role": "user",
+ "content": (
+ "## Conversation History\\n"
+ "Bob: Hi, how can I help you?\\n"
+ "user: What's the date today?"
+ )
+ }
+ ]
+
+
+ Args:
+ args (`Union[Msg, Sequence[Msg]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Msg` object, or a list of `Msg` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `List[dict]`:
+ The formatted messages.
+ """
+
+ return ModelWrapperBase.format_for_common_chat_models(*args)
+
+
+ def format_prompt(self, *args: Union[Message, Sequence[Message]]) -> str:
+ """Forward the input to the model.
+
+ Args:
+ args (`Union[Msg, Sequence[Msg]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Msg` object, or a list of `Msg` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `str`:
+ The formatted string prompt.
+ """
+ input_msgs: List[Message] = []
+ for _ in args:
+ if _ is None:
+ continue
+ if isinstance(_, Message):
+ input_msgs.append(_)
+ elif isinstance(_, list) and all(isinstance(__, Message) for __ in _):
+ input_msgs.extend(_)
+ else:
+ raise TypeError(
+ f"The input should be a Msg object or a list "
+ f"of Msg objects, got {type(_)}.",
+ )
+
+ sys_prompt = None
+ dialogue = []
+ for i, unit in enumerate(input_msgs):
+ if i == 0 and unit.role_type == "system":
+ # system prompt
+ sys_prompt = unit.content
+ else:
+ # Merge all messages into a conversation history prompt
+ dialogue.append(
+ f"{unit.role_name}: {unit.content}",
+ )
+
+ dialogue_history = "\n".join(dialogue)
+
+ if sys_prompt is None:
+ prompt_template = "## Conversation History\n{dialogue_history}"
+ else:
+ prompt_template = (
+ "{system_prompt}\n"
+ "\n"
+ "## Conversation History\n"
+ "{dialogue_history}"
+ )
+
+ return prompt_template.format(
+ system_prompt=sys_prompt,
+ dialogue_history=dialogue_history,
+ )
+
+class QwenTextEmbeddingWrapper(QwenWrapperBase):
+ """The model wrapper for DashScope Text Embedding API.
+
+ Response:
+ - Refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details?spm=a2c4g.11186623.0.i3
+
+ ```json
+ {
+ "status_code": 200, // 200 indicate success otherwise failed.
+ "request_id": "fd564688-43f7-9595-b986", // The request id.
+ "code": "", // If failed, the error code.
+ "message": "", // If failed, the error message.
+ "output": {
+ "embeddings": [ // embeddings
+ {
+ "embedding": [ // one embedding output
+ -3.8450357913970947, ...,
+ ],
+ "text_index": 0 // the input index.
+ }
+ ]
+ },
+ "usage": {
+ "total_tokens": 3 // the request tokens.
+ }
+ }
+ ```
+ """
+
+ model_type: str = "qwen_text_embedding"
+
+ def __call__(
+ self,
+ texts: Union[list[str], str],
+ dimension: Literal[512, 768, 1024, 1536] = 768,
+ **kwargs: Any,
+ ):
+ """Embed the messages with DashScope Text Embedding API.
+
+ Args:
+ texts (`list[str]` or `str`):
+ The messages used to embed.
+ **kwargs (`Any`):
+ The keyword arguments to DashScope Text Embedding API,
+ e.g. `text_type`. Please refer to
+ https://help.aliyun.com/zh/dashscope/developer-reference/api-details-15
+ for more detailed arguments.
+
+ Returns:
+ `ModelResponse`:
+ A list of embeddings in embedding field and the raw
+ response in raw field.
+
+ Note:
+ `parse_func`, `fault_handler` and `max_retries` are reserved
+ for `_response_parse_decorator` to parse and check the response
+ generated by model wrapper. Their usages are listed as follows:
+ - `parse_func` is a callable function used to parse and
+ check the response generated by the model, which takes the
+ response as input.
+ - `max_retries` is the maximum number of retries when the
+ `parse_func` raise an exception.
+ - `fault_handler` is a callable function which is called
+ when the response generated by the model is invalid after
+ `max_retries` retries.
+ """
+ # client = openai.OpenAI(api_key=self.api_key, base_url=self.api_url)
+ # step1: prepare keyword arguments
+ kwargs = {**self.generate_args, **kwargs}
+
+ # step2: forward to generate response
+ response = dashscope.TextEmbedding.call(
+ input=texts,
+ model=self.model_name,
+ api_key=self.api_key,
+ dimension=dimension,
+ **kwargs,
+ )
+
+ if response.status_code != HTTPStatus.OK:
+ error_msg = (
+ f" Request id: {response.request_id},"
+ f" Status code: {response.status_code},"
+ f" error code: {response.code},"
+ f" error message: {response.message}."
+ )
+ raise RuntimeError(error_msg)
+
+ # step5: return response
+ return response
+
+ def embed_query(self, text: str) -> List[float]:
+ response = self([text])
+ output = response["output"]
+ embeddings = output["embeddings"]
+ return embeddings[0]["embedding"]
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ response = self(texts)
+ output = response["output"]
+ embeddings = output["embeddings"]
+ return [emb["embedding"] for emb in embeddings]
+
\ No newline at end of file
diff --git a/muagent/models/yi_model.py b/muagent/models/yi_model.py
new file mode 100644
index 0000000..4140baf
--- /dev/null
+++ b/muagent/models/yi_model.py
@@ -0,0 +1,291 @@
+# -*- coding: utf-8 -*-
+"""Model wrapper for Yi models
+The implementation of this _ModelWrapperMeta are borrowed from
+https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py
+"""
+
+
+import json
+from typing import (
+ List,
+ Union,
+ Sequence,
+ Optional,
+ Generator,
+ Literal
+)
+
+import openai
+from openai.types.chat import ChatCompletion, ChatCompletionChunk
+
+from .base_model import ModelWrapperBase
+from ..schemas import Message
+
+
+class YiChatWrapper(ModelWrapperBase):
+ """The model wrapper for Yi Chat API.
+
+ Response:
+ - From https://platform.lingyiwanwu.com/docs
+
+ ```json
+ {
+ "id": "cmpl-ea89ae83",
+ "object": "chat.completion",
+ "created": 5785971,
+ "model": "yi-large-rag",
+ "usage": {
+ "completion_tokens": 113,
+ "prompt_tokens": 896,
+ "total_tokens": 1009
+ },
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "Today in Los Angeles, the weather ...",
+ },
+ "finish_reason": "stop"
+ }
+ ]
+ }
+ ```
+ """
+
+ model_type: str = "yi_chat"
+
+ def __init__(
+ self,
+ config_name: str,
+ model_name: str,
+ api_key: str,
+ api_url: str="https://api.lingyiwanwu.com/v1",
+ max_tokens: Optional[int] = None,
+ top_p: float = 0.9,
+ temperature: float = 0.3,
+ stream: bool = False,
+ **kwargs,
+ ) -> None:
+ """Initialize the Yi chat model wrapper.
+
+ Args:
+ config_name (`str`):
+ The name of the configuration to use.
+ model_name (`str`):
+ The name of the model to use, e.g. yi-large, yi-medium, etc.
+ api_key (`str`):
+ The API key for the Yi API.
+ max_tokens (`Optional[int]`, defaults to `None`):
+ The maximum number of tokens to generate, defaults to `None`.
+ top_p (`float`, defaults to `0.9`):
+ The randomness parameters in the range [0, 1].
+ temperature (`float`, defaults to `0.3`):
+ The temperature parameter in the range [0, 2].
+ stream (`bool`, defaults to `False`):
+ Whether to stream the response or not.
+ """
+
+ init_params = locals()
+ init_params.pop("self")
+ init_params["model_type"] = self.model_type
+ super().__init__(**init_params)
+
+ if top_p > 1 or top_p < 0:
+ raise ValueError(
+ f"The `top_p` parameter must be in the range [0, 1], but got "
+ f"{top_p} instead.",
+ )
+
+ if temperature < 0 or temperature > 2:
+ raise ValueError(
+ f"The `temperature` parameter must be in the range [0, 2], "
+ f"but got {temperature} instead.",
+ )
+ self.api_url = api_url or "https://api.lingyiwanwu.com/v1"
+ self.client = openai.OpenAI(api_key=self.api_key,base_url=self.api_url)
+ self.max_tokens = max_tokens
+ self.top_p = top_p
+ self.temperature = temperature
+ self.stream = stream
+
+ def __call__(
+ self,
+ prompt: str = None,
+ messages: Sequence[dict] = [],
+ tools: Sequence[object] = [],
+ *,
+ tool_choice: Literal['auto', 'required'] = 'auto',
+ parallel_tool_calls: Optional[bool] = None,
+ stream: Optional[bool] = None,
+ stop: Optional[str] = '',
+ format_type: Literal['str', 'raw', 'dict'] = 'raw',
+ **kwargs
+ ) -> Generator[Union[ChatCompletionChunk, ChatCompletion, str], None, None]:
+ """Invoke the Yi Chat API by sending a list of messages."""
+
+ messages = [{"role": "user", "content": prompt}] if prompt else messages
+ # Checking messages
+ if not isinstance(messages, list):
+ raise ValueError(
+ f"Yi `messages` field expected type `list`, "
+ f"got `{type(messages)}` instead.",
+ )
+
+ if not all("role" in Message and "content" in Message for Message in messages):
+ raise ValueError(
+ "Each message in the 'messages' list must contain a 'role' "
+ "and 'content' key for Yi API.",
+ )
+ #
+ stream = stream or self.stream
+ model_name = "yi-large-fc" if tools else self.model_name
+ # model_name = self.model_name
+ #
+ kwargs.update(
+ {
+ "model": model_name,
+ "messages": messages,
+ "stream": stream,
+ "tools": tools,
+ "tool_choice": tool_choice,
+ "parallel_tool_calls": parallel_tool_calls,
+ "temperature": self.temperature,
+ "max_tokens": self.max_tokens,
+ "top_p": self.top_p,
+ "stop": [stop]
+ },
+ )
+
+ response = self.client.chat.completions.create(**kwargs)
+
+ if format_type == "str":
+ content = ""
+ if stream:
+ for chunk in response:
+ content += chunk.choices[0].delta.content or ''
+ yield content
+ else:
+ yield response.choices[0].message.content
+ else:
+ if stream:
+ for chunk in response:
+ yield chunk
+ else:
+ yield response
+
+ def function_call(
+ self,
+ messages: Optional[Sequence[dict]] = None,
+ tools: Sequence[object] = [],
+ *,
+ prompt: Optional[str] = None,
+ tool_choice: Literal['auto', 'required'] = 'auto',
+ parallel_tool_calls: Optional[bool] = None,
+ stream: Optional[bool] = False,
+ ) -> ChatCompletion:
+ """Call a function to process messages with optional tools.
+
+ Args:
+ messages (Optional[Sequence[dict]], optional): A sequence of messages for context.
+ tools (Sequence[object], optional): Tools available for use.
+ prompt (Optional[str], optional): An optional prompt.
+ tool_choice (Optional[Literal['auto', 'required']], optional): How to select tools.
+ parallel_tool_calls (Optional[bool], optional): If true, allows parallel tool calls.
+ stream (Optional[bool], optional): If true, streams the output instead of returning it all at once.
+
+ Returns:
+ Union[ChatCompletion, Mapping]: The result of the function call processed by the model.
+ """
+ kwargs = locals()
+ kwargs.pop("self")
+ kwargs.pop("__class__")
+ return super().function_call(**kwargs)
+
+ def function_call_stream(
+ self,
+ messages: Optional[Sequence[dict]] = None,
+ tools: Sequence[object] = [],
+ *,
+ prompt: Optional[str] = None,
+ tool_choice: Literal['auto', 'required'] = 'auto',
+ parallel_tool_calls: Optional[bool] = None,
+ stream: Optional[bool] = True,
+ ) -> Generator[ChatCompletionChunk, None, None]:
+ """Stream function call outputs.
+
+ Args:
+ messages (Optional[Sequence[dict]], optional): A sequence of messages for context.
+ tools (Sequence[object], optional): Tools available for use.
+ prompt (Optional[str], optional): An optional prompt.
+ tool_choice (Optional[Literal['auto', 'required']], optional): How to select tools.
+ parallel_tool_calls (Optional[bool], optional): If true, allows parallel tool calls.
+ stream (Optional[bool], optional): If true, streams the output instead of returning it all at once.
+
+ Yields:
+ Generator[Union[ChatCompletionChunk, Mapping], None, None]: A generator yielding parts of the function output.
+ """
+ kwargs = locals()
+ kwargs.pop("self")
+ kwargs.pop("__class__")
+ for i in super().function_call_stream(**kwargs): yield i
+
+ def format(
+ self,
+ *args: Union[Message, Sequence[Message]],
+ ) -> List[dict]:
+ """Format the messages into the required format of Yi Chat API.
+
+ Note this strategy maybe not suitable for all scenarios,
+ and developers are encouraged to implement their own prompt
+ engineering strategies.
+
+ The following is an example:
+
+ .. code-block:: python
+
+ prompt1 = model.format(
+ Message("system", "You're a helpful assistant", role="system"),
+ Message("Bob", "Hi, how can I help you?", role="assistant"),
+ Message("user", "What's the date today?", role="user")
+ )
+
+ The prompt will be as follows:
+
+ .. code-block:: python
+
+ # prompt1
+ [
+ {
+ "role": "system",
+ "content": "You're a helpful assistant"
+ },
+ {
+ "role": "user",
+ "content": (
+ "## Conversation History\\n"
+ "Bob: Hi, how can I help you?\\n"
+ "user: What's the date today?"
+ )
+ }
+ ]
+
+ Args:
+ args (`Union[Message, Sequence[Message]]`):
+ The input arguments to be formatted, where each argument
+ should be a `Message` object, or a list of `Message` objects.
+ In distribution, placeholder is also allowed.
+
+ Returns:
+ `List[dict]`:
+ The formatted messages.
+ """
+
+ # TODO: Support Vision model
+ if self.model_name == "yi-vision":
+ raise NotImplementedError(
+ "Yi Vision model is not supported in the current version, "
+ "please format the messages manually.",
+ )
+
+ return ModelWrapperBase.format_for_common_chat_models(*args)
\ No newline at end of file
diff --git a/muagent/orm/__init__.py b/muagent/orm/__init__.py
deleted file mode 100644
index 2a2c21b..0000000
--- a/muagent/orm/__init__.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from .db import _engine, Base
-from loguru import logger
-
-__all__ = [
-
-]
-
-def create_tables():
- Base.metadata.create_all(bind=_engine)
-
-def reset_tables():
- Base.metadata.drop_all(bind=_engine)
- create_tables()
-
-
-def check_tables_exist(table_name) -> bool:
- table_exist = _engine.dialect.has_table(_engine.connect(), table_name, schema=None)
- return table_exist
-
-def table_init():
- if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")) or \
- (not check_tables_exist ("code_base")):
- create_tables()
diff --git a/muagent/project_manager.py b/muagent/project_manager.py
new file mode 100644
index 0000000..5f0250d
--- /dev/null
+++ b/muagent/project_manager.py
@@ -0,0 +1,70 @@
+from typing import (
+ Dict,
+ Optional,
+ Union
+)
+import os, sys, json, random
+from loguru import logger
+
+
+from .schemas.models import ModelConfig, LLMConfig
+from .schemas import ProjectConfig, PromptConfig, AgentConfig
+
+
+def get_project_config_from_env(
+ agent_configs: Optional[Dict[str, AgentConfig]] = None,
+ model_configs: Optional[Dict[str, Union[ModelConfig, LLMConfig]]] = None,
+ prompt_configs: Optional[Dict[str, PromptConfig]] = PromptConfig(),
+) -> ProjectConfig:
+ """"""
+ init_dict = {
+ "model_configs": [model_configs, ModelConfig],
+ "agent_configs": [agent_configs, AgentConfig],
+ "prompt_configs": [prompt_configs, PromptConfig],
+ }
+ project_configs = {
+ "model_configs": None,
+ "agent_configs": None,
+ "prompt_configs": None,
+ }
+ for k, (v, _type) in init_dict.items():
+ if v:
+ pass
+ elif k.upper() in os.environ:
+ v = json.loads(os.environ[k.upper()])
+ vc = {}
+ for kk, vv in v.items():
+ try:
+ vc[kk] = _type(**vv)
+ except:
+ vc[kk] = LLMConfig(**vv)
+ v = vc
+ if v:
+ chat_list = [_type for _type in v.keys() if "chat" in _type]
+ embedding_list = [_type for _type in v.keys() if "embedding" in _type]
+ if chat_list:
+ v["default_chat"] = v[random.choice(chat_list)]
+ model_type = random.choice(chat_list)
+ default_model_config = v[model_type]
+ os.environ["DEFAULT_MODEL_TYPE"] = model_type
+ os.environ["DEFAULT_MODEL_NAME"] = default_model_config.model_name
+ os.environ["DEFAULT_API_KEY"] = default_model_config.api_key or ""
+ os.environ["DEFAULT_API_URL"] = default_model_config.api_url or ""
+ if embedding_list:
+ v["default_embed"] = v[random.choice(embedding_list)]
+ model_type = random.choice(chat_list)
+ default_model_config = v[model_type]
+ os.environ["DEFAULT_EMBED_MODEL_TYPE"] = model_type
+ os.environ["DEFAULT_EMBED_MODEL_NAME"] = default_model_config.model_name
+ os.environ["DEFAULT_EMBED_API_KEY"] = default_model_config.api_key or ""
+ os.environ["DEFAULT_EMBED_API_URL"] = default_model_config.api_url or ""
+ project_configs[k] = v
+ else:
+ logger.warning(
+ f"Cant't init any {k} in this env."
+ )
+ else:
+ logger.warning(
+ f"Cant't init any {k} in this env."
+ )
+ return ProjectConfig(**project_configs)
\ No newline at end of file
diff --git a/muagent/prompt_manager/__init__.py b/muagent/prompt_manager/__init__.py
new file mode 100644
index 0000000..00ea8cf
--- /dev/null
+++ b/muagent/prompt_manager/__init__.py
@@ -0,0 +1,8 @@
+from .base_prompt_manager import BasePromptManager
+from .common_prompt_manager import CommonPromptManager
+
+
+__all__ = [
+ "BasePromptManager",
+ "CommonPromptManager"
+]
\ No newline at end of file
diff --git a/muagent/prompt_manager/base.py b/muagent/prompt_manager/base.py
new file mode 100644
index 0000000..479f1eb
--- /dev/null
+++ b/muagent/prompt_manager/base.py
@@ -0,0 +1,32 @@
+from .language.en import *
+from .language.zh import *
+
+
+TITLE_CONFIGS_LANGUAGE = {
+ "en": EN_TITLE_CONFIGS,
+ "zh": ZH_TITLE_CONFIGS,
+}
+
+TITLE_EDGES_LANGUAGE = {
+ "en": EN_TITLE_EDGES,
+ "zh": ZH_TITLE_EDGES,
+}
+
+TITLE_FORMAT_LANGUAGE = {
+ "en": EN_TITLE_FORMAT,
+ "zh": ZH_TITLE_FORMAT
+}
+
+TITLE_LANGUAGE = {
+ "en": EN_TITLES,
+ "zh": ZH_TITLES
+}
+
+ZERO_TITLES_LANGUAGE = {
+ "en": EN_ZERO_TITLES,
+ "zh": ZH_ZERO_TITLES,
+}
+COMMON_TEXT_LANGUAGE = {
+ "en": EN_COMMON_TEXT,
+ "zh": ZH_COMMON_TEXT
+}
\ No newline at end of file
diff --git a/muagent/prompt_manager/base_prompt_manager.py b/muagent/prompt_manager/base_prompt_manager.py
new file mode 100644
index 0000000..0c0c0c0
--- /dev/null
+++ b/muagent/prompt_manager/base_prompt_manager.py
@@ -0,0 +1,506 @@
+from abc import ABCMeta, abstractmethod
+from typing import (
+ Any,
+ Union,
+ Optional,
+ Type,
+ Literal,
+ Dict,
+ List,
+ Tuple,
+ Sequence,
+ Mapping
+)
+from pydantic import BaseModel
+from loguru import logger
+import os
+import uuid
+import copy
+
+from .base import *
+from .util import edges_to_graph_with_cycle_detection
+from ..sandbox import NBClientBox
+from ..tools import get_tool
+from ..schemas import Memory, Message, PromptConfig
+from ..schemas.common import ActionStatus, LogVerboseEnum
+
+from muagent.base_configs.env_config import KB_ROOT_PATH
+
+
+class _PromptManagerWrapperMeta(ABCMeta):
+ """A meta call to replace the prompt manager wrapper's __call__ function with
+ wrapper about error handling."""
+
+ def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any:
+ if "__call__" in attrs:
+ attrs["__call__"] = attrs["__call__"]
+ return super().__new__(mcs, name, bases, attrs)
+
+ def __init__(cls, name: Any, bases: Any, attrs: Any) -> None:
+ if not hasattr(cls, "_registry"):
+ cls._registry = {}
+ cls._type_registry = {}
+ else:
+ cls._registry[name] = cls
+ if hasattr(cls, "pm_type"):
+ cls._type_registry[cls.pm_type] = cls
+ super().__init__(name, bases, attrs)
+
+
+class BasePromptManager(metaclass=_PromptManagerWrapperMeta):
+
+ pm_type: str = "BasePromptManager"
+ """The type of prompt manager."""
+
+ def __init__(
+ self,
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = "",
+ prompt: Optional[str] = "",
+ language: Literal["en", "zh"] = "en",
+ *,
+ monitored_agents=[],
+ monitored_fields=[],
+ log_verbose: str = "0",
+ workdir_path: str = KB_ROOT_PATH,
+ **kwargs
+ ):
+ #
+ self.system_prompt = system_prompt
+ self.input_template = input_template
+ self.output_template = output_template
+ self.prompt = prompt
+ self.language = language
+ # decrapted
+ self.monitored_agents = monitored_agents
+ self.monitored_fields = monitored_fields
+ #
+ self.extra_registry_titles: Dict = {}
+ self.extra_register_edges: Sequence = []
+ self.new_dfsindex_to_str_format: Dict = {}
+ """use {title name} {description/function_value}"""
+
+ #
+ self.codebox = NBClientBox(do_code_exe=True) # Initialize code execution box
+ self.workdir_path = workdir_path # Set the working directory path
+ self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose # Configure logging verbosity
+
+ @classmethod
+ def from_config(self, prompt_config: PromptConfig, **kwargs) -> 'BasePromptManager':
+ """Get the prompt manager from PromptConfig"""
+ init_kwargs = {**kwargs, **prompt_config.dict()}
+ return self.get_wrapper(prompt_config.prompt_manager_type)(**init_kwargs)
+
+ @classmethod
+ def get_wrapper(cls, prompt_manager_type: str) -> Type['BasePromptManager']:
+ """Get the specific PromptManager wrapper"""
+ if prompt_manager_type in cls._type_registry:
+ return cls._type_registry[prompt_manager_type] # type: ignore[return-value]
+ elif prompt_manager_type in cls._registry:
+ return cls._registry[prompt_manager_type] # type: ignore[return-value]
+ else:
+ raise KeyError(
+ f"Unsupported prompt_manager_type [{prompt_manager_type}]"
+ )
+
+ def register_graph(
+ self,
+ title_configs: Mapping[str, Mapping] = {},
+ title_edges: Sequence[Sequence[str]] = {},
+ title_format: Mapping[int, str] = {},
+ titles: Mapping[str, Sequence[str]] = {},
+ zero_titles: Mapping = {},
+ common_texts: Mapping[str, str] = {}
+ ):
+ """transform title and edge into title graph to execute"""
+ # custom define
+ self.register_env(
+ title_configs, title_edges, title_format, titles,
+ zero_titles=zero_titles,
+ common_texts=common_texts
+ )
+ self.register_prompt()
+
+ # prepare title graph
+ start_nodes, self.title_graph = edges_to_graph_with_cycle_detection(self._registry_edges)
+ for title in start_nodes:
+ if title not in self._title_prefix + self._title_suffix:
+ self._title_middle.append(title)
+
+ if LogVerboseEnum.le(LogVerboseEnum.Log3Level, os.environ.get("log_verbose", "0")):
+ logger.debug(f"{self._registry_titles}, {self._registry_edges}, {self.title_graph}")
+
+ def register_env(
+ self,
+ title_configs: Mapping[str, Mapping] = {},
+ title_edges: Sequence[Sequence[str]] = {},
+ title_format: Mapping[int, str] = {},
+ titles: Mapping[str, Sequence[str]] = {},
+ *,
+ zero_titles: Mapping = {},
+ common_texts: Mapping[str, str] = {}
+ ):
+ self._registry_titles = copy.deepcopy(title_configs)
+ self._registry_titles.update(self.extra_registry_titles)
+ self._registry_edges = copy.deepcopy(title_edges)
+ self._registry_edges.extend(self.extra_register_edges)
+
+ self._dfsindex_to_str_format = copy.deepcopy(title_format)
+ self._dfsindex_to_str_format.update(self.new_dfsindex_to_str_format)
+
+ self._title_prefix = titles.get("title_prefix", [])
+ self._title_suffix = titles.get("title_suffix", [])
+ self._title_middle = titles.get("title_middle", [])
+
+ self._zero_titles = copy.deepcopy(zero_titles) # or ZERO_TITLES_LANGUAGE.get(self.language)
+ self._common_texts = copy.deepcopy(common_texts) or COMMON_TEXT_LANGUAGE.get(self.language)
+
+ @abstractmethod
+ def register_prompt(self, ):
+ """register input/output/prompt into titles and edges"""
+ raise NotImplementedError(
+ f"Prompt Manager Wrapper [{type(self).__name__}]"
+ f" is missing the required `register_prompt`"
+ f" method.",
+ )
+
+ def pre_print(self, **kwargs) -> str:
+ kwargs.update({"is_pre_print": True})
+ prompt = self.generate_prompt(**kwargs)
+ return prompt
+
+ def generate_prompt(self, **kwargs) -> str:
+ '''force to print all prompt format whatever it has value'''
+ if self.prompt:
+ return self.prompt.format(**self.handler_prompt_values(**kwargs))
+
+ is_pre_print = kwargs.get("is_pre_print", False)
+ # update title's description and function_value
+ title_values = {}
+ for title, title_config in self._registry_titles.items():
+ if hasattr(self, title_config["function"]):
+
+ handler = getattr(self, title_config["function"])
+ function_value = handler(
+ prompt=title_config.get("prompt", ""), title_key=title, **kwargs
+ ) if handler else None
+ else:
+ function_value = title_config["description"]
+
+ title_values[title] = {
+ "description": title_config["description"],
+ "function_value": function_value,
+ "display_type": title_config["display_type"],
+ "str_template": title_config.get("str_template", ""),
+ "prompt": title_config.get("prompt", ""),
+ }
+
+ # transform title values into 'markdown' prompt by title graph
+ prompt_values: List[str] = []
+ prompt_values = self._process_title_values(
+ title_values,
+ title_type="description",
+ prompt_values=prompt_values,
+ is_pre_print=is_pre_print
+ )
+
+ transition_text = self._common_texts["transition_text"]
+ prompt_values.append(self._dfsindex_to_str_format[0].format(transition_text, ""))
+
+ prompt_values = self._process_title_values(
+ title_values,
+ title_type="value",
+ prompt_values=prompt_values,
+ is_pre_print=is_pre_print
+ )
+
+ # logger.info(prompt_values)
+ reponse_text = self._common_texts["reponse_text"]
+ if not any("RESPONSE OUTPUT" in i for i in prompt_values):
+ prompt_values.append(reponse_text)
+ elif not any(["RESPONSE OUTPUT\n" in i for i in prompt_values]):
+ prompt_values.append(self._dfsindex_to_str_format[0].format("RESPONSE OUTPUT", ""))
+ # return prompt except '\n' in end
+ prompt_values = [pv.rstrip('\n') for pv in prompt_values]
+ return '\n\n'.join(prompt_values)
+
+ def _process_title_values(
+ self,
+ title_values: Mapping[str, Mapping[str, Any]],
+ title_type: Literal["description", "value"],
+ prompt_values: Sequence[str] = [],
+ is_pre_print=False
+ ):
+ '''process title values to prompt'''
+
+ def append_prompt_dfs(titles: Sequence[str], prompt_values: Sequence=[], dfs_index=0):
+ ''''''
+ if titles == [] or titles is None: return prompt_values
+ for title in titles:
+ title_value = title_values.get(title)
+ ctitles = self.title_graph.get(title, [])
+ ctitle_values = [
+ ctitle
+ for ctitle in ctitles
+ if title_values.get(ctitle, {}).get('function_value')
+ ]
+
+ str_template = title_value.get(
+ "str_template", self._dfsindex_to_str_format[dfs_index]
+ ) or self._dfsindex_to_str_format[dfs_index]
+ description = title_value["description"]
+ function_value = title_value["function_value"]
+ display_type = title_value["display_type"]
+ prompt = title_value["prompt"]
+
+ # logger.info(
+ # f"title={title}, description={description}, function_value={function_value} \n"
+ # f"display_type={display_type}, str_template={str_template} \n"
+ # f"ctitles= {ctitles}, ctitle_values={ctitle_values}"
+ # )
+
+ # todo display_type==only_value
+ if title_type=="description":
+ if display_type == "title":
+ prompt_values.append(str_template.format(title, description or function_value))
+ elif display_type=="description" and function_value:
+ prompt_values.append(str_template.format(title, function_value or description or prompt))
+ elif display_type == "value" and (function_value or len(ctitle_values)>0):
+ prompt_values.append(str_template.format(title, description or function_value))
+ elif display_type == "values" and len(ctitle_values)>0:
+ prompt_values.append(str_template.format(title, description or function_value))
+ elif display_type == "must_value" and (description or function_value or len(ctitle_values)>0):
+ prompt_values.append(str_template.format(title, description or function_value))
+ elif is_pre_print:
+ prompt_values.append(str_template.format(title, description or function_value))
+ elif title_type=="value":
+ if display_type == "values" and len(ctitle_values)>0:
+ prompt_values.append(str_template.format(title.replace(' FORMAT', ''), ""))
+ # must value
+ elif display_type == "must_value" and (function_value and len(ctitle_values)>0):
+ prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value))
+ continue
+ elif display_type == "must_value" and function_value:
+ prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value))
+ elif display_type == "must_value" and len(ctitle_values)>0:
+ prompt_values.append(str_template.format(title.replace(' FORMAT', ''), ""))
+ # value
+ elif display_type == "value" and (function_value and len(ctitle_values)>0):
+ prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value))
+ elif display_type == "value" and function_value:
+ prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value))
+ elif display_type == "value" and len(ctitle_values)>0:
+ prompt_values.append(str_template.format(title.replace(' FORMAT', ''), ""))
+ elif is_pre_print and display_type not in ["title", "description"]:
+ prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value))
+
+ prompt_values = append_prompt_dfs(ctitles, prompt_values, dfs_index+1)
+
+ return prompt_values
+
+ start_titles = self._title_prefix + self._title_middle + self._title_suffix
+ return append_prompt_dfs(start_titles, prompt_values)
+
+ def parser(self, message: Message) -> Message:
+ '''parse llm output into dict'''
+ return message
+
+ def step_router(
+ self,
+ msg: Message,
+ session_index: str = "",
+ **kwargs
+ ) -> Tuple[Message, ...]:
+ """Route a message to the appropriate step for processing based on its action status.
+
+ Args:
+ msg (Message): The input message that needs processing.
+ session_index (str): The session identifier for managing the conversation.
+ **kwargs: Additional parameters for processing.
+
+ Returns:
+ Tuple[Message, ...]: The processed message and any observation message.
+ """
+ session_index = msg.session_index or session_index or str(uuid.uuid4())
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"message.action_status: {msg.action_status}")
+
+ observation_msg = None
+ # Determine the action to take based on the message's action status
+ if msg.action_status == ActionStatus.CODE_EXECUTING:
+ msg, observation_msg = self.code_step(msg, session_index)
+ elif msg.action_status == ActionStatus.TOOL_USING:
+ msg, observation_msg = self.tool_step(msg, session_index, **kwargs)
+ elif msg.action_status == ActionStatus.CODING2FILE:
+ self.save_code2file(msg, self.workdir_path)
+ # Handle other action statuses as needed (currently no operations for these)
+ elif msg.action_status == ActionStatus.CODE_RETRIEVAL:
+ pass
+ elif msg.action_status == ActionStatus.CODING:
+ pass
+
+ return msg, observation_msg
+
+ def code_step(self, msg: Message, session_index: str) -> Message:
+ """Execute code contained in the message.
+
+ Args:
+ msg (Message): The message containing code to be executed.
+ session_index (str): The session identifier for managing the conversation.
+
+ Returns:
+ Tuple[Message, Message]: The processed message and an observation message regarding code execution.
+ """
+ # Execute the code using the codebox and capture the result
+ code_key = "code_content"
+ code_content = msg.spec_parsed_content.get(code_key, "")
+ code_answer = self.codebox.chat(
+ '```python\n{}```'.format(code_content)
+ )
+
+ # Prepare a response message based on code execution result
+ observation_title = {
+ "error": "The return error after executing the above code is {code_answer},need to recover.\n",
+ "accurate": "The return information after executing the above code is {code_answer}.\n",
+ "figure": "The return figure name is {uid} after executing the above code.\n"
+ }
+ code_prompt = (
+ observation_title["error"].format(code_answer=code_answer.code_exe_response)
+ if code_answer.code_exe_type == "error" else
+ observation_title["accurate"].format(code_answer=code_answer.code_exe_response)
+ )
+
+ # Create an observation message for logging code execution outcome
+ observation_msg = Message(
+ session_index=session_index,
+ role_name="function",
+ role_type="observation",
+ input_text=code_content,
+ )
+
+ uid = str(uuid.uuid1()) # Generate a unique identifier for related content
+ if code_answer.code_exe_type == "image/png":
+ # If the code execution produces an image, log the result and update the message
+ msg.global_kwargs[uid] = code_answer.code_exe_response
+ msg.step_content += "\n**Observation:**: " + observation_title["figure"].format(uid=uid)
+ msg.parsed_contents.append({"Observation": observation_title["figure"].format(uid=uid)})
+ observation_msg.update_content("\n**Observation:**: " + observation_title["figure"].format(uid=uid))
+ observation_msg.update_parsed_content({"Observation": observation_title["figure"].format(uid=uid)})
+ else:
+ # Log the standard execution result
+ msg.step_content += f"\n**Observation:**: {code_prompt}\n"
+ observation_msg.update_content(code_prompt)
+ observation_msg.update_parsed_content({"Observation": f"{code_prompt}\n"})
+
+ # Log the observations at the defined verbosity level
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"Code Observation: {msg.action_status}, {observation_msg.content}")
+
+ return msg, observation_msg
+
+ def tool_step(
+ self,
+ msg: Message,
+ session_index: str,
+ **kwargs
+ ) -> Message:
+ """Execute a tool based on parameters in the message.
+
+ Args:
+ msg (Message): The message that specifies the tool to be executed.
+ session_index (str): The session identifier for managing the conversation.
+ **kwargs: Additional parameters for processing, including available tools.
+
+ Returns:
+ Tuple[Message, ...]:
+ The processed message and an observation message regarding the tool execution.
+ """
+ observation_title = {
+ "error": "there is no tool can execute.\n",
+ "accurate": "",
+ "figure": "The return figure name is {uid} after executing the above code.\n"
+ }
+ no_tool_msg = "\n**Observation:** there is no tool can execute.\n" # Message for missing tool
+ tool_names = kwargs.get("tools") # Retrieve available tool names
+ extra_params = kwargs.get("extra_params", {})
+ tool_param = msg.spec_parsed_content.get("tool_param", {}) # Parameters for the tool execution
+ tool_param.update(extra_params)
+ tool_name = msg.spec_parsed_content.get("tool_name", "") # Name of the tool to execute
+
+ # Create a message to log the tool execution result
+ observation_msg = Message(
+ session_index=session_index,
+ role_name="function",
+ role_type="observation",
+ input_text=str(tool_param),
+ )
+ if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
+ logger.debug(f"message: {msg.action_status}, {tool_param}")
+
+ if tool_name not in tool_names:
+ msg.step_content += f"\n{no_tool_msg}"
+ observation_msg.update_content(no_tool_msg)
+ observation_msg.update_parsed_content({"Observation": no_tool_msg})
+ else:
+ # Execute the specified tool and capture the result
+ tool = get_tool(tool_name)
+ tool_res = tool.run(**tool_param)
+ msg.step_content += f"\n**Observation:** {tool_res}.\n"
+ msg.parsed_contents.append({"Observation": f"{tool_res}.\n"})
+ observation_msg.update_content(f"**Observation:** {tool_res}.\n")
+ observation_msg.update_parsed_content({"Observation": f"{tool_res}.\n"})
+
+ # Log the observations at the defined verbosity level
+ if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
+ logger.info(f"**Observation:** {msg.action_status}, {observation_msg.content}")
+
+ return msg, observation_msg
+
+ def save_code2file(self, msg: Message, project_dir="./"):
+ """Save the code from the message to a specified file.
+
+ Args:
+ msg (Message): The message containing the code to be saved.
+ project_dir (str): Directory path where the code file will be saved.
+ """
+ filename = msg.parsed_content.get("SaveFileName") # Retrieve filename from message content
+ code = msg.spec_parsed_content.get("code") # Extract code content from the message
+
+ # Replace HTML entities in the code
+ for k, v in {">": ">", "≥": ">=", "<": "<", "≤": "<="}.items():
+ code = code.replace(k, v)
+
+ project_dir_path = os.path.join(self.workdir_path, project_dir) # Construct project directory path
+ file_path = os.path.join(project_dir_path, filename) # Full path for the output code file
+
+ # Create directories if they don't exist
+ if not os.path.exists(file_path):
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+
+ # Write the code to the file
+ with open(file_path, "w") as f:
+ f.write(code)
+
+ def handler_prompt_values(self, **kwargs) -> Mapping[str, str]:
+ """Handling prompt values from memory, message' global or content
+ or step content or spec parsed content
+ """
+ raise NotImplementedError(
+ f"Prompt Manager Wrapper [{type(self).__name__}]"
+ f" is missing the required `handler_prompt_values`"
+ f" method.",
+ )
+
+ def handle_empty_key(self, **kwargs) -> str:
+ '''return "" '''
+ return ""
+
+ def handler_input_key(self, **kwargs) -> str:
+ '''return {input_template}'''
+ return self.input_template
+
+ def handler_output_key(self, **kwargs) -> str:
+ '''return {output_template}'''
+ return self.output_template
+
\ No newline at end of file
diff --git a/muagent/prompt_manager/common_prompt_manager.py b/muagent/prompt_manager/common_prompt_manager.py
new file mode 100644
index 0000000..feb628f
--- /dev/null
+++ b/muagent/prompt_manager/common_prompt_manager.py
@@ -0,0 +1,320 @@
+from typing import (
+ List,
+ Any,
+ Union,
+ Optional,
+ Literal
+)
+import copy
+from pydantic import BaseModel
+import random
+from textwrap import dedent
+from loguru import logger
+import json
+
+from .base import *
+from .base_prompt_manager import BasePromptManager
+from ..schemas import Memory, Message, PromptConfig
+from ..tools import get_tool, BaseToolModel
+
+from muagent.connector.utils import *
+
+
+class CommonPromptManager(BasePromptManager):
+ """Prompt Manager of MarkDown style"""
+
+ pm_type: str = "CommonPromptManager"
+ """The type of prompt manager."""
+
+ def __init__(
+ self,
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = "",
+ prompt: Optional[str] = "",
+ language: Literal["en", "zh"] = "en",
+ *,
+ extra_registry_titles: Dict = {},
+ extra_register_edges: List = [],
+ new_dfsindex_to_str_format: Dict = {},
+ monitored_agents=[],
+ monitored_fields=[],
+ **kwargs
+ ):
+ super().__init__(
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=output_template,
+ prompt=prompt,
+ language=language,
+ monitored_agents=monitored_agents,
+ monitored_fields=monitored_fields,
+ **kwargs
+ )
+
+ # update new titles
+ self.extra_registry_titles: Dict = extra_registry_titles
+ self.extra_register_edges: List = extra_register_edges
+ self.new_dfsindex_to_str_format: Dict = new_dfsindex_to_str_format
+
+ #
+ self.register_graph(
+ TITLE_CONFIGS_LANGUAGE[self.language],
+ TITLE_EDGES_LANGUAGE[self.language],
+ TITLE_FORMAT_LANGUAGE[self.language],
+ titles=TITLE_LANGUAGE[self.language],
+ zero_titles=ZERO_TITLES_LANGUAGE[self.language],
+ common_texts=COMMON_TEXT_LANGUAGE[self.language],
+ )
+
+ def register_prompt(self, ):
+ """register input/output/prompt into titles and edges"""
+ input_str, output_str = "", ""
+ input_values, output_values = {}, {}
+
+ if self.system_prompt:
+ input_str = extract_section(
+ self.system_prompt,
+ self._zero_titles["input"]
+ )
+ output_str = extract_section(
+ self.system_prompt,
+ self._zero_titles["output"]
+ )
+
+ input_values = parse_section_to_dict(
+ self.system_prompt,
+ self._zero_titles["input"]
+ )
+ output_values = parse_section_to_dict(
+ self.system_prompt,
+ self._zero_titles["output"]
+ )
+ self.system_prompt = extract_section(
+ self.system_prompt,
+ self._zero_titles["agent"]
+ ) or self.system_prompt
+
+ if self.input_template:
+ input_values = parse_section_to_dict(
+ self.input_template,
+ self._zero_titles["input"]
+ ) or input_values
+
+ self.input_template = extract_section(
+ self.input_template,
+ self._zero_titles["input"]
+ ) or input_str
+
+ if self.output_template:
+ output_values = parse_section_to_dict(
+ self.output_template,
+ self._zero_titles["output"]
+ ) or output_values
+ self.output_template = extract_section(
+ self.output_template,
+ self._zero_titles["output"]
+ ) or output_str
+ #
+ self._registry_titles[self._zero_titles["input"]].update({
+ "description": self.input_template or input_str,
+ })
+
+ self._registry_titles[self._zero_titles["output"]].update({
+ "description": self.output_template or output_str,
+ })
+ self._registry_titles.update(
+ {k: {
+ "description": v,
+ "function": "handle_custom_data",
+ "display_type": "value",
+ "str_template": "**{}:** {}",
+ }
+ for k,v in (input_values|output_values).items()}
+ )
+ self._registry_edges.extend(
+ [(self._zero_titles["output"], k) for k in input_values.keys()]
+ )
+ self._registry_edges.extend(
+ [(self._zero_titles["output"], k) for k in output_values.keys()]
+ )
+
+ def pre_print(self, **kwargs):
+ kwargs.update({"is_pre_print": True})
+ prompt = self.generate_prompt(**kwargs)
+
+ input_keys = parse_section(self.system_prompt, self._zero_titles["output"])
+ llm_predict = "\n".join([f"**{k}:**" for k in input_keys])
+ return_prompt = (
+ f"{prompt}\n\n"
+ f"{'#'*19}"
+ "\n<<<>>>\n"
+ f"{'#'*19}"
+ f"\n\n{llm_predict}\n"
+ )
+ return return_prompt
+
+ def parser(self, message: Message) -> Message:
+ '''parse llm output into dict'''
+ content = message.content
+ # parse start
+ parsed_dict = parse_text_to_dict(content)
+ spec_parsed_dict = parse_dict_to_dict(parsed_dict)
+ # select parse value
+ action_value = parsed_dict.get('Action Status')
+ if action_value:
+ action_value = action_value.lower()
+
+ code_content_value = spec_parsed_dict.get('python') or \
+ spec_parsed_dict.get('java')
+ if action_value == 'tool_using':
+ tool_params_value = spec_parsed_dict.get('json')
+ else:
+ tool_params_value = {}
+
+ # add parse value to message
+ message.action_status = action_value or "default"
+ spec_parsed_dict["code_content"] = code_content_value
+ spec_parsed_dict["tool_param"] = tool_params_value.get("tool_params")
+ spec_parsed_dict["tool_name"] = tool_params_value.get("tool_name")
+ #
+ message.update_parsed_content(parsed_dict)
+ message.update_spec_parsed_content(spec_parsed_dict)
+ return message
+
+ def handler_prompt_values(self, **kwargs) -> Dict[str, str]:
+ memory: Memory = kwargs.get("memory", None)
+ query: Message = kwargs.get("query", None)
+ result = {
+ "query": query.content or query.input_text if query else "",
+ "memory": memory.to_format_messages(format_type="str")
+ }
+ return result
+
+ def handle_custom_data(self, **kwargs):
+ '''get key-value from parsed_output_list or global_kargs'''
+ key: str = kwargs.get("title_key", "")
+ query: Message = kwargs.get('query')
+
+ keys = [
+ "_".join([i.title() for i in key.split(" ")]),
+ " ".join([i.title() for i in key.split("_")]),
+ key
+ ]
+ keys = list(set(keys))
+
+ content = ""
+ for key in keys:
+ if key in query.spec_parsed_content:
+ content = query.spec_parsed_content.get(key)
+ content = "\n".join(content) if isinstance(content, list) else content
+ break
+ if key in query.global_kwargs:
+ content = query.global_kwargs.get(key)
+ content = "\n".join(content) if isinstance(content, list) else content
+ break
+
+ return content
+
+ def handle_tool_data(self, **kwargs):
+ if 'tools' not in kwargs: return ""
+
+ tools: List = kwargs.get('tools')
+ prompt: str = kwargs.get('prompt')
+ tools: List[BaseToolModel] = [get_tool(tool) for tool in tools if isinstance(tool, str)]
+
+ if len(tools) == 0: return ""
+
+ tool_strings = []
+ for tool in tools:
+ args_str = f'args: {str(tool.intput_to_json_schema())}' if tool.ToolInputArgs else ""
+ tool_strings.append(f"{tool.name}: {tool.description}, {args_str}")
+ formatted_tools = "\n".join(tool_strings)
+
+ tool_names = ", ".join([tool.name for tool in tools])
+
+ tool_prompt = dedent(prompt.format(formatted_tools=formatted_tools, tool_names=tool_names))
+ while "\n " in tool_prompt:
+ tool_prompt = tool_prompt.replace("\n ", "\n")
+
+ return tool_prompt
+
+ def handle_agent_data(self, **kwargs):
+ """"""
+ if 'agent_names' not in kwargs or "agent_descs" not in kwargs:
+ return ""
+
+ agent_names: List = kwargs.get('agent_names')
+ agent_descs: List = kwargs.get('agent_descs')
+ prompt: str = kwargs.get('prompt')
+
+ if len(agent_names) == 0: return ""
+
+ random.shuffle(agent_names)
+ agent_descriptions = []
+ for agent_name, desc in zip(agent_names, agent_descs):
+ while "\n\n" in desc:
+ desc = desc.replace("\n\n", "\n")
+ desc = desc.replace("\n", ",")
+ agent_descriptions.append(
+ f'"role name: {agent_name}\nrole description: {desc}"'
+ )
+
+ agent_description = "\n".join(agent_descriptions)
+ agent_prompt = dedent(
+ prompt.format(agents=agent_description, agent_names=agent_names)
+ )
+
+ while "\n " in agent_prompt:
+ agent_prompt = agent_prompt.replace("\n ", "\n")
+
+ return agent_prompt
+
+ def handle_current_query(self, **kwargs) -> str:
+ """"""
+ query: Message = kwargs.get('query')
+ if query:
+ return query.input_text
+ return ""
+
+ def handle_session_records(self, **kwargs) -> str:
+
+ memory: Memory = kwargs.get('memory', Memory(messages=[]))
+ return memory.to_format_messages(
+ content_key='parsed_contents',
+ format_type='str',
+ with_tag=True
+ )
+
+ def handle_agent_profile(self, **kwargs) -> str:
+ return extract_section(self.system_prompt, 'AGENT PROFILE') or self.system_prompt
+
+ def handle_output_format(self, **kwargs) -> str:
+ return extract_section(self.system_prompt, self._zero_titles["output"])
+
+ def handle_react_memory(self, **kwargs) -> str:
+ react_memory: Memory = kwargs.get('react_memory')
+
+ if react_memory:
+ return react_memory.to_format_messages(format_type="str")
+ return ""
+
+ def handle_task_memory(self, **kwargs) -> str:
+ if 'task_memory' not in kwargs:
+ return ""
+
+ task_memory: Memory = kwargs.get('task_memory', Memory(messages=[]))
+ if task_memory is None:
+ return ""
+
+ return "\n".join([
+ "\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]])
+ for _dict in task_memory.get_memory_values("parsed_content")
+ ])
+
+ def handle_current_plan(self, **kwargs) -> str:
+ if 'query' not in kwargs:
+ return ""
+ query: Message = kwargs['query']
+ return query.global_kwargs.get("CURRENT_STEP", "")
+
\ No newline at end of file
diff --git a/muagent/prompt_manager/language/en.py b/muagent/prompt_manager/language/en.py
new file mode 100644
index 0000000..7fca273
--- /dev/null
+++ b/muagent/prompt_manager/language/en.py
@@ -0,0 +1,89 @@
+EN_TITLE_EDGES = [
+ ("AGENT PROFILE", "ROLE"),
+ ("AGENT PROFILE", "AGENT INFORMATION"),
+ ("AGENT PROFILE", "TOOL INFORMATION"),
+ ("CONTEXT FORMAT", "SESSION RECORDS"),
+ ("CONTEXT FORMAT", "CURRENT QUERY"),
+]
+
+EN_TITLE_CONFIGS = {
+ "AGENT PROFILE": {
+ "description": "",
+ "function": "handle_empty_key",
+ "display_type": "title"
+ },
+ "CONTEXT FORMAT": {
+ "description": "Use the content provided in the context.",
+ "function": "handle_empty_key",
+ "display_type": "values"
+ },
+ "INPUT FORMAT": {
+ "description": "",
+ "function": "handle_empty_key",
+ "display_type": "values"
+ },
+ "RESPONSE OUTPUT FORMAT": {
+ "description": "",
+ "function": "handle_react_memory",
+ "display_type": "must_value"
+ },
+ "ROLE": {
+ "description": "",
+ "prompt": "",
+ "function": "handle_agent_profile",
+ "display_type": "description"
+ },
+ "TOOL INFORMATION": {
+ "description": "",
+ "prompt": """Below is a list of tools that are available for your use:{formatted_tools}\nvalid "tool_name" value is:\n{tool_names}""",
+ "function": "handle_tool_data",
+ "display_type": "description"
+ },
+ "AGENT INFORMATION": {
+ "description": "",
+ "prompt": '''Please ensure your selection is one of the listed roles. Available roles for selection:\n{agents}Please ensure select the Role from agent names, such as {agent_names}''',
+ "function": "handle_agent_data",
+ "display_type": "description"
+ },
+ "SESSION RECORDS": {
+ "description": "In this part, we will supply with the context about this question.",
+ "function": "handle_session_records",
+ "display_type": "value"
+ },
+ "CURRENT QUERY": {
+ "description": "In this part, we will supply with current question to do.",
+ "function": "handle_current_query",
+ "display_type": "value"
+ },
+}
+
+
+
+
+EN_TITLE_FORMAT = {
+ 0: "#### {}\n{}",
+ 1: "### {}\n{}",
+ 2: "## {}\n{}",
+ 3: "# {}\n{}",
+}
+
+
+EN_ZERO_TITLES = {
+ "agent": "AGENT PROFILE",
+ "context": "CONTEXT FORMAT",
+ "input": "INPUT FORMAT",
+ "output": "RESPONSE OUTPUT FORMAT"
+}
+
+
+EN_TITLES = {
+ "title_prefix": [EN_ZERO_TITLES["agent"], EN_ZERO_TITLES["context"]],
+ "title_suffix": [EN_ZERO_TITLES["input"], EN_ZERO_TITLES["output"]],
+ "title_middle": [],
+}
+
+
+EN_COMMON_TEXT = {
+ "transition_text": "BEGIN!!!",
+ "reponse_text": "Please response:"
+}
diff --git a/muagent/prompt_manager/language/zh.py b/muagent/prompt_manager/language/zh.py
new file mode 100644
index 0000000..cfaf427
--- /dev/null
+++ b/muagent/prompt_manager/language/zh.py
@@ -0,0 +1,87 @@
+ZH_TITLE_EDGES = [
+ ("智能体配置", "角色"),
+ ("智能体配置", "智能体信息"),
+ ("智能体配置", "工具信息"),
+ ("上下文", "会话记录"),
+ ("上下文", "当前问题"),
+]
+
+ZH_TITLE_CONFIGS = {
+ "智能体配置": {
+ "description": "",
+ "function": "handle_empty_key",
+ "display_type": "title"
+ },
+ "上下文": {
+ "description": "使用下面内容作为上下文的信息。",
+ "function": "handle_empty_key",
+ "display_type": "values"
+ },
+ "输入": {
+ "description": "",
+ "function": "handle_empty_key",
+ "display_type": "values"
+ },
+ "输出": {
+ "description": "",
+ "function": "handle_react_memory",
+ "display_type": "must_value"
+ },
+ "角色": {
+ "description": "",
+ "prompt": "",
+ "function": "handle_agent_profile",
+ "display_type": "description"
+ },
+ "工具信息": {
+ "description": "",
+ "prompt": """以下是您可以使用的工具列表:{formatted_tools}\n有效的 "tool_name" 值是:\n{tool_names}""",
+ "function": "handle_tool_data",
+ "display_type": "description"
+ },
+ "智能体信息": {
+ "description": "",
+ "prompt": '''请确保您的选择是列出的角色之一。可供选择的角色有:\n{agents}请确保从代理名称中选择角色,例如 {agent_names}''',
+ "function": "handle_agent_data",
+ "display_type": "description"
+ },
+ "会话记录": {
+ "description": "在这个部分,我们将提供有关这个问题的上下文。",
+ "function": "handle_session_records",
+ "display_type": "value"
+ },
+ "当前问题": {
+ "description": "在这个部分,我们将提供当前需要处理的问题。",
+ "function": "handle_current_query",
+ "display_type": "value"
+ },
+}
+
+
+
+
+ZH_TITLE_FORMAT = {
+ 0: "#### {}\n{}",
+ 1: "### {}\n{}",
+ 2: "## {}\n{}",
+ 3: "# {}\n{}",
+}
+
+ZH_ZERO_TITLES = {
+ "agent": "智能体配置",
+ "context": "上下文",
+ "input": "输入",
+ "output": "输出"
+}
+
+ZH_TITLES = {
+ "title_prefix": [ZH_ZERO_TITLES["agent"], ZH_ZERO_TITLES["context"]],
+ "title_suffix": [ZH_ZERO_TITLES["input"], ZH_ZERO_TITLES["output"]],
+ "title_middle": [],
+}
+
+
+ZH_COMMON_TEXT = {
+ "transition_text": "开始",
+ "reponse_text": "请回答:"
+}
diff --git a/muagent/prompt_manager/util.py b/muagent/prompt_manager/util.py
new file mode 100644
index 0000000..0885b1d
--- /dev/null
+++ b/muagent/prompt_manager/util.py
@@ -0,0 +1,94 @@
+from collections import defaultdict
+
+
+
+class GraphCycleError(Exception):
+ """Custom exception for graph cycle detection."""
+ pass
+
+
+
+def edges_to_graph_with_cycle_detection(intervals):
+ """Converts a list of intervals into a directed graph and checks for cycles.
+
+ Args:
+ intervals (list of tuple): List of intervals where each interval is defined by (start, end).
+
+ Returns:
+ tuple: A tuple containing a list of start nodes (nodes with indegree of 0) and the constructed graph.
+
+ Raises:
+ GraphCycleError: If the graph contains a cycle.
+ """
+
+ graph = defaultdict(list) # Adjacency list for the graph
+ indegree = defaultdict(int) # Count of incoming edges for each node
+
+ # Build the graph and the indegree table
+ for start, end in intervals:
+ graph[start].append(end) # Add directed edge from start to end
+ indegree[end] += 1 # Increment indegree of end node
+ # Ensure every node is in the graph (even nodes without outgoing edges)
+ if start not in indegree:
+ indegree[start] = 0 # Initialize indegree for start node
+
+ # Find all starting nodes (indegree of 0)
+ start_nodes = [node for node in indegree if indegree[node] == 0]
+
+ # Detect cycle in the graph
+ if detect_cycle(graph):
+ raise GraphCycleError("Graph contains a cycle!") # Raise error if cycle is found
+
+ return start_nodes, graph
+
+
+
+def detect_cycle(graph):
+ """Detects if a directed graph contains a cycle using DFS.
+
+ Args:
+ graph (dict): The adjacency list of the graph.
+
+ Returns:
+ bool: True if a cycle is detected, False otherwise.
+ """
+
+ visited = set() # To keep track of visited nodes
+ rec_stack = set() # To keep track of nodes currently in the recursion stack
+
+ def dfs(node):
+ """Performs a DFS on the graph to detect cycles.
+
+ Args:
+ node: Current node being visited.
+
+ Returns:
+ bool: True if a cycle is detected.
+ """
+ # If node is in recursion stack, a cycle is found
+ if node in rec_stack:
+ return True
+ # If node is already visited, no need to check it again
+ if node in visited:
+ return False
+
+ # Mark the current node as visited and add to recursion stack
+ visited.add(node)
+ rec_stack.add(node)
+
+ # Use list() to copy neighbors to avoid modifying while iterating
+ for neighbor in list(graph[node]):
+ if dfs(neighbor): # Recursive call for each neighbor
+ return True # Cycle detected in the neighbor
+
+ # Remove the node from the recursion stack after visiting
+ rec_stack.remove(node)
+ return False # No cycle detected in this path
+
+ # Iterate over each node in the graph to detect cycles
+ for node in list(graph.keys()):
+ if node not in visited: # Proceed if the node hasn't been visited yet
+ if dfs(node): # Start DFS
+ return True # Cycle found
+
+ return False # No cycles found in the graph
\ No newline at end of file
diff --git a/muagent/sandbox/__init__.py b/muagent/sandbox/__init__.py
index 435da9b..ac51b6b 100644
--- a/muagent/sandbox/__init__.py
+++ b/muagent/sandbox/__init__.py
@@ -1,6 +1,7 @@
from .basebox import CodeBoxResponse
from .pycodebox import PyCodeBox
+from .nbclient import NBClientBox, NoteBookExecutor
__all__ = [
- "CodeBoxResponse", "PyCodeBox"
+ "CodeBoxResponse", "PyCodeBox", "NBClientBox"
]
\ No newline at end of file
diff --git a/muagent/sandbox/nbclient.py b/muagent/sandbox/nbclient.py
new file mode 100644
index 0000000..c52ce01
--- /dev/null
+++ b/muagent/sandbox/nbclient.py
@@ -0,0 +1,297 @@
+"""Service for executing jupyter notebooks interactively
+Partially referenced the implementation of
+https://github.com/modelscope/agentscope/blob/main/src/agentscope/service/execute_code/exec_notebook.py
+"""
+import base64
+import asyncio
+from loguru import logger
+
+try:
+ import nbclient
+ import nbformat
+except ImportError:
+ nbclient = None
+ nbformat = None
+
+
+import os, asyncio, re
+from typing import List, Optional
+from loguru import logger
+from ..base_configs.env_config import KB_ROOT_PATH
+from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus
+
+
+
+class NoteBookExecutor:
+ """
+ Class for executing jupyter notebooks block interactively.
+ To use the service function, you should first init the class, then call the
+ run_code_on_notebook function.
+
+ Example:
+
+ ```ipython
+ from agentscope.service.service_toolkit import *
+ from agentscope.service.execute_code.exec_notebook import *
+ nbe = NoteBookExecutor()
+ code = "print('helloworld')"
+ # calling directly
+ nbe.run_code_on_notebook(code)
+
+ >>> Executing function run_code_on_notebook with arguments:
+ >>> code: print('helloworld')
+ >>> END
+
+ # calling with service toolkit
+ service_toolkit = ServiceToolkit()
+ service_toolkit.add(nbe.run_code_on_notebook)
+ input_obs = [{"name": "run_code_on_notebook", "arguments":{"code": code}}]
+ res_of_string_input = service_toolkit.parse_and_call_func(input_obs)
+
+ "1. Execute function run_code_on_notebook\n [ARGUMENTS]:\n code: print('helloworld')\n [STATUS]: SUCCESS\n [RESULT]: ['helloworld\\n']\n"
+
+ ```
+ """ # noqa
+
+ def __init__(
+ self,
+ timeout: int = 300,
+ work_path: str = KB_ROOT_PATH,
+ ) -> None:
+ """
+ The construct function of the NoteBookExecutor.
+ Args:
+ timeout (Optional`int`):
+ The timeout for each cell execution.
+ Default to 300.
+ """
+
+ if nbclient is None or nbformat is None:
+ raise ImportError(
+ "The package nbclient or nbformat is not found. Please "
+ "install it by `pip install notebook nbclient nbformat`",
+ )
+
+ self.nb = nbformat.v4.new_notebook()
+ self.nb_client = nbclient.NotebookClient(nb=self.nb)
+ self.work_path = work_path
+ self.ori_path = os.getcwd()
+ self.timeout = timeout
+
+ asyncio.run(self._start_client())
+
+ def _output_parser(self, output: dict) -> str:
+ """Parse the output of the notebook cell and return str"""
+ if output["output_type"] == "stream":
+ return output["text"]
+ elif output["output_type"] == "execute_result":
+ return output["data"]["text/plain"]
+ elif output["output_type"] == "display_data":
+ if "image/png" in output["data"]:
+ file_path = self._save_image(output["data"]["image/png"])
+ return f"Displayed image saved to {file_path}"
+ else:
+ return "Unsupported display type"
+ elif output["output_type"] == "error":
+ return output["traceback"]
+ else:
+ logger.info(f"Unsupported output encountered: {output}")
+ return "Unsupported output encountered"
+
+ async def _start_client(self) -> None:
+ """start notebook client"""
+ if self.nb_client.kc is None or not await self.nb_client.kc.is_alive():
+ os.chdir(self.work_path)
+ self.nb_client.create_kernel_manager()
+ self.nb_client.start_new_kernel()
+ self.nb_client.start_new_kernel_client()
+ os.chdir(self.ori_path)
+
+ async def _kill_client(self) -> None:
+ """kill notebook client"""
+ if (
+ self.nb_client.km is not None
+ and await self.nb_client.km.is_alive()
+ ):
+ await self.nb_client.km.shutdown_kernel(now=True)
+ await self.nb_client.km.cleanup_resources()
+
+ self.nb_client.kc.stop_channels()
+ self.nb_client.kc = None
+ self.nb_client.km = None
+
+ async def _restart_client(self) -> None:
+ """Restart the notebook client"""
+ await self._kill_client()
+ self.nb_client = nbclient.NotebookClient(self.nb, timeout=self.timeout)
+ await self._start_client()
+
+ async def _run_cell(self, cell_index: int):
+ """Run a cell in the notebook by its index"""
+ try:
+ self.nb_client.execute_cell(self.nb.cells[cell_index], cell_index)
+ return self.nb.cells[cell_index].outputs
+ return [self._output_parser(output) for output in self.nb.cells[cell_index].outputs]
+ except nbclient.exceptions.DeadKernelError:
+ await self.reset_notebook()
+ return "DeadKernelError when executing cell, reset kernel"
+ except nbclient.exceptions.CellTimeoutError:
+ assert self.nb_client.km is not None
+ await self.nb_client.km.interrupt_kernel()
+ return (
+ "CellTimeoutError when executing cell"
+ ", code execution timeout"
+ )
+ except Exception as e:
+ return str(e)
+
+ @property
+ def cells_length(self) -> int:
+ """return cell length"""
+ return len(self.nb.cells)
+
+ async def async_run_code_on_notebook(self, code: str):
+ """
+ Run the code on interactive notebook
+ """
+ self.nb.cells.append(nbformat.v4.new_code_cell(code))
+ cell_index = self.cells_length - 1
+ return await self._run_cell(cell_index)
+
+ def run_code_on_notebook(self, code: str):
+ """
+ Run the code on interactive jupyter notebook.
+
+ Args:
+ code (`str`):
+ The Python code to be executed in the interactive notebook.
+
+ Returns:
+ `ServiceResponse`: whether the code execution was successful,
+ and the output of the code execution.
+ """
+ return asyncio.run(self.async_run_code_on_notebook(code))
+
+ def reset_notebook(self) -> str:
+ """
+ Reset the notebook
+ """
+ asyncio.run(self._restart_client())
+ return "Reset notebook"
+
+
+
+
+
+class NBClientBox(BaseBox):
+
+ enter_status: bool = False
+
+ def __init__(
+ self,
+ do_code_exe: bool = False,
+ work_path: str = KB_ROOT_PATH,
+ ):
+ self.nbe = NoteBookExecutor(work_path=work_path)
+ self.do_code_exe = do_code_exe
+
+ def decode_code_from_text(self, text: str) -> str:
+ pattern = r'```.*?```'
+ code_blocks = re.findall(pattern, text, re.DOTALL)
+ code_text: str = "\n".join([block.strip('`') for block in code_blocks])
+ code_text = code_text[6:] if code_text.startswith("python") else code_text
+ code_text = code_text.replace("python\n", "").replace("code", "")
+ return code_text
+
+ def run(
+ self, code_text: Optional[str] = None,
+ file_path: Optional[os.PathLike] = None,
+ retry = 3,
+ ) -> CodeBoxResponse:
+ if not code_text and not file_path:
+ return CodeBoxResponse(
+ code_exe_response="Code or file_path must be specifieds!",
+ code_text=code_text,
+ code_exe_type="text",
+ code_exe_status=502,
+ do_code_exe=self.do_code_exe,
+ )
+
+ if code_text and file_path:
+ return CodeBoxResponse(
+ code_exe_response="Can only specify code or the file to read_from!",
+ code_text=code_text,
+ code_exe_type="text",
+ code_exe_status=502,
+ do_code_exe=self.do_code_exe,
+ )
+
+ if file_path:
+ with open(file_path, "r", encoding="utf-8") as f:
+ code_text = f.read()
+
+
+ def _output_parser(output: dict) -> str:
+ """Parse the output of the notebook cell and return str"""
+ if output["output_type"] == "stream":
+ return CodeBoxResponse(
+ code_exe_type="text",
+ code_text=code_text,
+ code_exe_response=output["text"] or "Code run successfully (no output)",
+ code_exe_status=200,
+ do_code_exe=self.do_code_exe
+ )
+ elif output["output_type"] == "execute_result":
+ return CodeBoxResponse(
+ code_exe_type="text",
+ code_text=code_text,
+ code_exe_response=output["data"]["text/plain"] or "Code run successfully (no output)",
+ code_exe_status=200,
+ do_code_exe=self.do_code_exe
+ )
+ elif output["output_type"] == "display_data":
+ if "image/png" in output["data"]:
+ return CodeBoxResponse(
+ code_exe_type="image/png",
+ code_text=code_text,
+ code_exe_response=output["data"]["image/png"],
+ code_exe_status=200,
+ do_code_exe=self.do_code_exe
+ )
+ else:
+ return CodeBoxResponse(
+ code_exe_type="error",
+ code_text=code_text,
+ code_exe_response="Unsupported display type",
+ code_exe_status=420,
+ do_code_exe=self.do_code_exe
+ )
+ elif output["output_type"] == "error":
+ return CodeBoxResponse(
+ code_exe_type="error",
+ code_text=code_text,
+ code_exe_response="error",
+ code_exe_status=500,
+ do_code_exe=self.do_code_exe
+ )
+ else:
+ return CodeBoxResponse(
+ code_exe_type="error",
+ code_text=code_text,
+ code_exe_response=f"Unsupported output encountered: {output}",
+ code_exe_status=420,
+ do_code_exe=self.do_code_exe
+ )
+
+ contents = self.nbe.run_code_on_notebook(code_text)
+ content = contents[0]
+ return _output_parser(content)
+
+ def restart(self, ) -> CodeBoxStatus:
+ return CodeBoxStatus(status="restared")
+
+ def stop(self, ) -> CodeBoxStatus:
+ pass
+
+ def __del__(self):
+ self.stop()
\ No newline at end of file
diff --git a/muagent/schemas/__init__.py b/muagent/schemas/__init__.py
index e69de29..a43edef 100644
--- a/muagent/schemas/__init__.py
+++ b/muagent/schemas/__init__.py
@@ -0,0 +1,12 @@
+from .message import Message
+from .memory import Memory
+from .agent_config import PromptConfig, AgentConfig
+from .project_config import ProjectConfig, EKGProjectConfig
+from .models import LLMConfig, ModelConfig
+
+
+__all__ = [
+ "Message", "Memory",
+ "PromptConfig", "AgentConfig", "LLMConfig", "ModelConfig",
+ "EKGProjectConfig", "ProjectConfig",
+]
diff --git a/muagent/schemas/agent_config.py b/muagent/schemas/agent_config.py
new file mode 100644
index 0000000..fe1763d
--- /dev/null
+++ b/muagent/schemas/agent_config.py
@@ -0,0 +1,68 @@
+
+from pydantic import BaseModel, root_validator
+from typing import List, Dict, Optional, Union, Literal
+
+
+class PromptConfig(BaseModel):
+ """The dataclass for prompt config."""
+
+ config_name: str = "codefuse"
+ """The config name of prompt."""
+
+ prompt_manager_type: str = "CommonPromptManager"
+ """The type of prompt manager."""
+
+ language: Literal['en', 'zh'] = 'en'
+ """The language of prompt manager."""
+
+
+class AgentConfig(BaseModel):
+ """The dataclass for agent config"""
+
+ config_name: str
+ """The name of the agent configuration. It equals to agent name"""
+
+ agent_type: str
+ """The type of the agent wrapper, which is to identify the agent wrapper
+ class in model configuration."""
+
+ agent_name: str
+ """The name of the agent, which is used in agent api calling. It will eqaul to role name"""
+
+ agent_desc: str = ""
+ """The role description of this role."""
+
+ system_prompt: str = ""
+ """The system prompt of this role."""
+
+ input_template: Union[str, BaseModel] = ""
+ """The input template for role."""
+
+ output_template: Union[str, BaseModel] = ""
+ """The output template for role."""
+
+ prompt: str = ""
+ """The full prompt of this role. it will override system prompt + input prompt + output prompt"""
+
+ tools: List[str] = []
+ """The tools' name of this role. it will use these tools to complete task"""
+
+ agents: List[str] = []
+ """This role can manage some agents. It will ask one agent to complete task"""
+
+ #
+ llm_config_name: Optional[str]
+ """The name of the llm model configuration."""
+
+ em_config_name: Optional[str]
+ """The name of the embedding model configuration."""
+
+ prompt_config_name: Optional[str]
+ """"""
+
+ @root_validator(pre=True)
+ def set_default_config_name(cls, values):
+ """Set config_name to model_name if config_name is not provided."""
+ if 'config_name' not in values or values['config_name'] is None:
+ values['config_name'] = values.get('agent_name')
+ return values
\ No newline at end of file
diff --git a/muagent/schemas/apis/ekg_api_schema.py b/muagent/schemas/apis/ekg_api_schema.py
index 5c677df..f4d6eef 100644
--- a/muagent/schemas/apis/ekg_api_schema.py
+++ b/muagent/schemas/apis/ekg_api_schema.py
@@ -1,8 +1,9 @@
from pydantic import BaseModel
-from typing import List, Dict, Optional, Literal
+from typing import List, Dict, Optional, Literal, Union
from enum import Enum
from muagent.schemas.common import GNode, GEdge
+from muagent.schemas.models import ChatMessage, Choice
@@ -41,6 +42,20 @@ class LLMRequest(BaseModel):
text: str
stop: Optional[str]
+
+class LLMFCRequest(BaseModel):
+ messages: List[ChatMessage]
+ system_prompt: Optional[str] = None
+ tools: List[Union[str, object]] = []
+ tool_choice: Optional[Literal["auto", "required"]] = "auto"
+ parallel_tool_calls: bool = False
+ stop: Optional[str]
+
+
+class LLMFCResponse(EKGResponse):
+ choices: List[Choice]
+
+
class LLMResponse(EKGResponse):
successCode: int
errorMessage: str
@@ -123,7 +138,7 @@ class SearchAncestorRequest(BaseModel):
class LLMParamsResponse(BaseModel):
url: Optional[str] = None
model_name: str
- model_type: Literal["openai", "ollama", "lingyiwanwu", "kimi", "moonshot", "qwen"] = "ollama"
+ model_type: str = "ollama"
api_key: str = ""
stop: Optional[str] = None
temperature: float = 0.3
@@ -137,7 +152,7 @@ class LLMParamsRequest(LLMParamsResponse):
class EmbeddingsParamsResponse(BaseModel):
# ollama embeddings
url: Optional[str] = None
- embedding_type: Literal["openai", "ollama"] = "ollama"
+ embedding_type: str = "ollama"
model_name: str = "qwen2.5:0.5b"
api_key: str = ""
diff --git a/muagent/schemas/common/__init__.py b/muagent/schemas/common/__init__.py
index d47cbe9..d633342 100644
--- a/muagent/schemas/common/__init__.py
+++ b/muagent/schemas/common/__init__.py
@@ -1,8 +1,12 @@
from .auto_extract_graph_schema import *
-
+from .actions import *
+from .log import LogVerboseEnum
__all__ = [
"GNodeAbs", "GEdgeAbs", "GRelationAbs", "Attribute",
"GNode", "GEdge", "Graph", "GEdgeRequst", "GNodeRequest", "GRelation",
- "ThemeEnums", "GbaseExecStatus"
+ "ThemeEnums", "GbaseExecStatus",
+
+ "ActionStatus",
+ "LogVerboseEnum",
]
\ No newline at end of file
diff --git a/muagent/schemas/common/actions.py b/muagent/schemas/common/actions.py
new file mode 100644
index 0000000..60d4bb6
--- /dev/null
+++ b/muagent/schemas/common/actions.py
@@ -0,0 +1,76 @@
+from pydantic import BaseModel
+from enum import Enum
+
+
+
+class ActionStatus(Enum):
+ DEFAUILT = "default"
+
+ FINISHED = "finished"
+ STOPPED = "stopped"
+ CONTINUED = "continued"
+
+ TOOL_USING = "tool_using"
+ CODING = "coding"
+ CODE_EXECUTING = "code_executing"
+ CODING2FILE = "coding2file"
+
+ PLANNING = "planning"
+ UNCHANGED = "unchanged"
+ ADJUSTED = "adjusted"
+ CODE_RETRIEVAL = "code_retrieval"
+
+ def __eq__(self, other):
+ if isinstance(other, str):
+ return self.value.lower() == other.lower()
+ return super().__eq__(other)
+
+
+class Action(BaseModel):
+ action_name: str
+ description: str
+
+class FinishedAction(Action):
+ action_name: str = ActionStatus.FINISHED
+ description: str = "provide the final answer to the original query to break the chain answer"
+
+class StoppedAction(Action):
+ action_name: str = ActionStatus.STOPPED
+ description: str = "provide the final answer to the original query to break the agent answer"
+
+class ContinuedAction(Action):
+ action_name: str = ActionStatus.CONTINUED
+ description: str = "cant't provide the final answer to the original query"
+
+class ToolUsingAction(Action):
+ action_name: str = ActionStatus.TOOL_USING
+ description: str = "proceed with using the specified tool."
+
+class CodingdAction(Action):
+ action_name: str = ActionStatus.CODING
+ description: str = "provide the answer by writing code"
+
+class Coding2FileAction(Action):
+ action_name: str = ActionStatus.CODING2FILE
+ description: str = "provide the answer by writing code and filename"
+
+class CodeExecutingAction(Action):
+ action_name: str = ActionStatus.CODE_EXECUTING
+ description: str = "provide the answer by writing executable code"
+
+class PlanningAction(Action):
+ action_name: str = ActionStatus.PLANNING
+ description: str = "provide a sequence of tasks"
+
+class UnchangedAction(Action):
+ action_name: str = ActionStatus.UNCHANGED
+ description: str = "this PLAN has no problem, just set PLAN_STEP to CURRENT_STEP+1."
+
+class AdjustedAction(Action):
+ action_name: str = ActionStatus.ADJUSTED
+ description: str = "the PLAN is to provide an optimized version of the original plan."
+
+# extended action exmaple
+class CodeRetrievalAction(Action):
+ action_name: str = ActionStatus.CODE_RETRIEVAL
+ description: str = "execute the code retrieval to acquire more code information"
diff --git a/muagent/schemas/common/log.py b/muagent/schemas/common/log.py
new file mode 100644
index 0000000..426815e
--- /dev/null
+++ b/muagent/schemas/common/log.py
@@ -0,0 +1,38 @@
+from enum import Enum
+from typing import Union
+
+
+class LogVerboseEnum(Enum):
+ Log0Level = "0" # don't print log
+ Log1Level = "1" # print level-1 log
+ Log2Level = "2" # print level-2 log
+ Log3Level = "3" # print level-3 log
+
+ def __eq__(self, other):
+ if isinstance(other, str):
+ return self.value.lower() == other.lower()
+ if isinstance(other, LogVerboseEnum):
+ return self.value == other.value
+ return False
+
+ def __ge__(self, other):
+ if isinstance(other, LogVerboseEnum):
+ return int(self.value) >= int(other.value)
+ if isinstance(other, str):
+ return int(self.value) >= int(other)
+ return NotImplemented
+
+ def __le__(self, other):
+ if isinstance(other, LogVerboseEnum):
+ return int(self.value) <= int(other.value)
+ if isinstance(other, str):
+ return int(self.value) <= int(other)
+ return NotImplemented
+
+ @classmethod
+ def ge(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']):
+ return enum_value <= other
+
+ @classmethod
+ def le(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']):
+ return enum_value <= other
\ No newline at end of file
diff --git a/muagent/schemas/kb/base_schema.py b/muagent/schemas/kb/base_schema.py
index fd14756..e9dda17 100644
--- a/muagent/schemas/kb/base_schema.py
+++ b/muagent/schemas/kb/base_schema.py
@@ -1,6 +1,6 @@
from sqlalchemy import Column, Integer, String, DateTime, func
-from muagent.orm.db import Base
+from muagent.db_handler.db import Base
class KnowledgeBaseSchema(Base):
diff --git a/muagent/schemas/memory.py b/muagent/schemas/memory.py
new file mode 100644
index 0000000..7e0ba80
--- /dev/null
+++ b/muagent/schemas/memory.py
@@ -0,0 +1,193 @@
+from pydantic import BaseModel
+from typing import List, Union, Dict, Optional, Literal
+from loguru import logger
+
+from .message import Message
+
+
+class Memory(BaseModel):
+ '''The base dataclass of Memory'''
+
+ messages: List[Message] = []
+ _limit: Optional[int] = None
+
+ def set_limit(self, limit: Optional[int] = None):
+ self._limit = limit
+
+ def _limit_messages(self, ):
+ if self._limit:
+ self.messages = self.messages[-self._limit:]
+
+ def append(self, message: Message):
+ self.messages.append(message)
+ self._limit_messages()
+
+ def extend(self, memory: 'Memory'):
+ self.messages.extend(memory.messages)
+ self._limit_messages()
+
+ def update(self, message: Message, role_tag: str = None):
+ if role_tag is None:
+ return
+ message_index = message.message_index
+ idx = None
+ for idx, msg in enumerate(self.messages):
+ if msg.session_index == message_index: break
+ if idx is not None:
+ if (self.messages[idx].role_tags, list):
+ self.messages[idx].role_tags = list(set(self.messages[idx].role_tags + [role_tag]))
+ else:
+ self.messages[idx].role_tags += f", {role_tag}"
+
+ def sort_by_key(self, key: str):
+ self.messages = sorted(self.messages, key=lambda x: getattr(x, key, f"No this {key}"))
+
+ def clear(self, k: int = None):
+ '''save the messages by k limit'''
+ if k is None:
+ self.messages = []
+ else:
+ self.messages = self.messages[-k:]
+
+ def get_messages(self, k=0) -> List[Message]:
+ """Return the most recent k memories, return all when k=0"""
+ return self.messages[-k:]
+
+ def get_datetimes(self) -> List[any]:
+ """get datetime values values. default: end_datetime"""
+ return self.get_memory_values("end_datetime")
+
+ def get_contents(self) -> List[any]:
+ """get content values"""
+ return self.get_memory_values("content")
+
+ def get_memory_values(self, key: str) -> List[any]:
+ return [message.get_value(key) for message in self.messages]
+
+ def split_by_role_type(self) -> List[Dict[str, 'Memory']]:
+ """
+ Split messages into rounds of conversation based on role_type.
+ Each round consists of consecutive messages of the same role_type.
+ User messages form a single round, while assistant and function messages are combined into a single round.
+ Each round is represented by a dict with 'role' and 'memory' keys, with assistant and function messages
+ labeled as 'assistant'.
+ """
+ rounds = []
+ current_memory = Memory()
+ current_role = None
+
+ for msg in self.messages:
+ # Determine the message's role, considering 'function' as 'assistant'
+ message_role = 'assistant' if msg.role_type in ['assistant', 'function'] else 'user'
+
+ # If the current memory is empty or the current message is of the same role_type as current_role, add to current memory
+ if not current_memory.messages or current_role == message_role:
+ current_memory.append(msg)
+ else:
+ # Finish the current memory and start a new one
+ rounds.append({'role': current_role, 'memory': current_memory})
+ current_memory = Memory()
+ current_memory.append(msg)
+
+ # Update the current_role, considering 'function' as 'assistant'
+ current_role = message_role
+
+ # Don't forget to add the last memory if it exists
+ if current_memory.messages:
+ rounds.append({'role': current_role, 'memory': current_memory})
+
+ return rounds
+
+ def format_rounds_to_html(self) -> str:
+ formatted_html_str = ""
+ rounds = self.split_by_role_type()
+
+ for round in rounds:
+ role = round['role']
+ memory = round['memory']
+
+ # 转换当前round的Memory为字符串
+ messages_str = memory.to_str_messages()
+
+ # 根据角色类型添加相应的HTML标签
+ if role == 'user':
+ formatted_html_str += f"\n{messages_str}\n\n"
+ else: # 对于'assistant'和'function'角色,我们将其视为'assistant'
+ formatted_html_str += f"\n{messages_str}\n\n"
+
+ return formatted_html_str
+
+ def to_format_messages(
+ self,
+ attributes: dict[str, Union[any, List[any]]] = {},
+ filter_type: Optional[Literal['select', 'filter']] = None,
+ *,
+ return_all: bool = True,
+ content_key: str = "content",
+ with_tag: bool = False,
+ format_type: Literal['raw', 'tuple', 'dict', 'str']='raw',
+ logic: Literal['or', 'and'] = 'and'
+ ) -> List[Message]:
+ '''Filter messages by attributes'''
+ def _logic_check(values: List[bool], logic):
+ # default: not filter any message
+ if values == []: return True
+ return any(values) if logic == "or" else all(values)
+
+ def _select(message, attrs, select_type="filter"):
+ if select_type == "filter":
+ return [message.get(key) not in value if isinstance(value, list) else
+ message.get(key) != value
+ for key, value in attrs.items()
+ ]
+ else:
+ return [message.get(key) in value if isinstance(value, list) else
+ message.get(key) == value
+ for key, value in attrs.items()
+ ]
+ #
+ messages = [
+ message for message in self.messages
+ if _logic_check(_select(message, attributes, filter_type), logic)
+ ]
+
+ #
+ if format_type == "tuple":
+ return [
+ message.to_tuple_message(return_all, content_key)
+ for message in messages
+ ]
+ elif format_type == "dict":
+ return [
+ message.to_dict_message()
+ for message in messages
+ ]
+ elif format_type == "str":
+ return "\n\n".join([
+ message.to_str_content(content_key, with_tag=with_tag)
+ for message in messages
+ ])
+
+ return messages
+
+ @classmethod
+ def from_memory_list(cls, memorys: List['Memory']) -> 'Memory':
+ return cls(messages=[message for memory in memorys for message in memory.get_messages()])
+
+ def __len__(self, ):
+ return len(self.messages)
+
+ def __str__(self) -> str:
+ return self.to_format_messages(format_type="str")
+ return "\n".join([": ".join(i) for i in self.to_format_messages(format_type="tuple")])
+
+ def __add__(self, other: Union[Message, 'Memory']) -> 'Memory':
+ if isinstance(other, Message):
+ return Memory(messages=self.messages + [other])
+ elif isinstance(other, Memory):
+ return Memory(messages=self.messages + other.messages)
+ else:
+ raise ValueError(f"cant add unspecified type like as {type(other)}")
+
+
+
\ No newline at end of file
diff --git a/muagent/schemas/message.py b/muagent/schemas/message.py
new file mode 100644
index 0000000..a444a58
--- /dev/null
+++ b/muagent/schemas/message.py
@@ -0,0 +1,258 @@
+from pydantic import BaseModel, root_validator
+from typing import List, Dict, Optional, Literal, Union, Sequence, Tuple
+from loguru import logger
+import uuid
+from muagent.utils.common_utils import getCurrentDatetime
+
+
+
+class Message(BaseModel):
+ '''The base dataclass of Message
+
+ The following is an example:
+
+ .. code-block:: python
+
+ from muagent.schemas.message import Message
+ msg = Message(
+ role_name="system",
+ role_type="system",
+ content="You're a helpful assistant",
+ )
+ '''
+
+ #
+ role_name: str = "muagent"
+ '''The role name of agent to generate this message.'''
+
+ role_type: Literal[
+ 'system',
+ 'user',
+ 'assistant',
+ 'observation',
+ 'tool_call',
+ 'function',
+ 'codefuse',
+ 'summary'
+ ] = "codefuse"
+ '''The role type of agent to generate this message. such as system/user/assistant/observation/tool_call'''
+ #
+ role_tags: Union[Sequence[str], str] = ''
+ '''The tags of this message.'''
+
+ embedding: Optional[Sequence] = None
+ '''The embedding from LLM of this message.'''
+
+ image_urls: Optional[Sequence[str]] = None
+ '''The image_urls from LLM of this message.'''
+
+ action_status: str = "default"
+ '''llm\tool\code executre information'''
+
+ content: Optional[str] = ""
+ '''The last response from LLM of this message.'''
+
+ step_content: Optional[str] = ''
+ '''The multi content from LLM of this message, connected by \n'''
+
+ parsed_content: Dict = {}
+ '''The structed content from LLM parsing of this message'''
+
+ parsed_contents: List[Dict] = []
+ '''The multi structed content from LLM parsing of this message'''
+
+ spec_parsed_content: Dict = {}
+ '''The special structed content from LLM parsing of this message'''
+
+ spec_parsed_contents: List[Dict] = []
+ '''The multi special structed content from LLM parsing of this message'''
+
+ global_kwargs: Dict = {}
+ '''user's customed kargs for init or end action'''
+
+ # input from last message
+ input_text: Optional[str] = ""
+ '''The input text from last message.'''
+
+ parsed_input: Dict = {}
+ '''The structed input from LLM parsing from last message'''
+
+ parsed_inputs: List[Dict] = []
+ '''The multi structed input from LLM parsing from last message'''
+
+ spec_parsed_input: Dict = {}
+ '''The special structed content from LLM parsing of this message'''
+
+ spec_parsed_inputs: List[Dict] = []
+ '''The multi special structed content from LLM parsing of this message'''
+
+ #
+ session_index: Optional[str] = None
+ '''The session index of this message.'''
+
+ message_index: Optional[str] = None
+ '''The message index of this message.'''
+
+ node_index: Optional[str] = "default"
+ '''The node index of this message.'''
+
+ #
+ start_datetime: str = None
+ '''The first record time of this message.'''
+
+ end_datetime: str = None
+ '''The last update time of this message.'''
+
+ datetime_format: str = "%Y-%m-%d %H:%M:%S.%f"
+
+ @root_validator(pre=True)
+ def check_card_number_omitted(cls, values):
+ input_text = values.get("input_text")
+ content = values.get("content")
+ if content is None:
+ values["content"] = content or input_text
+ return values
+
+ @root_validator(pre=True)
+ def check_datetime(cls, values):
+ start_datetime = values.get("start_datetime")
+ end_datetime = values.get("end_datetime")
+ datetime_format = values.get("datetime_format", "%Y-%m-%d %H:%M:%S.%f")
+ if start_datetime is None:
+ values["start_datetime"] = getCurrentDatetime(datetime_format)
+ if end_datetime is None:
+ values["end_datetime"] = getCurrentDatetime(datetime_format)
+ return values
+
+ @root_validator(pre=True)
+ def check_message_index(cls, values):
+ message_index = values.get("message_index")
+ session_index = values.get("session_index")
+ if message_index is None or message_index == "":
+ values["message_index"] = str(uuid.uuid4()).replace("-", "_")
+
+ if session_index is None or session_index == "":
+ values["session_index"] = str(uuid.uuid4()).replace("-", "_")
+ return values
+
+ def update_input(self, input: Union[str, 'Message'], parsed_input: Dict = {}):
+ if isinstance(input, str):
+ self.update_attributes({"input_text": input})
+ else:
+ self.update_attributes({"input_text": input.content})
+
+ def update_parsed_input(self, parsed_input: Dict):
+ self.update_attributes({"parsed_input": parsed_input})
+ self.update_attributes({"parsed_inputs": self.parsed_inputs + [parsed_input]})
+
+ def update_spec_parsed_input(self, spec_parsed_input: Dict):
+ self.update_attributes({"spec_parsed_input": spec_parsed_input})
+ self.update_attributes({"spec_parsed_inputs": self.spec_parsed_inputs + [spec_parsed_input]})
+
+ def update_content(self, content: Union[str, 'Message'], parsed_content: Dict = {}):
+ if isinstance(content, str):
+ self.update_attributes({"content": content})
+ self.update_attributes({"step_content": self.step_content + f"\n{content}"})
+ else:
+ self.update_attributes({"content": content.content})
+ self.update_attributes({"step_content": self.step_content + f"\n{content.content}"})
+
+ def update_parsed_content(self, parsed_content: Dict = {}):
+ self.update_attributes({"parsed_content": parsed_content})
+ self.update_attributes({"parsed_contents": self.parsed_contents + [parsed_content]})
+
+ def update_spec_parsed_content(self, spec_parsed_content: Dict = {}):
+ self.update_attributes({"spec_parsed_content": spec_parsed_content})
+ self.update_attributes({"spec_parsed_contents": self.spec_parsed_contents + [spec_parsed_content]})
+
+ def update_attributes(self, attributes: dict):
+ '''update message attributes'''
+ for k, v in attributes.items():
+ self.update_attribute(k, v)
+
+ def update_attribute(self, key: str, value):
+ if hasattr(self, key):
+ setattr(self, key, value)
+ self.end_datetime = getCurrentDatetime(self.datetime_format)
+ else:
+ raise AttributeError(f"{key} is not a valid property of {self.__class__.__name__}")
+
+ def to_dict_message(self, ) -> Dict:
+ return vars(self)
+
+ def to_tuple_message(
+ self,
+ return_all: bool = True,
+ content_key: Literal[
+ 'input_text',
+ 'content',
+ 'step_conetent',
+ 'parsed_content',
+ 'spec_parsed_contents',
+ ] = "content",
+ ) -> Union[str, Tuple[str, str]]:
+ content = self.to_str_content(False, content_key)
+ if return_all:
+ return (self.role_name, content)
+ else:
+ return (content)
+
+ def to_str_content(
+ self,
+ content_key: Literal[
+ 'input_text',
+ 'content',
+ 'step_conetent',
+ 'parsed_content',
+ 'parsed_contents',
+ 'spec_parsed_content',
+ 'spec_parsed_contents',
+ ] = "content",
+ with_tag=False
+ ) -> str:
+ # TODO while role_type is USER return input_query, else return role_content
+ response = self.content or self.input_text
+ if content_key == "content":
+ content = response
+ elif content_key == "input_text":
+ content = self.input_text
+ elif content_key == "step_content":
+ content = self.step_content or response
+ elif content_key == "parsed_content":
+ content = "\n".join([v for k, v in self.parsed_content.items()]) or response
+ # content = "\n".join([f"**{k}:** {v}" for k, v in self.parsed_content.items()]) or response
+ elif content_key == "spec_parsed_content":
+ content = "\n".join([f"**{k}:** {v}" for k, v in self.spec_parsed_content.items()]) or response
+ elif content_key == "parsed_contents":
+ content = "\n".join([v for po in self.parsed_contents for k,v in po.items()]) or response
+ elif content_key == "spec_parsed_contents":
+ content = "\n".join([f"**{k}:** {v}" for po in self.spec_parsed_contents for k,v in po.items()]) or response
+ else:
+ content = response
+
+ if with_tag:
+ start_tag = f"<{self.role_type}-{self.role_name}-message>"
+ end_tag = f"{self.role_type}-{self.role_name}-message>"
+ return f"{start_tag}\n{content}\n{end_tag}"
+ else:
+ return content
+
+ def get_value(self, key: str) -> any:
+ """
+ Get the value of the given key from the message.
+
+ :param key: The key of the attribute to retrieve.
+ :return: The value associated with the key.
+ """
+ if hasattr(self, key):
+ return getattr(self, key, None)
+ raise AttributeError(f"Message don't have attribute {key}")
+
+ def get_attribute_type(self, key):
+ return type(getattr(self, key, None))
+
+ def __str__(self) -> str:
+ # key_str = '\n'.join([k for k, v in vars(self).items()])
+ # logger.debug(f"{key_str}")
+ return "\n".join([": ".join([k, str(v)]) for k, v in vars(self).items()])
+
\ No newline at end of file
diff --git a/muagent/schemas/models/__init__.py b/muagent/schemas/models/__init__.py
new file mode 100644
index 0000000..046c4b7
--- /dev/null
+++ b/muagent/schemas/models/__init__.py
@@ -0,0 +1,11 @@
+from .model import ModelConfig, LLMConfig
+from .llm_shemas import *
+
+
+__all__ = [
+ "ModelConfig", "LLMConfig"
+
+ "ChatMessage", "FunctionCallData", "ToolCall", "LLMOuputMessage",
+ "Choice", "UsageData", "LLMResponse",
+
+]
\ No newline at end of file
diff --git a/muagent/schemas/models/llm_shemas.py b/muagent/schemas/models/llm_shemas.py
new file mode 100644
index 0000000..45f54e9
--- /dev/null
+++ b/muagent/schemas/models/llm_shemas.py
@@ -0,0 +1,50 @@
+from pydantic import BaseModel, Field
+from typing import List, Dict, Optional, Union
+from enum import Enum
+
+
+
+class ChatMessage(BaseModel):
+ role: str
+ content: str
+
+
+class FunctionCallData(BaseModel):
+ name: str
+ arguments: Union[str, dict]
+
+
+class ToolCall(BaseModel):
+ id: Optional[Union[str, int]] = None
+ type: str = "function"
+ function: FunctionCallData
+
+
+class LLMOuputMessage(BaseModel):
+ content: Optional[str] = None
+ role: str
+ tool_calls: List[ToolCall] = []
+
+
+class Choice(BaseModel):
+ finish_reason: str
+ index: int = 0
+ message: LLMOuputMessage
+
+
+class UsageData(BaseModel):
+ completion_tokens: int
+ prompt_tokens: int
+ total_token: int
+
+
+class LLMResponse(BaseModel):
+ choices: List[Choice]
+ created: int = 0
+ id: str
+ model: str
+ object: str
+ usage: Optional[UsageData] = None
+
+
+
diff --git a/muagent/schemas/models/model.py b/muagent/schemas/models/model.py
new file mode 100644
index 0000000..4e4a4fb
--- /dev/null
+++ b/muagent/schemas/models/model.py
@@ -0,0 +1,56 @@
+
+
+from pydantic import BaseModel, root_validator
+from typing import List, Dict, Optional, Union, Literal
+
+
+
+class ModelConfig(BaseModel):
+ """The dataclass for model config."""
+
+ config_name: Optional[str] = None
+ """The name of the model configuration. It equals to model_name or model_type."""
+
+ model_type: str
+ """The type of the model wrapper, which is to identify the model wrapper
+ class in model configuration."""
+
+ model_name: str
+ """The name of the model, which is used in model api calling."""
+
+ api_key: Optional[str] = None
+ """The api key of the model, which is used in model api calling."""
+
+ api_url: Optional[str] = None
+ """The api url of the model, which is used in model api calling."""
+
+ max_tokens: Optional[int] = None
+ """The max_tokens of the model, which is used in model api calling."""
+
+ top_p: float = 0.9
+ """The top_p of the model, which is used in model api calling."""
+
+ temperature: float = 0.3
+ """The temperature of the model, which is used in model api calling."""
+
+ stream: bool = False
+ """The stream mode of the model, which is used in model api calling."""
+
+ @root_validator(pre=True)
+ def set_default_config_name(cls, values):
+ """Set config_name to model_name if config_name is not provided."""
+ if 'config_name' not in values or values['config_name'] is None:
+ values['config_name'] = values.get('model_name')
+ return values
+
+
+
+class LLMConfig(BaseModel):
+ """temp config will delete"""
+ model_name: str = "gpt-3.5-turbo"
+ model_engine: str = "openai"
+ temperature: float = 0.3
+ stop: Union[List[str], str] = None
+ api_key: str = ""
+ api_base_url: str = ""
+ llm: Optional[str] = ""
\ No newline at end of file
diff --git a/muagent/schemas/project_config.py b/muagent/schemas/project_config.py
new file mode 100644
index 0000000..ff19c19
--- /dev/null
+++ b/muagent/schemas/project_config.py
@@ -0,0 +1,114 @@
+
+from pydantic import BaseModel, Field
+from typing import (
+ List,
+ Dict,
+ Optional,
+ Union,
+ Literal,
+ Any
+)
+
+from .models import ModelConfig, LLMConfig
+from .agent_config import AgentConfig, PromptConfig
+from .db import GBConfig, TBConfig
+
+
+
+class ProjectConfig(BaseModel):
+ """The dataclass of project config"""
+
+ agent_configs: Optional[Dict[str, AgentConfig]]
+ """"""
+
+ prompt_configs: Optional[Dict[str, PromptConfig]]
+ """"""
+
+ model_configs: Optional[Dict[str, Any]]
+ """"""
+
+ graph: Any = None
+ """"""
+
+ def extend_agent_configs(
+ self,
+ agent_configs: Union[AgentConfig, List[AgentConfig], Dict[str, AgentConfig]]
+ ):
+
+ if isinstance(agent_configs, AgentConfig):
+ self.agent_configs.update({agent_configs.config_name: agent_configs})
+
+ if isinstance(agent_configs, List):
+ self.agent_configs.update({
+ i.config_name: agent_configs for i in agent_configs
+ if isinstance(agent_configs, AgentConfig)
+ })
+ elif isinstance(agent_configs, Dict):
+ self.agent_configs.update(agent_configs)
+
+ def extend_prompt_configs(
+ self,
+ prompt_configs: Union[PromptConfig, List[PromptConfig], Dict[str, PromptConfig]]
+ ):
+ if isinstance(prompt_configs, PromptConfig):
+ self.prompt_configs.update({prompt_configs.config_name: prompt_configs})
+
+ if isinstance(prompt_configs, List):
+ self.prompt_configs.update({
+ i.config_name: prompt_configs for i in prompt_configs
+ if isinstance(prompt_configs, PromptConfig)
+ })
+ elif isinstance(prompt_configs, Dict):
+ self.prompt_configs.update(prompt_configs)
+
+ def extend_model_configs(
+ self,
+ model_configs: Union[ModelConfig, List[ModelConfig], Dict[str, ModelConfig]]
+ ):
+ if isinstance(model_configs, ModelConfig):
+ self.model_configs.update({model_configs.config_name: model_configs})
+
+ if isinstance(model_configs, List):
+ self.model_configs.update({
+ i.config_name: model_configs for i in model_configs
+ if isinstance(model_configs, ModelConfig)
+ })
+ elif isinstance(model_configs, Dict):
+ self.model_configs.update(model_configs)
+
+ def extend_graph(self, graph):
+ """wait"""
+ pass
+
+ def __add__(self, other: 'ProjectConfig') -> 'ProjectConfig':
+ if isinstance(other, ProjectConfig):
+ self.extend_agent_configs(other.agent_configs)
+ self.extend_prompt_configs(other.model_configs)
+ self.extend_prompt_configs(other.prompt_configs)
+ self.extend_graph(other.graph)
+ return self
+ else:
+ raise ValueError(f"cant add unspecified type like as {type(other)}")
+
+
+
+class EKGProjectConfig(BaseModel):
+ """The dataclass of project config"""
+
+ config_name: str = "default"
+ """The config name of EKG Project"""
+
+ model_configs: Optional[Dict[str, Union[ModelConfig, Any]]]
+ """"""
+
+ embed_configs: Optional[Dict[str, ModelConfig]]
+ """"""
+
+ agent_configs: Optional[Dict[str, AgentConfig]]
+ """"""
+
+ prompt_configs: Optional[Dict[str, PromptConfig]]
+ """"""
+
+ db_configs: Optional[Dict[str, Union[GBConfig, TBConfig]]]
+ """"""
\ No newline at end of file
diff --git a/muagent/service/ekg_construct/ekg_construct_base.py b/muagent/service/ekg_construct/ekg_construct_base.py
index f13466e..be19262 100644
--- a/muagent/service/ekg_construct/ekg_construct_base.py
+++ b/muagent/service/ekg_construct/ekg_construct_base.py
@@ -22,7 +22,8 @@
from muagent.schemas.db import *
from muagent.schemas.common import *
from muagent.db_handler import *
-from muagent.orm import table_init
+# from muagent.orm import table_init
+from muagent.db_handler import table_init
from muagent.base_configs.env_config import EXTRA_KEYWORDS_PATH
from muagent.connector.configs.generate_prompt import *
@@ -193,6 +194,12 @@ def init_gb(self, do_init: bool=None):
self.create_gb_tags_and_edgetypes()
self.waiting_tags_edgetypes_initialize()
+ # print('Node Tags和Edge Types初始化中,等待20秒......')
+ # time.sleep(20)
+ else:
+ self.gb.add_hosts('storaged0', 9779)
+ # 创建node tags和edge types
+ self.create_gb_tags_and_edgetypes()
else:
self.gb = None
@@ -323,6 +330,28 @@ def _dfs(node, current_path: List):
def create_gb_tags_and_edgetypes(self):
+
+ def _check():
+ node_types = [i for i in TYPE2SCHEMA.keys() if i!='edge']
+ tags = self.gb.show_tags()
+ tag_names = [tag["Name"] for tag in tags]
+ tag_flag = set(tag_names) == set(node_types)
+
+ edges = self.gb.show_edge_type()
+ edge_names = [edge["Name"] for edge in edges]
+ inset_edges = [f"{i}{k}{j}"
+ for i in node_types
+ for j in node_types
+ for k in ["_route_", "_extend_", "_conclude_"]
+ ]
+ edge_flag = set(edge_names) == set(inset_edges)
+ logger.info(f"tag_flag={tag_flag}, edge_flag={edge_flag}")
+ #
+ return tag_flag and edge_flag
+
+ # if tags is existed and edge is existed, return
+ if _check(): return
+
# 节点标签和属性 (done)
for node_type, schema in TYPE2SCHEMA.items():
if node_type == 'edge':
@@ -364,7 +393,7 @@ def create_gb_tags_and_edgetypes(self):
# 边类型(名称)
node_types = list(TYPE2SCHEMA.keys())
- logger.info(node_types)
+ node_types = [i for i in TYPE2SCHEMA.keys() if i!='edge']
for i in range(len(node_types)):
for j in range(len(node_types)):
if node_types[i] != 'edge' and node_types[j] != 'edge': # 排除 node_type 为 'edge'
@@ -374,9 +403,11 @@ def create_gb_tags_and_edgetypes(self):
self.gb.create_edge_type(edge_type2, edge_attributes_dict)
edge_type3 = f"{node_types[i]}_conclude_{node_types[j]}"
self.gb.create_edge_type(edge_type3, edge_attributes_dict)
-
-
+ time.sleep(5)
+ while not _check():
+ logger.info('Node Tags和Edge Types初始化中,等待5秒......')
+ time.sleep(5)
def update_graph(
self,
@@ -874,6 +905,7 @@ def get_node_by_id(
) -> GNode:
if service_type=="gbase":
node = self.gb.get_current_node({'id': nodeid}, node_type=node_type)
+ if node is None: return node
node = self._normalized_nodes_type(nodes=[node])[0]
else:
node = GNode(id=nodeid, type="", attributes={})
diff --git a/muagent/service/ekg_inference/intention_match_rule.py b/muagent/service/ekg_inference/intention_match_rule.py
index f94b6ab..19a821c 100644
--- a/muagent/service/ekg_inference/intention_match_rule.py
+++ b/muagent/service/ekg_inference/intention_match_rule.py
@@ -1,5 +1,5 @@
import re
-import Levenshtein
+import edit_distance as ed
from muagent.schemas.common import GNode
@@ -13,13 +13,14 @@ def edit_distance(cls, node: GNode, pattern=None, **kwargs):
desc: str = node.attributes.get('description', '')
if pattern is None:
- return -Levenshtein.distance(desc, s)
+ return -ed.edit_distance(desc, s)[0]
desc_list = re.findall(pattern, desc)
if not desc_list:
return -float('inf')
- return max([-Levenshtein.distance(x, s) for x in desc_list])
+ return max([-ed.edit_distance(x, s)[0] for x in desc_list])
+
@classmethod
def edit_distance_integer(cls, node: GNode, **kwargs):
diff --git a/muagent/service/ekg_inference/intention_router.py b/muagent/service/ekg_inference/intention_router.py
index 2bb5136..f41d063 100644
--- a/muagent/service/ekg_inference/intention_router.py
+++ b/muagent/service/ekg_inference/intention_router.py
@@ -103,8 +103,7 @@ def _func(node: GNode, rule: Callable):
return select_node, error_msg
def get_intention_by_node_info_match(
- self, root_node_id: str, filter_attribute: Optional[dict] = None,
- gb_handler: Optional[GBHandler] = None,
+ self, root_node_id: str, gb_handler: Optional[GBHandler] = None,
rule: Union[Rule_type, list[Rule_type]] = None, **kwargs
) -> dict[str, Any]:
gb_handler = gb_handler if gb_handler is not None else self.gb_handler
@@ -124,10 +123,7 @@ def get_intention_by_node_info_match(
RuleRetInfo(node_id=root_node_id, error_msg=error_msg, status=RouterStatus.OTHERS.value))
if not (root_node_id and self._node_exist(root_node_id, gb_handler)):
- if not root_node_id:
- error_msg = f'No node matches attribute {filter_attribute}.'
- else:
- error_msg = f'Node(id={root_node_id}, type={self._node_type}) does not exist!'
+ error_msg = f'Node(id={root_node_id}, type={self._node_type}) does not exist!'
return asdict(
RuleRetInfo(node_id=root_node_id, error_msg=error_msg, status=RouterStatus.OTHERS.value))
@@ -284,6 +280,8 @@ def get_intention_whether_execute(self, query: str, agent=None) -> bool:
return False
def get_intention_consult_which(self, query: str, agent=None, root_node_id: Optional[str]=None) -> str:
+ if isinstance(query, (list, tuple)):
+ query = query[0]
agent = agent if agent else self.agent
query_consult_which = itp.CONSULT_WHICH_PROMPT.format(query=query)
ans = agent.predict(query_consult_which)
@@ -386,13 +384,12 @@ def _dfs(s: str, ancestor: str, path: str, out: dict, visited: set):
if child in nodes:
if ancestor in out:
out.pop(ancestor)
+ visited.add(ancestor)
temp_ancestor = child
else:
temp_ancestor = ancestor
child_path = split.join((path, child))
_dfs(child, temp_ancestor, child_path, out, visited)
- if s in nodes:
- visited.add(s)
if len(nodes) == 0:
return dict()
diff --git a/muagent/service/ui_file_service/code_base_cds.py b/muagent/service/ui_file_service/code_base_cds.py
index f1d0e31..43cc0da 100644
--- a/muagent/service/ui_file_service/code_base_cds.py
+++ b/muagent/service/ui_file_service/code_base_cds.py
@@ -6,7 +6,7 @@
@desc:
'''
from loguru import logger
-from muagent.orm.db import with_session
+from muagent.db_handler.db import with_session
from muagent.schemas.kb.base_schema import CodeBaseSchema
diff --git a/muagent/service/ui_file_service/document_base_cds.py b/muagent/service/ui_file_service/document_base_cds.py
index 43f8700..c8bc73a 100644
--- a/muagent/service/ui_file_service/document_base_cds.py
+++ b/muagent/service/ui_file_service/document_base_cds.py
@@ -1,4 +1,4 @@
-from muagent.orm.db import with_session
+from muagent.db_handler.db import with_session
from muagent.schemas.kb.base_schema import KnowledgeBaseSchema
diff --git a/muagent/service/ui_file_service/document_file_cds.py b/muagent/service/ui_file_service/document_file_cds.py
index 9cdd27e..801ff39 100644
--- a/muagent/service/ui_file_service/document_file_cds.py
+++ b/muagent/service/ui_file_service/document_file_cds.py
@@ -1,4 +1,4 @@
-from muagent.orm.db import with_session
+from muagent.db_handler.db import with_session
from muagent.schemas.kb.base_schema import KnowledgeFileSchema, KnowledgeBaseSchema
from muagent.schemas.kb.file_schema import DocumentFile
diff --git a/muagent/service/utils.py b/muagent/service/utils.py
index ae8f7e9..2a2592f 100644
--- a/muagent/service/utils.py
+++ b/muagent/service/utils.py
@@ -33,10 +33,10 @@ def decode_biznodes(
**{**{"id": node.id, "type": node.type}, **node.attributes}
)
- if node.type == "opsgptkg_task":
- logger.debug(f"schema:{ schema}")
- logger.debug(f"node_data:{ type(node_data)}")
- logger.debug(f"node_data:{ node_data}")
+ # if node.type == "opsgptkg_task":
+ # logger.debug(f"schema:{ schema}")
+ # logger.debug(f"node_data:{ type(node_data)}")
+ # logger.debug(f"node_data:{ node_data}")
node_data = {
k:v
@@ -44,8 +44,8 @@ def decode_biznodes(
if k not in ["type", "ID", "id", "extra"]
}
- if node.type == "opsgptkg_task":
- logger.debug(f"node_data:{ node_data}")
+ # if node.type == "opsgptkg_task":
+ # logger.debug(f"node_data:{ node_data}")
# update agent/tool nodes and edges
agents = node_data.pop("agents", [])
@@ -70,9 +70,9 @@ def decode_biznodes(
attributes={}
))
- if node.type == "opsgptkg_task":
- logger.debug(f"node_data:{ node_data}")
- logger.debug(f"node.attributes:{ node.attributes}")
+ # if node.type == "opsgptkg_task":
+ # logger.debug(f"node_data:{ node_data}")
+ # logger.debug(f"node.attributes:{ node.attributes}")
new_nodes.append(GNode(**{
"id": node.id,
diff --git a/muagent/tools/__init__.py b/muagent/tools/__init__.py
index 5200c04..5a6a5c3 100644
--- a/muagent/tools/__init__.py
+++ b/muagent/tools/__init__.py
@@ -12,12 +12,14 @@
from .ocr_tool import BaiduOcrTool
from .stock_tool import StockInfo, StockName
from .codechat_tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
+from .undercover import *
+from .werewolf import *
IMPORT_TOOL = [
WeatherInfo, DistrictInfo, Multiplier, WorldTimeGetTimezoneByArea,
KSigmaDetector, MetricsQuery, DDGSTool, DocRetrieval, CodeRetrieval,
- BaiduOcrTool, StockInfo, StockName, CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
+ BaiduOcrTool, StockInfo, StockName, CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code,
]
TOOL_SETS = [tool.__name__ for tool in IMPORT_TOOL]
@@ -29,3 +31,6 @@
"toLangchainTools", "get_tool_schema", "tool_sets", "BaseToolModel"
] + TOOL_SETS
+
+def get_tool(tool_name: str) -> BaseToolModel:
+ return BaseToolModel._from_name(tool_name)
\ No newline at end of file
diff --git a/muagent/tools/base_tool.py b/muagent/tools/base_tool.py
index 507822a..814a210 100644
--- a/muagent/tools/base_tool.py
+++ b/muagent/tools/base_tool.py
@@ -1,16 +1,109 @@
+from abc import ABCMeta
+
from langchain.agents import Tool
from langchain.tools import StructuredTool
from langchain.tools.base import ToolException
from pydantic import BaseModel, Field
-from typing import List, Dict
-# import jsonref
-import json
-
+from typing import List, Dict, Any, Type
+try:
+ import jsonref
+except:
+ pass
-class BaseToolModel:
+import json
+import copy
+
+
+def simplify_schema(schema: Dict[str, Any], definitions, no_required=False,depth=0) -> Dict[str, Any]:
+ """简化 schema,去除 $ref 引用和 definitions"""
+ if definitions is None: return schema
+ schema_new = copy.deepcopy(schema)
+ # 去掉 title 字段
+ schema_new.pop('title', None)
+ # 遍历 properties
+ if 'properties' in schema:
+ for key, value in schema['properties'].items():
+ for k,v in value.items():
+
+ if k == "allOf":
+ ref_model_name = v[0]['$ref'].split('/')[-1] # 提取模型名称
+ ref_model_value = simplify_schema(definitions[ref_model_name], definitions, no_required=True, depth=depth+1)
+ schema_new["properties"][key].pop(k)
+ schema_new["properties"][key].update(ref_model_value)
+
+ if isinstance(v, dict) and '$ref' in v:
+ ref_model_name = v['$ref'].split('/')[-1] # 提取模型名称
+ ref_model_value = simplify_schema(definitions[ref_model_name], definitions, no_required=True, depth=depth+1)
+ schema_new["properties"][key][k] = ref_model_value
+
+ schema_new["properties"][key].pop("title")
+ # 去掉 definitions 部分
+ if no_required:
+ schema_new.pop('required', None)
+ schema_new.pop('definitions', None)
+
+ return schema_new
+
+
+class _ToolWrapperMeta(ABCMeta):
+ """A meta call to replace the tool wrapper's run function with
+ wrapper about error handling."""
+
+ def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any:
+ if "__call__" in attrs:
+ attrs["__call__"] = attrs["__call__"]
+ return super().__new__(mcs, name, bases, attrs)
+
+ def __init__(cls, name: Any, bases: Any, attrs: Any) -> None:
+ if not hasattr(cls, "_registry"):
+ cls._registry = {} # class name
+ cls._toolname_registry = {} # class attribute name
+ else:
+ cls._registry[name] = cls
+ cls._toolname_registry[cls.name] = cls
+ super().__init__(name, bases, attrs)
+
+
+class BaseToolModel(metaclass=_ToolWrapperMeta):
name = "BaseToolModel"
description = "Tool Description"
+ @classmethod
+ def _from_name(cls, tool_name: str) -> 'BaseToolModel':
+
+ """Get the specific model wrapper"""
+ if tool_name in cls._registry:
+ return cls._registry[tool_name]() # type: ignore[return-value]
+ elif tool_name in cls._toolname_registry:
+ return cls._toolname_registry[tool_name]() # type: ignore[return-value]
+ else:
+ raise KeyError(
+ f"Tool Library is missiong"
+ f" {tool_name}, please check your tool name"
+ )
+
+ @classmethod
+ def intput_to_json_schema(cls) -> Dict[str, Any]:
+ '''Transform schema to json structure'''
+ try:
+ return jsonref.loads(cls.ToolInputArgs.schema_json())
+ except:
+ return simplify_schema(
+ cls.ToolInputArgs.schema(),
+ cls.ToolInputArgs.schema().get("definitions")
+ )
+
+ @classmethod
+ def output_to_json_schema(cls) -> Dict[str, Any]:
+ '''Transform schema to json structure'''
+ try:
+ return jsonref.loads(cls.ToolInputArgs.schema_json())
+ except:
+ return simplify_schema(
+ cls.ToolOutputArgs.schema(),
+ cls.ToolOutputArgs.schema().get("definitions")
+ )
+
class ToolInputArgs(BaseModel):
"""
Input for MoveFileTool.
@@ -32,7 +125,7 @@ class ToolOutputArgs(BaseModel):
key2: str = Field(..., description="hello world!!")
@classmethod
- def run(cls, tool_input_args: ToolInputArgs) -> ToolOutputArgs:
+ def run(cls) -> ToolOutputArgs:
"""excute your tool!"""
pass
diff --git a/muagent/tools/metrics_query.py b/muagent/tools/metrics_query.py
index c3336c7..6198663 100644
--- a/muagent/tools/metrics_query.py
+++ b/muagent/tools/metrics_query.py
@@ -25,7 +25,8 @@ class ToolOutputArgs(BaseModel):
datas: List[float] = Field(..., description="监控时序数组")
- def run(machine_ip, time):
+ @classmethod
+ def run(cls, machine_ip, time):
"""excute your tool!"""
data = [0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890,
16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567,
diff --git a/muagent/tools/undercover.py b/muagent/tools/undercover.py
new file mode 100644
index 0000000..6746bf6
--- /dev/null
+++ b/muagent/tools/undercover.py
@@ -0,0 +1,315 @@
+import os
+from typing import (
+ List,
+ Dict
+)
+from loguru import logger
+from pydantic import BaseModel, Field
+import random
+
+from .base_tool import BaseToolModel
+from ..models import get_model, ModelConfig
+
+
+class SeatAssignerTool(BaseToolModel):
+ """
+ This tool assigns seat positions to players and formats them in a markdown table.
+ Example Output:
+ ```
+ | 座位 | 玩家 |
+ |---|---|
+ | 1 | **张伟** |
+ | 2 | **李静** |
+ | 3 | **王鹏** |
+ | 4 | **人类玩家** |
+ ```
+ """
+ name: str = "谁是卧底-座位分配"
+ description: str = "谁是卧底的座位分配工具,可以将玩家顺序打乱随机分配座位"
+
+ class ToolInputArgs(BaseModel):
+ """Input for SeatAssigner."""
+ pass # No specific parameters required for this tool
+
+ class ToolOutputArgs(BaseModel):
+ """Output for SeatAssigner."""
+ table: str = Field(..., description="Markdown table of seating arrangement")
+
+ @classmethod
+ def run(cls, **kwargs) -> str:
+ """Execute the seat assignment tool."""
+ players = [["张伟", "agent_张伟"], ["李静", "agent_李静"], ["王鹏", "agent_王鹏"], ["人类玩家", "agent_人类玩家"]]
+ # Shuffle players to assign them to random seats
+ random.shuffle(players)
+ # Create the markdown table
+ markdown_table = "\n\n| 座位 | 玩家 |\n|---|---|\n" + "\n".join(
+ f"| {i+1} | **{players[i][0]}** |" for i in range(len(players))
+ )
+ return markdown_table
+
+
+class RoleAssignerTool(BaseToolModel):
+ """
+ This class assigns roles and words to players in a game.
+ The output will include player names, agent names, agent descriptions, and secret words based on their role type.
+ """
+ name: str = "谁是卧底-角色分配"
+ description: str = "谁是卧底的角色分配工具,可以为每一位玩家分配一个单词和人物角色。"
+
+ class ToolInputArgs(BaseModel):
+ """Input for assigning roles."""
+ pass
+
+ class ToolOutputArgs(BaseModel):
+ """Output for assigned roles."""
+ roles: List[Dict[str, str]] = Field(..., description="List of roles assigned to players")
+
+ @classmethod
+ def run(cls, **kwargs) -> List[Dict[str, str]]:
+ words = [
+ ["苹果", "梨"],
+ ["猫", "狗"],
+ ["摩托车", "自行车"],
+ ["太阳", "月亮"],
+ ["红色", "粉色"],
+ ["大象", "长颈鹿"],
+ ["铅笔", "钢笔"],
+ ["牛奶", "豆浆"],
+ ["河", "湖"],
+ ["面包", "蛋糕"],
+ ["饺子", "包子"],
+ ["冬天", "夏天"],
+ ["电视", "电脑"],
+ ["铅笔", "橡皮"],
+ ["跑步", "游泳"],
+ ["手机", "平板"],
+ ["鱼", "虾"],
+ ["空调", "风扇"],
+ ["马", "驴"],
+ ["书", "杂志"],
+ ["草", "树"],
+ ["杯子", "碗"],
+ ["米饭", "面条"],
+ ["饼干", "蛋糕"],
+ ["雨伞", "雨衣"],
+ ["猪", "牛"],
+ ["白菜", "生菜"],
+ ["吉他", "钢琴"],
+ ["飞机", "火车"],
+ ["镜子", "眼镜"]
+ ]
+
+ player_names = ["张伟", "李静", "王鹏", "人类玩家"]
+ roles = ["平民_1", "平民_2", "平民_3", "卧底_1"]
+ random.shuffle(player_names)
+ random.shuffle(roles)
+
+ word_idx = random.randint(0, len(words) - 1)
+ under_cover_word_idx = random.randint(0, 1)
+
+ result = []
+ for i in range(len(player_names)):
+ r = {
+ "player_name": player_names[i],
+ "agent_name": f"agent_{player_names[i]}",
+ "agent_description": roles[i],
+ "单词": words[word_idx][1 - under_cover_word_idx] if roles[i].startswith("平民") else words[word_idx][under_cover_word_idx]
+ }
+ result.append(r)
+
+ return result
+
+
+class GameActionTool(BaseToolModel):
+ name = "谁是卧底-游戏行动"
+ description = "谁是卧底的游戏行动工具,需要根据记忆的上下文信息,当前任务、以及你拿到的单词信息来进行回答响应。"
+
+ class ToolInputArgs(BaseModel):
+ pass
+
+ class ToolOutputArgs(BaseModel):
+ content: str
+
+ @classmethod
+ def run(cls, **kwargs) -> str:
+ """Execute your tool!"""
+
+ template = (
+ '##背景##\n'
+ '您正在参加“谁是卧底”这款游戏,您的目标是:想办法击杀与自己身份不同的所有玩家,获得胜利。\n'
+ '\n'
+ '##游戏介绍##\n'
+ '在“谁是卧底”游戏中,每位玩家会被分配一个[单词](玩家可见)和一个身份(玩家不可见,包括[平民]和[卧底]两种身份),卧底的[单词]跟[平民]不同,但有许多共同的特征。\n'
+ '\n'
+ '##任务##\n'
+ '1. 根据**游戏进展中主持人的最新通知**,感知当前的任务:讨论 or 票选卧底,准备发言。\n'
+ '2. 如果任务是讨论,感知分配给您的[单词],描述它的某一特征(**描述内容可真可假,禁止描述已经提到过的特征**),您的目标是:让其他玩家相信该特征与他们的[单词]是相符的;否则,投票给某个当前存活玩家,并说明理由,您的目标是:让其他玩家相信,该玩家给出的特征与大家的[单词]都不符。\n'
+ '\n'
+ '##发言示例##\n'
+ '(任务是讨论)一种植物,可食用。\n'
+ '(任务是票选卧底)我投票给李静,因为对比所有人的发言,他的描述和其他的有明显区别。\n'
+ '\n'
+ '##游戏进展##\n'
+ '{memory}\n'
+ '\n'
+ '##注意##\n'
+ '- 无论您的任务是什么,**禁止泄露自己的[单词],发言内容尽可能简洁!!!**。\n'
+ '- 如果您的任务是讨论,**描述的特征可真可假,但要避免已经提到过的特征**;如果是票选卧底,**一定要明确表示投票给哪一位玩家(禁止给自己或已经死亡的玩家投票)**。\n'
+ '- 禁止描述任何没有发生过的事情。\n'
+ '\n'
+ '##游戏经验##\n'
+ '如果任务是讨论,以下是描述[单词]特征时的一些经验:\n'
+ '1. 保持模糊性:特征不宜过于明显(尤其当您是首位发言的玩家时),这样很容易别人推测出自己的[单词]。\n'
+ '2. 逐渐清晰:与其他玩家给出的特征相比,您的特征应该更清晰,否则很容易被其他玩家怀疑。\n'
+ '3. 定位身份:如果您发现多个玩家的特征跟您的[单词]都不符,那么自己的身份很可能是[卧底],应该推测他们的[单词]是什么,**编造**跟他们[单词]相符的特征。\n'
+ '\n'
+ '##输出##\n'
+ 'Python可直接解析的jsonstr,格式如下:\n'
+ '{{\"thought\": 感知自己的名字、位置(根据主持人的【身份通知】!!!注意您不是人类玩家)、[单词]、当前任务、哪些特征已经被提出来、推测其他玩家的[单词]是什么、自己是否是[卧底]、如何保护自己,分析内容不超过120字, \"output\": 您的发言(避免泄露[单词],避免投票给自己,避免重复特征,直接说出符合您的身份的话,不要输出其他信息)}}\n'
+ '以{{开头,任何其他内容都是不允许的!\n'
+ )
+ model_config = None
+ try:
+ model_config = ModelConfig(
+ config_name="codefuse_default",
+ model_type=os.environ.get("DEFAULT_MODEL_TYPE"),
+ model_name=os.environ.get("DEFAULT_MODEL_NAME"),
+ api_key=os.environ.get("DEFAULT_API_KEY"),
+ api_url=os.environ.get("DEFAULT_API_URL"),
+ )
+ memory = kwargs.get("memory") or ""
+ model = get_model(model_config)
+ content = model.predict(template.format(memory=memory))
+ except Exception as e:
+ content = f"无法正确调用模型: {e}, {model_config}"
+ return content
+
+
+class AgentZhangweiTool(GameActionTool):
+ name = "谁是卧底-张伟"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与谁是卧底这场游戏,在游戏中你的名字是李静"
+ )
+
+class AgentLijingTool(GameActionTool):
+ name = "谁是卧底-李静"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与谁是卧底这场游戏,在游戏中你的名字是李静"
+ )
+
+
+class AgentWangpengTool(GameActionTool):
+ name = "谁是卧底-王鹏"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与谁是卧底这场游戏,在游戏中你的名字是李静"
+ )
+
+
+class GameEndCheckerTool(BaseToolModel):
+ name = "谁是卧底-胜利条件判断"
+ description = "谁是卧底的胜利条件判断工具,判断当前谁是卧底游戏是否结束。"
+
+ class ToolInputArgs(BaseModel):
+ pass
+
+ class ToolOutputArgs(BaseModel):
+ content: str
+
+ @classmethod
+ def run(cls, **kwargs) -> str:
+ """Execute your tool!"""
+
+ template = (
+ '##本局游戏历史记录##\n'
+ '{memory}\n\n'
+ '##背景##\n'
+ '你是一个逻辑判断大师,你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,现在需要判断当前游戏是否结束。\n\n'
+ '##任务##\n'
+ '你的任务是判断当前游戏是否结束,规则如下:\n'
+ '根据【重要信息】感知每一轮被投票死亡的玩家。 统计目前存活的[平民]玩家数量、[卧底]玩家数量。格式{{\"存活的卧底\":[player_name], \"存活的平民\":[player_name]}},判断以下条件中的一个是否满足:\n'
+ '1. \t卧底玩家全部已经死亡(即 存活[卧底]数量为0)。\n'
+ '2. 存活的[平民]数量与存活的[卧底]数量相等。\n'
+ '如果其中一个条件满足,则游戏结束;否则,游戏需要继续。\n\n'
+ '##输出##\n'
+ '返回jsonstr 格式。{{\"thought\": str, \"存活的玩家信息\": {{\"存活的卧底\":[player_name], \"存活的平民\":[player_name]}}, \"isEnd\": \"是\" or \"否\"}}\n'
+ '-thought **根据本局游戏历史记录** 分析 游戏最开始有哪些玩家, 他们的身份是什么, 投票导致死亡的玩家有哪些? 分析当前存活的玩家有哪些 ? 是否触发了游戏结束条件? 等等\n\n'
+ '##注意事项##\n'
+ '1. 所有玩家的座位、身份、agent_name、存活状态、游戏进展等信息在开头部分已给出。\n'
+ '2. \"是\" or \"否\" 如何选择?若游戏结束,则为\"是\",否则为\"否\"。\n'
+ '3. 请直接输出jsonstr,不用输出markdown格式。\n\n'
+ '4. 游戏可能进行了不只一轮,可能有1个或者2个玩家已经死亡,请注意感知\n'
+ '##结果##\n\n'
+ )
+
+ model_config = None
+ try:
+ model_config = ModelConfig(
+ config_name="codefuse_default",
+ model_type=os.environ.get("DEFAULT_MODEL_TYPE"),
+ model_name=os.environ.get("DEFAULT_MODEL_NAME"),
+ api_key=os.environ.get("DEFAULT_API_KEY"),
+ api_url=os.environ.get("DEFAULT_API_URL"),
+ )
+ memory = kwargs.get("memory") or ""
+ model = get_model(model_config)
+ content = model.predict(template.format(memory=memory))
+ except Exception as e:
+ content = f"无法正确调用模型: {e}, {model_config}"
+ return content
+
+class GameOutcomeCheckerTool(BaseToolModel):
+ name = "谁是卧底-结果输出"
+ description = "谁是卧底的结果输出工具,判断谁是卧底游戏中最终的胜利方是谁,并输出角色分配情况"
+
+ class ToolInputArgs(BaseModel):
+ pass
+
+ class ToolOutputArgs(BaseModel):
+ content: str
+
+ @classmethod
+ def run(cls, **kwargs) -> str:
+ """Execute your tool!"""
+
+ template = (
+ '##本局游戏历史记录##\n'
+ '{memory}\n'
+ '\n'
+ '##背景##\n'
+ '您正在参与“谁是卧底”这个游戏,角色是[主持人]。现在游戏已经结束,您需要判断胜利的一方是谁。\n'
+ '\n'
+ '##任务##\n'
+ '统计目前存活的[平民]玩家数量、[卧底]玩家数量。判断以下条件中的哪一个满足:\n'
+ '1.[卧底]数量为0。\n'
+ '2.[平民]数量与[卧底]数量相等。\n'
+ '如果条件1满足,则[平民]胜利;如果条件2满足,则[卧底]胜利。\n'
+ '\n'
+ '##输出##\n'
+ 'Python可直接解析的jsonstr,格式如下:\n'
+ '{{\"原因是\": 获胜者为[平民]或[卧底]的原因, \"角色分配结果为\": 所有玩家的身份和单词(根据本局游戏历史记录), \"获胜方为\": \"平民\" or \"卧底\"}}\n'
+ '以{{开头,任何其他内容都是不允许的!\n'
+ '\n'
+ '##输出示例##\n'
+ '{{\"原因是\": \"卧底数量为0\", \"角色分配结果为\": \"李静:身份为卧底,单词为香蕉;人类玩家:身份为平民, 单词为梨子; 张伟:身份为平民, 单词为梨子; 王鹏:身份为平民, 单词为梨子。\", \"获胜方为\": \"平民\"}}\n'
+ '\n'
+ '##注意##\n'
+ '请输出所有玩家的角色分配结果,不要遗漏信息\n'
+ '\n'
+ '##结果##\n\n'
+ )
+
+ model_config = None
+ try:
+ model_config = ModelConfig(
+ config_name="codefuse_default",
+ model_type=os.environ.get("DEFAULT_MODEL_TYPE"),
+ model_name=os.environ.get("DEFAULT_MODEL_NAME"),
+ api_key=os.environ.get("DEFAULT_API_KEY"),
+ api_url=os.environ.get("DEFAULT_API_URL"),
+ )
+ memory = kwargs.get("memory") or ""
+ model = get_model(model_config)
+ content = model.predict(template.format(memory=memory))
+ except Exception as e:
+ content = f"无法正确调用模型: {e}, {model_config}"
+ return content
\ No newline at end of file
diff --git a/muagent/tools/werewolf.py b/muagent/tools/werewolf.py
new file mode 100644
index 0000000..6457f37
--- /dev/null
+++ b/muagent/tools/werewolf.py
@@ -0,0 +1,314 @@
+import os
+from typing import (
+ List,
+ Dict
+)
+from loguru import logger
+from pydantic import BaseModel, Field
+import random
+
+
+from .base_tool import BaseToolModel
+from ..models import get_model, ModelConfig
+
+
+
+class RoleAssignmentTool(BaseToolModel):
+ name = "狼人杀-角色分配工具"
+ description = "狼人杀的角色分配工具,可以为每一位玩家分配一个单词和人物角色。"
+
+ class ToolInputArgs(BaseModel):
+ pass
+
+ class ToolOutputArgs(BaseModel):
+ roles: list
+
+ @classmethod
+ def run(cls, **kwargs) -> ToolOutputArgs:
+ """Execute your tool!"""
+
+ players = [
+ ["朱丽", "agent_朱丽"],
+ ["周杰", "agent_周杰"],
+ ["沈强", "agent_沈强"],
+ ["韩刚", "agent_韩刚"],
+ ["梁军", "agent_梁军"],
+ ["周欣怡", "agent_周欣怡"],
+ ["贺子轩", "agent_贺子轩"],
+ ["人类玩家", "agent_人类玩家"]
+ ]
+ random.shuffle(players)
+ roles = ["平民_1", "平民_2", "平民_3", "狼人_1", "狼人_2", "狼人_3", "女巫", "预言家"]
+ random.shuffle(roles)
+
+ assigned_roles = []
+ for i in range(len(players)):
+ assigned_roles.append({
+ "player_name": players[i][0],
+ "agent_name": players[i][1],
+ "agent_description": roles[i]
+ })
+ return assigned_roles
+
+
+
+class PlayerSeatingTool(BaseToolModel):
+ name: str = "狼人杀-座位分配"
+ description: str = "狼人杀的座位分配工具,可以将玩家顺序打乱随机分配座位"
+
+ class ToolInputArgs(BaseModel):
+ pass
+
+ class ToolOutputArgs(BaseModel):
+ seating_chart: str
+
+ @classmethod
+ def run(cls, **kwargs) -> ToolOutputArgs:
+ """Execute your tool!"""
+
+ players = [
+ ["朱丽", "agent_朱丽"],
+ ["周杰", "agent_周杰"],
+ ["沈强", "agent_沈强"],
+ ["韩刚", "agent_韩刚"],
+ ["梁军", "agent_梁军"],
+ ["周欣怡", "agent_周欣怡"],
+ ["贺子轩", "agent_贺子轩"],
+ ["人类玩家", "agent_人类玩家"]
+ ]
+ n = len(players)
+ random.shuffle(players)
+
+ seating_chart = "\n\n| 座位 | 玩家 |\n|---|---|\n"
+ seating_chart += "\n".join(f"| {i} | **{players[i-1][0]}** |" for i in range(1, n + 1))
+
+ return seating_chart
+
+
+class WerewolfGameInstructionTool(BaseToolModel):
+ name = "狼人杀-游戏指令"
+ description = "狼人杀的游戏指令工具,需要根据记忆的上下文信息,当前任务、以及你拿到的身份信息来进行回应。"
+
+ class ToolInputArgs(BaseModel):
+ pass
+
+ class ToolOutputArgs(BaseModel):
+ instruction: str
+
+ @classmethod
+ def run(cls, **kwargs) -> ToolOutputArgs:
+ """Execute your tool!"""
+ template = (
+ '##狼人杀游戏说明##\n'
+ '这个游戏基于文字交流, 以下是游戏规则:\n'
+ '角色:\n'
+ '主持人也是游戏的组织者,你需要正确回答他的指示。游戏中有五种角色:狼人、平民、预言家、女巫和猎人,三个狼人,一个预言家,一个女巫,一个猎人,两个平民。\n'
+ '好人阵营: 村民、预言家、猎人和女巫。\n'
+ '游戏阶段:游戏分为两个交替的阶段:白天和黑夜。\n'
+ '黑夜:\n'
+ '在黑夜阶段,你与主持人的交流内容是保密的。你无需担心其他玩家和主持人知道你说了什么或做了什么。\n'
+ '- 如果你是狼人,你需要和队友一起选择袭击杀死一个玩家\n'
+ '- 如果你是女巫,你有一瓶解药,可以拯救被狼人袭击的玩家,以及一瓶毒药,可以在黑夜后毒死一个玩家。解药和毒药只能使用一次。\n'
+ '- 如果你是预言家,你可以在每个晚上检查一个玩家是否是狼人,这非常重要。\n'
+ '- 如果你是猎人,当你在黑夜被狼人杀死时可以选择开枪杀死任意一名玩家。\n'
+ '- 如果你是村民,你在夜晚无法做任何事情。\n'
+ '白天:\n'
+ '你与存活所有玩家(包括敌人)讨论。讨论结束后,玩家投票来淘汰一个自己怀疑是狼人的玩家。获得最多票数的玩家将被淘汰。主持人将告诉谁被杀,否则将没有人被杀。\n'
+ '如果你是猎人,当你在白天被投票杀死之后可以选择开枪杀死任意一名玩家。\n'
+ '游戏目标:\n'
+ '狼人的目标是杀死所有的好人阵营中的玩家,并且不被好人阵营的玩家识别出狼人身份;\n'
+ '好人阵营的玩家,需要找出并杀死所有的狼人玩家。\n'
+ '##注意##\n'
+ '你正在参与狼人杀这个游戏,你应该感知自己的名字、座位号和角色。\n'
+ '1. 若你的角色为狼人,白天的发言应该尽可能隐藏身份。\n'
+ '2. 若你的角色属于好人阵营,白天的发言应该根据游戏进展尽可能分析出谁是狼人。\n'
+ '##以下为目前游戏进展##\n'
+ '{memory}\n'
+ '##发言格式##\n'
+ '你的回答中需要包含你的想法并给出简洁的理由,注意请有理有据,白天的发言尽量不要与别人的发言内容重复。发言的格式应该为Python可直接解析的jsonstr,格式如下:\n'
+ '{{\"thought\": 以“我是【座位号】号玩家【名字】【角色】”开头,根据主持人的通知感知自己的【名字】、【座位号】、【角色】,根据游戏进展和自己游戏角色的当前任务分析如何发言,字数不超过150字, \"output\": 您的发言应该符合目前游戏进展和自己角色的逻辑,白天投票环节不能投票给自己。}}\n'
+ '##开始发言##\n'
+ )
+ model_config = None
+ try:
+ model_config = ModelConfig(
+ config_name="codefuse_default",
+ model_type=os.environ.get("DEFAULT_MODEL_TYPE"),
+ model_name=os.environ.get("DEFAULT_MODEL_NAME"),
+ api_key=os.environ.get("DEFAULT_API_KEY"),
+ api_url=os.environ.get("DEFAULT_API_URL"),
+ )
+ memory = kwargs.get("memory") or ""
+ model = get_model(model_config)
+ content = model.predict(template.format(memory=memory))
+ except Exception as e:
+ content = f"无法正确调用模型: {e}, {model_config}"
+ return content
+
+
+class AgentZhuliTool(WerewolfGameInstructionTool):
+ name = "狼人杀-agent_朱丽"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是朱丽"
+ )
+
+
+class AgentZhoujieTool(WerewolfGameInstructionTool):
+ name = "狼人杀-agent_周杰"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是周杰"
+ )
+
+
+class AgentShenqiangTool(WerewolfGameInstructionTool):
+ name = "狼人杀-agent_沈强"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是沈强"
+ )
+
+class AgentHangangTool(WerewolfGameInstructionTool):
+ name = "狼人杀-agent_韩刚"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是韩刚"
+ )
+
+
+class AgentLiangjunTool(WerewolfGameInstructionTool):
+ name = "狼人杀-agent_梁军"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是梁军"
+ )
+
+class AgentZhouxinyiTool(WerewolfGameInstructionTool):
+ name = "狼人杀-agent_周欣怡"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是周欣怡"
+ )
+
+
+class AgentHezixuanTool(WerewolfGameInstructionTool):
+ name = "狼人杀-agent_贺子轩"
+ description = (
+ f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是贺子轩"
+ )
+
+
+class WerewolfGameEndCheckerTool(BaseToolModel):
+ name = "狼人杀-胜利条件判断"
+ description = "狼人杀的胜利条件判断工具,判断当前狼人杀游戏是否结束。"
+
+ class ToolInputArgs(BaseModel):
+ pass
+
+ class ToolOutputArgs(BaseModel):
+ thought: str
+ players: dict
+ isEnd: str
+
+ @classmethod
+ def run(cls, **kwargs) -> ToolOutputArgs:
+ """Execute your tool!"""
+ template = (
+ '##本局游戏历史记录##\n'
+ '{memory}\n'
+ '\n'
+ '##背景##\n'
+ '你是一个逻辑判断大师,你正在参与“狼人杀”这个游戏,你的角色是[主持人]。你熟悉“狼人杀”游戏的完整流程,现在需要判断当前游戏是否结束。\n'
+ '\n'
+ '##任务##\n'
+ '你的任务是判断当前游戏是否结束,规则如下:\n'
+ '根据【重要信息】感知每一轮被投票死亡、被狼人杀死、被女巫毒死、被猎人带走的玩家。 统计目前存活的[好人]玩家数量、[狼人]玩家数量。格式{{\"存活的好人\":[player_name], \"存活的狼人\":[player_name]}},判断以下条件中的一个是否满足:\n'
+ '1. 存活的“狼人”玩家数量为0。\n'
+ '2. “狼人”数量超过了“好人”数量。\n'
+ '3. “狼人”数量等于“好人”数量,“女巫”已死亡或者她的毒药已经使用。\n'
+ '若某个条件满足,游戏结束;否则游戏没有结束。\n'
+ '\n'
+ '##输出##\n'
+ '返回JSON格式,格式为:{{\"thought\": str, \"存活的玩家信息\": {{\"存活的好人\":[player_name], \"存活的狼人\":[player_name]}}, \"isEnd\": \"是\" or \"否\"}}\n'
+ '-thought **根据本局游戏历史记录** 分析 游戏最开始有哪些玩家, 他们的身份是什么, 投票导致死亡的玩家有哪些? 被狼人杀死的玩家有哪些? 被女巫毒死的玩家是谁? 被猎人带走的玩家是谁?分析当前存活的玩家有哪些? 是否触发了游戏结束条件? 等等。\n'
+ '\n'
+ '##example##\n'
+ '{{\"thought\": \"**游戏开始时** 有 小杭、小北、小赵、小钱、小孙、小李、小夏、小张 八位玩家, 其中 小杭、小北、小赵是[狼人], 小钱、小孙是[平民], 小李是[预言家],小夏是[女巫],小张是[猎人],小张在第一轮被狼人杀死了,猎人没有开枪,[狼人]数量大于[好人]数量,因此游戏未结束。\", \"存活的玩家信息\": {{\"存活的狼人\":[\"小杭\", \"小北\", \"小赵\"]}}, {{\"存活的好人\":[\"小钱\", \"小孙\", \"小李\", \"小夏\"]}}, \"isEnd\": \"否\" }}\n'
+ '##注意事项##\n'
+ '1. 所有玩家的座位、身份、agent_name、存活状态、游戏进展等信息在开头部分已给出。\n'
+ '2. \"是\" or \"否\" 如何选择?若游戏结束,则为\"是\",否则为\"否\"。\n'
+ '3. 请直接输出jsonstr,不用输出markdown格式。\n'
+ '4. 游戏可能进行了不只一轮,可能有1个或者2个玩家已经死亡,请注意感知。\n'
+ )
+ model_config = None
+ try:
+ model_config = ModelConfig(
+ config_name="codefuse_default",
+ model_type=os.environ.get("DEFAULT_MODEL_TYPE"),
+ model_name=os.environ.get("DEFAULT_MODEL_NAME"),
+ api_key=os.environ.get("DEFAULT_API_KEY"),
+ api_url=os.environ.get("DEFAULT_API_URL"),
+ )
+ memory = kwargs.get("memory") or ""
+ model = get_model(model_config)
+ content = model.predict(template.format(memory=memory))
+ except Exception as e:
+ content = f"无法正确调用模型: {e}, {model_config}"
+ return content
+
+
+
+class WerewolfGameOutcomeTool(BaseToolModel):
+ name = "狼人杀-结果输出"
+ description = "狼人杀的结果输出工具,判断狼人杀游戏中最终的胜利方是谁。"
+
+ class ToolInputArgs(BaseModel):
+ pass
+
+ class ToolOutputArgs(BaseModel):
+ reason: str
+ 角色分配结果为: str
+ 获胜方为: str
+
+ @classmethod
+ def run(cls, **kwargs) -> ToolOutputArgs:
+ """Execute your tool!"""
+ template = (
+ '##本局游戏历史记录##\n'
+ '{memory}\n'
+ '\n'
+ '##背景##\n'
+ '您正在参与“狼人杀”这个游戏,角色是[主持人]。现在游戏已经结束,您需要判断胜利的一方是谁。\n'
+ '\n'
+ '##任务##\n'
+ '统计目前存活的[好人]玩家数量、[狼人]玩家数量。判断以下条件中的哪一个满足:\n'
+ '1. 存活的“狼人”玩家数量为0。\n'
+ '2. “狼人”数量超过了“好人”数量。\n'
+ '3. “狼人”数量等于“好人”数量,“女巫”已死亡或者她的毒药已经使用。\n'
+ '如果条件1满足,则[好人]胜利;如果条件2或者条件3满足,则[狼人]胜利。\n'
+ '\n'
+ '##输出##\n'
+ 'Python可直接解析的jsonstr,格式如下:\n'
+ '{{\"原因是\": 获胜者为[好人]或[狼人]的原因, \"角色分配结果为\": 所有玩家的角色(根据本局游戏历史记录), \"获胜方为\": \"好人\" or \"狼人\"}}\n'
+ '以{{开头,任何其他内容都是不允许的!\n'
+ '\n'
+ '##输出示例##\n'
+ '{{\"原因是\": \"狼人数量为0\", \"角色分配结果为\": \"沈强:身份为狼人_1;周欣怡:身份为狼人_2;梁军:身份为狼人_3;贺子轩:身份为平民_1;人类玩家:身份为平民_2;朱丽:身份为预言家;韩刚:身份为女巫;周杰:身份为猎人。\", \"获胜方为\": \"好人\"}}\n'
+ '\n'
+ '##注意##\n'
+ '请输出所有玩家的角色分配结果,不要遗漏信息。\n'
+ '\n'
+ '##结果##\n'
+ '\n'
+ )
+
+ model_config = None
+ try:
+ model_config = ModelConfig(
+ config_name="codefuse_default",
+ model_type=os.environ.get("DEFAULT_MODEL_TYPE"),
+ model_name=os.environ.get("DEFAULT_MODEL_NAME"),
+ api_key=os.environ.get("DEFAULT_API_KEY"),
+ api_url=os.environ.get("DEFAULT_API_URL"),
+ )
+ memory = kwargs.get("memory") or ""
+ model = get_model(model_config)
+ content = model.predict(template.format(memory=memory))
+ except Exception as e:
+ content = f"无法正确调用模型: {e}, {model_config}"
+ return content
\ No newline at end of file
diff --git a/muagent/utils/common_utils.py b/muagent/utils/common_utils.py
index 977420b..5939390 100644
--- a/muagent/utils/common_utils.py
+++ b/muagent/utils/common_utils.py
@@ -31,6 +31,13 @@ def addMinutesToTime(input_time: str, n: int = 5, dateformat=DATE_FORMAT):
new_time_after = dt + timedelta(minutes=n)
return new_time_before.strftime(dateformat), new_time_after.strftime(dateformat)
+def addMinutesToTimestamp(input_time: str, n: int = 5, dateformat=DATE_FORMAT):
+ dt = datetime.strptime(input_time, dateformat)
+
+ # 前后加N分钟
+ new_time_before = dt - timedelta(minutes=n)
+ new_time_after = dt + timedelta(minutes=n)
+ return new_time_before.timestamp(), new_time_after.timestamp()
def timestampToDateformat(ts, interval=1000, dateformat=DATE_FORMAT):
'''将标准时间戳转换标准指定时间格式'''
@@ -131,6 +138,44 @@ def double_hashing(s: str, modulus: int = 10e12) -> int:
return int((hash1 + hash2) % modulus)
+def _convert_to_str(content: Any) -> str:
+ """Convert the content to string.
+
+ The implementation of this _convert_to_str are borrowed from
+ https://github.com/modelscope/agentscope/blob/main/src/agentscope/utils/common.py
+
+ Note:
+ For prompt engineering, simply calling `str(content)` or
+ `json.dumps(content)` is not enough.
+
+ - For `str(content)`, if `content` is a dictionary, it will turn double
+ quotes to single quotes. When this string is fed into prompt, the LLMs
+ may learn to use single quotes instead of double quotes (which
+ cannot be loaded by `json.loads` API).
+
+ - For `json.dumps(content)`, if `content` is a string, it will add
+ double quotes to the string. LLMs may learn to use double quotes to
+ wrap strings, which leads to the same issue as `str(content)`.
+
+ To avoid these issues, we use this function to safely convert the
+ content to a string used in prompt.
+
+ Args:
+ content (`Any`):
+ The content to be converted.
+
+ Returns:
+ `str`: The converted string.
+ """
+
+ if isinstance(content, str):
+ return content
+ elif isinstance(content, (dict, list, int, float, bool, tuple)):
+ return json.dumps(content, ensure_ascii=False)
+ else:
+ return str(content)
+
+
@contextlib.contextmanager
def timer(seconds: Optional[Union[int, float]] = None) -> Generator:
"""
diff --git a/requirements.txt b/requirements.txt
index 115cf6a..47a6dfb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -19,7 +19,7 @@ Pyarrow
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
docker
-Levenshtein
+edit_distance
redis==5.0.1
pydantic<=1.10.14
aliyun-log-python-sdk==0.9.0
diff --git a/setup.py b/setup.py
index 1553d9c..7fb0ef2 100644
--- a/setup.py
+++ b/setup.py
@@ -35,8 +35,12 @@
"notebook",
"docker",
"sseclient",
- "Levenshtein",
+ "edit_distance",
"urllib3==1.26.6",
+ "ollama",
+ "colorama",
+ "pycryptodome",
+ "dashscope"
#
"chromadb==0.4.17",
"javalang==0.13.0",
diff --git a/tests/agents/funccall_agent_test.py b/tests/agents/funccall_agent_test.py
new file mode 100644
index 0000000..256e173
--- /dev/null
+++ b/tests/agents/funccall_agent_test.py
@@ -0,0 +1,97 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+
+from muagent.schemas import Message, Memory
+from muagent.agents import FunctioncallAgent
+from muagent import get_agent, get_project_config_from_env
+
+
+# log-level,print prompt和llm predict
+os.environ["log_verbose"] = "0"
+
+AGENT_CONFIGS = {
+ "codefuse_function_caller": {
+ "config_name": "codefuse_function_caller",
+ "agent_type": "FunctioncallAgent",
+ "agent_name": "codefuse_function_caller",
+ "llm_config_name": "qwener"
+ }
+}
+os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+project_config = get_project_config_from_env()
+tools = ["KSigmaDetector", "MetricsQuery"]
+tools = [
+ "谁是卧底-座位分配", "谁是卧底-角色分配", "谁是卧底-结果输出", "谁是卧底-胜利条件判断",
+ "谁是卧底-张伟", "谁是卧底-李静", "谁是卧底-王鹏",
+]
+
+# tools = [
+# "狼人杀-角色分配工具", "狼人杀-座位分配", "狼人杀-胜利条件判断", "狼人杀-结果输出",
+# '狼人杀-agent_朱丽', '狼人杀-agent_周杰', '狼人杀-agent_沈强', '狼人杀-agent_韩刚',
+# '狼人杀-agent_梁军', '狼人杀-agent_周欣怡', '狼人杀-agent_贺子轩'
+# ]
+
+agent = FunctioncallAgent(
+ agent_name="codefuse_function_caller",
+ project_config=project_config,
+ tools=tools
+)
+
+
+memory_content = "[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.89, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.89, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.89, 26.789, 28.901]"
+memory = Memory(
+ messages=[Message(
+ role_type="observation",
+ content=memory_content
+ )]
+)
+query_content = "帮我查询下127.0.0.1这个服务器的在10点的数据"
+query_content = "帮我判断这个数据是否异常"
+query_content = "开始分配座位"
+query_content = "开始分配身份"
+query_content = "游戏是否结束"
+query_content = "游戏的胜利玩家是谁"
+
+memory_content = "3号玩家说今天天气很好"
+memory = Memory(
+ messages=[Message(
+ role_type="observation",
+ content=memory_content
+ )]
+)
+query_content = "我要使用工具,工具描述为agent_张伟"
+# query_content = "我要使用工具,工具描述为'agent_周杰'"
+
+
+query = Message(
+ role_name="human",
+ role_type="user",
+ content=query_content,
+)
+# agent.pre_print(query)
+# output_message = agent.step(query, memory=memory)
+output_message = agent.step(query, extra_params={"memory": memory_content})
+print("### intput ###\n", output_message.input_text)
+print("### content ###\n", output_message.content)
+print("### observation ###\n", output_message.parsed_contents[-1]["Observation"])
+print("### step content ###\n", output_message.step_content)
\ No newline at end of file
diff --git a/tests/agents/group_agent_test.py b/tests/agents/group_agent_test.py
new file mode 100644
index 0000000..182aa9c
--- /dev/null
+++ b/tests/agents/group_agent_test.py
@@ -0,0 +1,71 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+from muagent.tools import TOOL_SETS
+from muagent.schemas import Message
+from muagent.agents import BaseAgent
+from muagent.project_manager import get_project_config_from_env
+
+
+tools = list(TOOL_SETS)
+tools = ["KSigmaDetector", "MetricsQuery"]
+role_prompt = "you are a helpful assistant!"
+
+AGENT_CONFIGS = {
+ "grouper": {
+ "agent_type": "GroupAgent",
+ "agent_name": "grouper",
+ "agents": ["codefuse_reacter_1", "codefuse_reacter_2"]
+ },
+ "codefuse_reacter_1": {
+ "agent_type": "ReactAgent",
+ "agent_name": "codefuse_reacter_1",
+ "tools": tools,
+ },
+ "codefuse_reacter_2": {
+ "agent_type": "ReactAgent",
+ "agent_name": "codefuse_reacter_2",
+ "tools": tools,
+ }
+}
+os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+# log-level,print prompt和llm predict
+os.environ["log_verbose"] = "0"
+
+#
+project_config = get_project_config_from_env()
+agent = BaseAgent.init_from_project_config(
+ "grouper", project_config
+)
+
+query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
+query = Message(
+ role_name="human",
+ role_type="user",
+ content=query_content,
+)
+# agent.pre_print(query)
+output_message = agent.step(query)
+print("input:", output_message.input_text)
+print("content:", output_message.content)
+print("step_content:", output_message.step_content)
\ No newline at end of file
diff --git a/tests/agents/react_agent_test.py b/tests/agents/react_agent_test.py
new file mode 100644
index 0000000..3703812
--- /dev/null
+++ b/tests/agents/react_agent_test.py
@@ -0,0 +1,62 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+from muagent.tools import TOOL_SETS
+from muagent.schemas import Message
+from muagent.agents import BaseAgent
+from muagent import get_project_config_from_env
+
+# log-level,print prompt和llm predict
+os.environ["log_verbose"] = "0"
+
+tools = list(TOOL_SETS)
+tools = ["KSigmaDetector", "MetricsQuery"]
+role_prompt = "you are a helpful assistant!"
+
+AGENT_CONFIGS = {
+ "reacter": {
+ "system_prompt": role_prompt,
+ "agent_type": "ReactAgent",
+ "agent_name": "reacter",
+ "tools": tools,
+ "llm_config_name": "qwen_chat"
+ }
+}
+os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+#
+project_config = get_project_config_from_env()
+agent = BaseAgent.init_from_project_config(
+ "reacter", project_config
+)
+
+query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
+query = Message(
+ role_name="human",
+ role_type="user",
+ content=query_content,
+)
+# agent.pre_print(query)
+output_message = agent.step(query)
+print("### intput ###\n", output_message.input_text)
+print("### content ###\n", output_message.content)
+print("### step content ###\n", output_message.step_content)
\ No newline at end of file
diff --git a/tests/agents/single_agent_test.py b/tests/agents/single_agent_test.py
new file mode 100644
index 0000000..ededa33
--- /dev/null
+++ b/tests/agents/single_agent_test.py
@@ -0,0 +1,119 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+from muagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
+
+from muagent.schemas import Message
+from muagent.models import ModelConfig
+from muagent.agents import SingleAgent, BaseAgent
+from muagent import get_project_config_from_env
+
+
+
+# log-level,print prompt和llm predict
+os.environ["log_verbose"] = "0"
+
+role_prompt = "you are a helpful assistant!"
+role_prompt = """#### AGENT PROFILE
+you are a helpful assistant!
+
+#### RESPONSE OUTPUT FORMAT
+**Action Status:** Set to 'stopped' or 'code_executing'.
+If it's 'stopped', the action is to provide the final answer to the session records and executed steps.
+If it's 'code_executing', the action is to write the code.
+
+**Action:**
+```python
+# Write your code here
+...
+```
+"""
+
+role_prompt = """#### AGENT PROFILE
+you are a helpful assistant!
+
+#### RESPONSE OUTPUT FORMAT
+**Action Status:** Set to either 'stopped' or 'tool_using'. If 'stopped', provide the final response to the original question. If 'tool_using', proceed with using the specified tool.
+
+**Action:** Use the tools by formatting the tool action in JSON. The format should be:
+
+```json
+{
+ "tool_name": "$TOOL_NAME",
+ "tool_params": "$INPUT"
+}
+```
+"""
+
+tools = list(TOOL_SETS)
+tools = ["KSigmaDetector", "MetricsQuery"]
+
+
+AGENT_CONFIGS = {
+ "codefuse_simpler": {
+ "agent_type": "SingleAgent",
+ "agent_name": "codefuse_simpler",
+ "tools": tools,
+ "llm_config_name": "qwen_chat"
+ }
+}
+os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+
+project_config = get_project_config_from_env()
+agent = BaseAgent.init_from_project_config(
+ "codefuse_simpler", project_config
+)
+# base_agent = SingleAgent(
+# system_prompt=role_prompt,
+# project_config=project_config,
+# tools=tools
+# )
+
+
+question = "用python画一个爱心"
+query = Message(
+ session_index="agent_test",
+ role_type="user",
+ role_name="user",
+ content=question,
+)
+
+# base_agent.pre_print(query)
+# output_message = base_agent.step(query)
+# print(output_message.input_text)
+# print(output_message.content)
+
+
+
+
+query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
+query = Message(
+ role_name="human",
+ role_type="user",
+ input_text=query_content,
+)
+# base_agent.pre_print(query)
+output_message = agent.step(query)
+print("### intput ###\n", output_message.input_text)
+print("### content ###\n", output_message.content)
+print("### step content ###\n", output_message.step_content)
\ No newline at end of file
diff --git a/tests/agents/task_agent_test.py b/tests/agents/task_agent_test.py
new file mode 100644
index 0000000..c446ed2
--- /dev/null
+++ b/tests/agents/task_agent_test.py
@@ -0,0 +1,66 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+from muagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
+
+from muagent.schemas import Message
+from muagent.models import ModelConfig
+from muagent.agents import BaseAgent
+from muagent import get_project_config_from_env
+
+
+# log-level,print prompt和llm predict
+os.environ["log_verbose"] = "0"
+
+
+tools = list(TOOL_SETS)
+tools = ["KSigmaDetector", "MetricsQuery"]
+role_prompt = "you are a helpful assistant!"
+
+AGENT_CONFIGS = {
+ "tasker": {
+ "system_prompt": role_prompt,
+ "agent_type": "TaskAgent",
+ "agent_name": "tasker",
+ "tools": tools,
+ "llm_config_name": "qwen_chat"
+ }
+}
+os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+#
+project_config = get_project_config_from_env()
+agent = BaseAgent.init_from_project_config(
+ "tasker", project_config
+)
+
+query_content = "先帮我获取下127.0.0.1这个服务器在10点的数,然后在帮我判断下数据是否存在异常"
+query = Message(
+ role_name="human",
+ role_type="user",
+ content=query_content,
+)
+# agent.pre_print(query)
+output_message = agent.step(query)
+print("### intput ###\n", output_message.input_text)
+print("### content ###\n", output_message.content)
+print("### step content ###\n", output_message.step_content)
\ No newline at end of file
diff --git a/tests/llm_models/embedding_test.py b/tests/llm_models/embedding_test.py
new file mode 100644
index 0000000..3e3a284
--- /dev/null
+++ b/tests/llm_models/embedding_test.py
@@ -0,0 +1,48 @@
+from loguru import logger
+import os, sys
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.append(src_dir)
+try:
+ import test_config
+ api_key = os.environ["OPENAI_API_KEY"]
+ api_base_url= os.environ["API_BASE_URL"]
+ model_name = os.environ["model_name"]
+ embed_model = os.environ["embed_model"]
+ embed_model_path = os.environ["embed_model_path"]
+except Exception as e:
+ # set your config
+ api_key = ""
+ api_base_url= ""
+ model_name = ""
+ embed_model = ""
+ embed_model_path = ""
+ logger.error(f"{e}")
+
+
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+from muagent.models import get_model
+from muagent.schemas.models import ModelConfig
+import json
+
+model_configs = json.loads(os.environ["MODEL_CONFIGS"])
+
+for model_type in model_configs.keys():
+ if "_embedding" not in model_type: continue
+ model_config = model_configs[model_type]
+ embed_config = ModelConfig(
+ config_name="model_test",
+ model_type=model_type,
+ model_name=model_config["model_name"],
+ api_key=model_config["api_key"],
+ )
+
+ model = get_model(embed_config)
+
+
+ print(model_type, model_config["model_name"], len(model.embed_query("hello")))
\ No newline at end of file
diff --git a/tests/llm_models/model_test.py b/tests/llm_models/model_test.py
new file mode 100644
index 0000000..f2e4471
--- /dev/null
+++ b/tests/llm_models/model_test.py
@@ -0,0 +1,93 @@
+from loguru import logger
+import os, sys
+import json
+
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.append(src_dir)
+try:
+ import test_config
+ api_key = os.environ["OPENAI_API_KEY"]
+ api_base_url= os.environ["API_BASE_URL"]
+ model_name = os.environ["model_name"]
+ embed_model = os.environ["embed_model"]
+ embed_model_path = os.environ["embed_model_path"]
+except Exception as e:
+ # set your config
+ api_key = ""
+ api_base_url= ""
+ model_name = ""
+ embed_model = ""
+ embed_model_path = ""
+ logger.error(f"{e}")
+
+
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+from muagent.models import get_model
+from muagent.schemas.models import ModelConfig
+
+model_configs = json.loads(os.environ["MODEL_CONFIGS"])
+
+# "openai_chat","yi_chat","qwen_chat", "dashscope_chat""moonshot_chat", "ollama_chat"
+
+model_type = "ollama_chat"
+model_config = model_configs[model_type]
+
+model_config = ModelConfig(
+ config_name="model_test",
+ model_type=model_type,
+ model_name=model_config["model_name"],
+ api_key=model_config["api_key"],
+)
+model = get_model(model_config)
+
+# 工具
+tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "strict": True,
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {"type": "string"},
+ "unit": {"type": "string", "enum": ["c", "f"]},
+ },
+ "required": ["location", "unit"],
+ "additionalProperties": False,
+ },
+ },
+ }
+]
+
+
+# print(model.generate("输出 '今天你好'", stop="你", format_type='str'))
+for i in model.generate_stream("hello", stop="你", format_type='str'):
+ print(i)
+
+# #
+# print(model.generate("hello", format_type='str'))
+
+# #
+# for i in model.generate_stream("hello", format_type='str'):
+# print(i)
+
+# #
+# print(model.chat([{"role": "user", "content":"hello"}], format_type='str'))
+
+# #
+# for i in model.chat_stream([{"role": "user", "content":"hello"}], format_type='str'):
+# print(i)
+
+# #
+# print(model.function_call(tools=tools, prompt="我想查北京的天气"))
+
+# #
+# for i in model.function_call_stream(tools=tools, messages=[{"role": "user", "content":"我想查北京的天气"}]):
+# print(i)
\ No newline at end of file
diff --git a/tests/memory_manager/local_memory_manager_test.py b/tests/memory_manager/local_memory_manager_test.py
new file mode 100644
index 0000000..2c5ac16
--- /dev/null
+++ b/tests/memory_manager/local_memory_manager_test.py
@@ -0,0 +1,161 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+ api_key = os.environ["OPENAI_API_KEY"]
+ api_base_url= os.environ["API_BASE_URL"]
+ model_name = os.environ["model_name"]
+ model_engine = os.environ["model_engine"]
+ embed_model = os.environ["embed_model"]
+ embed_model_path = os.environ["embed_model_path"]
+except Exception as e:
+ # set your config
+ api_key = ""
+ api_base_url= ""
+ model_name = ""
+ model_engine = os.environ["model_engine"]
+ embed_model = ""
+ embed_model_path = ""
+ logger.error(f"{e}")
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+
+from muagent.utils.common_utils import getCurrentDatetime
+from muagent.schemas.db import DBConfig, GBConfig, VBConfig, TBConfig
+from muagent.schemas import Message
+from muagent.models import ModelConfig, get_model
+
+from muagent.memory_manager import LocalMemoryManager
+
+
+from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
+
+model_configs = json.loads(os.environ["MODEL_CONFIGS"])
+
+#
+# llm_config = LLMConfig(
+# model_name=model_name, model_engine=model_engine, api_key=api_key, api_base_url=api_base_url, temperature=0.3,
+# )
+model_type = "qwen_chat"
+model_config = model_configs[model_type]
+model_config = ModelConfig(
+ config_name="model_test",
+ model_type=model_type,
+ model_name=model_config["model_name"],
+ api_key=model_config["api_key"],
+)
+
+
+# embed_config = EmbedConfig(
+# embed_engine="model", embed_model=embed_model, embed_model_path=embed_model_path
+# )
+model_type = "qwen_text_embedding"
+embed_config = model_configs[model_type]
+embed_config = ModelConfig(
+ config_name="model_test",
+ model_type=model_type,
+ model_name=embed_config["model_name"],
+ api_key=embed_config["api_key"],
+)
+
+# prepare your message
+message1 = Message(
+ session_index="default",
+ role_name="test1",
+ role_type="user",
+ content="hello",
+ spec_parsed_contents=[{"input": "hello"}],
+)
+
+text = "hi! how can I help you?"
+message2 = Message(
+ session_index="shuimo",
+ role_name="test2",
+ role_type="assistant",
+ content=text,
+ arsed_output_list=[{"answer": text}],
+)
+
+text = "they say hello and hi to each other"
+message3 = Message(
+ session_index="shanshi",
+ role_name="test3",
+ role_type="summary",
+ content=text,
+ spec_parsed_contents=[{"summary": text}],
+)
+
+vb_config = VBConfig(vb_type="LocalFaissHandler")
+
+# append or extend test
+print("###"*10 + "append or extend" + "###"*10)
+local_memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=True)
+# append can ignore user_name
+local_memory_manager.append(message=message1)
+local_memory_manager.append(message=message2)
+local_memory_manager.append(message=message3)
+
+# test init_local
+print("###"*10 + "dont load local" + "###"*10)
+local_memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=True)
+print(local_memory_manager.get_memory_pool("default").to_format_messages(
+ content_key="content", format_type='str'))
+print(local_memory_manager.get_memory_pool("shuimo").to_format_messages(
+ content_key="content", format_type='str'))
+print(local_memory_manager.get_memory_pool("shanshi").to_format_messages(
+ content_key="content", format_type='str'))
+
+# test load from local
+print("###"*10 + "load local" + "###"*10)
+local_memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=False)
+print(local_memory_manager.get_memory_pool("default").to_format_messages(
+ content_key="content", format_type='str'))
+print(local_memory_manager.get_memory_pool("shuimo").to_format_messages(
+ content_key="content", format_type='str'))
+print(local_memory_manager.get_memory_pool("shanshi").to_format_messages(
+ content_key="content", format_type='str'))
+
+
+
+local_memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=False)
+# embedding retrieval test
+print("###"*10 + "retrieval" + "###"*10)
+text = "say hi to each other,"
+# retrieval_type=datetime => retrieval from datetime and jieba
+print(local_memory_manager.router_retrieval(
+ session_index="shanshi", text=text, datetime=getCurrentDatetime(),
+ n=4, top_k=5, retrieval_type= "datetime"))
+# retrieval_type=embedding => retrieval from embedding
+print(local_memory_manager.router_retrieval(
+ session_index="shanshi", text=text, top_k=5, retrieval_type= "embedding"))
+# retrieval_type=text => retrieval from jieba
+print(local_memory_manager.router_retrieval(
+ session_index="shanshi", text=text, top_k=5, retrieval_type= "text"))
+
+# # recursive_summary test
+print("###"*10 + "recursive_summary" + "###"*10)
+print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("shanshi").messages, split_n=1, session_index="shanshi"))
+
+# print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("shuimo").messages, split_n=1, session_index="shanshi"))
+
+# print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("default").messages, split_n=1, session_index="shanshi"))
+
+
+# test after clear local vs and jsonl
+print("###"*10 + "test after clear local vs and jsonl" + "###"*10)
+local_memory_manager.clear_local(re_init=True)
+print(local_memory_manager.get_memory_pool("shanshi").to_format_messages(
+ content_key="content", format_type='str'))
\ No newline at end of file
diff --git a/tests/memory_manager/local_mm_crud_test.py b/tests/memory_manager/local_mm_crud_test.py
new file mode 100644
index 0000000..d5211cb
--- /dev/null
+++ b/tests/memory_manager/local_mm_crud_test.py
@@ -0,0 +1,117 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+ api_key = os.environ["OPENAI_API_KEY"]
+ api_base_url= os.environ["API_BASE_URL"]
+ model_name = os.environ["model_name"]
+ model_engine = os.environ["model_engine"]
+ embed_model = os.environ["embed_model"]
+ embed_model_path = os.environ["embed_model_path"]
+except Exception as e:
+ # set your config
+ api_key = ""
+ api_base_url= ""
+ model_name = ""
+ model_engine = os.environ["model_engine"]
+ embed_model = ""
+ embed_model_path = ""
+ logger.error(f"{e}")
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+
+from muagent.utils.common_utils import getCurrentDatetime
+from muagent.schemas.db import DBConfig, GBConfig, VBConfig, TBConfig
+from muagent.schemas import Message
+from muagent.models import ModelConfig, get_model
+
+from muagent.memory_manager import LocalMemoryManager, TbaseMemoryManager
+
+
+from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
+
+model_configs = json.loads(os.environ["MODEL_CONFIGS"])
+
+#
+# llm_config = LLMConfig(
+# model_name=model_name, model_engine=model_engine, api_key=api_key, api_base_url=api_base_url, temperature=0.3,
+# )
+model_type = "qwen_chat"
+model_config = model_configs[model_type]
+model_config = ModelConfig(
+ config_name="model_test",
+ model_type=model_type,
+ model_name=model_config["model_name"],
+ api_key=model_config["api_key"],
+)
+
+
+# embed_config = EmbedConfig(
+# embed_engine="model", embed_model=embed_model, embed_model_path=embed_model_path
+# )
+model_type = "qwen_text_embedding"
+embed_config = model_configs[model_type]
+embed_config = ModelConfig(
+ config_name="model_test",
+ model_type=model_type,
+ model_name=embed_config["model_name"],
+ api_key=embed_config["api_key"],
+)
+
+
+# 初始化 TbaseHandler 实例
+tb_config = TBConfig(
+ tb_type="TbaseHandler",
+ index_name="muagent_test",
+ host="127.0.0.1",
+ port=os.environ['tb_port'],
+ username=os.environ['tb_username'],
+ password=os.environ['tb_password'],
+)
+
+vb_config = VBConfig(vb_type="LocalFaissHandler")
+
+# append or extend test
+# memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=True)
+memory_manager = TbaseMemoryManager(embed_config=embed_config, llm_config=model_config, tb_config=tb_config)
+
+
+# prepare your message
+message1 = Message(
+ session_index="default",
+ message_index="default",
+ role_name="crud_test",
+ role_type="user",
+ content="hello",
+ role_tags=["shanshi"]
+)
+
+# append can ignore user_name
+memory_manager.append(message=message1)
+print(memory_manager.get_memory_pool("default").to_format_messages(format_type="raw"))
+
+# prepare your message
+message2 = Message(
+ session_index="default",
+ message_index="default",
+ role_name="crud_test",
+ role_type="user",
+ content="hello",
+ role_tags=["test"]
+)
+
+memory_manager.append(message=message2, role_tag="test")
+print(memory_manager.get_memory_pool("default").to_format_messages(format_type="raw"))
\ No newline at end of file
diff --git a/tests/memory_manager/tbase_memory_manager_test.py b/tests/memory_manager/tbase_memory_manager_test.py
new file mode 100644
index 0000000..f899cac
--- /dev/null
+++ b/tests/memory_manager/tbase_memory_manager_test.py
@@ -0,0 +1,153 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+ api_key = os.environ["OPENAI_API_KEY"]
+ api_base_url= os.environ["API_BASE_URL"]
+ model_name = os.environ["model_name"]
+ model_engine = os.environ["model_engine"]
+ embed_model = os.environ["embed_model"]
+ embed_model_path = os.environ["embed_model_path"]
+except Exception as e:
+ # set your config
+ api_key = ""
+ api_base_url= ""
+ model_name = ""
+ model_engine = os.environ["model_engine"]
+ embed_model = ""
+ embed_model_path = ""
+ logger.error(f"{e}")
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+
+from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
+from muagent.schemas.db import DBConfig, GBConfig, VBConfig, TBConfig
+from muagent.schemas import Message
+from muagent.models import ModelConfig, get_model
+
+from muagent.memory_manager import TbaseMemoryManager
+from muagent.utils.common_utils import getCurrentDatetime
+
+model_configs = json.loads(os.environ["MODEL_CONFIGS"])
+
+#
+# llm_config = LLMConfig(
+# model_name=model_name, model_engine=model_engine, api_key=api_key, api_base_url=api_base_url, temperature=0.3,
+# )
+model_type = "qwen_chat"
+model_config = model_configs[model_type]
+model_config = ModelConfig(
+ config_name="model_test",
+ model_type=model_type,
+ model_name=model_config["model_name"],
+ api_key=model_config["api_key"],
+)
+
+
+# embed_config = EmbedConfig(
+# embed_engine="model", embed_model=embed_model, embed_model_path=embed_model_path
+# )
+model_type = "qwen_text_embedding"
+embed_config = model_configs[model_type]
+embed_config = ModelConfig(
+ config_name="model_test",
+ model_type=model_type,
+ model_name=embed_config["model_name"],
+ api_key=embed_config["api_key"],
+)
+
+
+
+# 初始化 TbaseHandler 实例
+tb_config = TBConfig(
+ tb_type="TbaseHandler",
+ index_name="muagent_test",
+ host="127.0.0.1",
+ port=os.environ['tb_port'],
+ username=os.environ['tb_username'],
+ password=os.environ['tb_password'],
+)
+
+# prepare your message
+message1 = Message(
+ session_index="default",
+ role_name="test1",
+ role_type="user",
+ content="hello",
+ spec_parsed_contents=[{"input": "hello"}],
+)
+
+text = "hi! how can I help you?"
+message2 = Message(
+ session_index="shuimo",
+ role_name="test2",
+ role_type="assistant",
+ content=text,
+ arsed_output_list=[{"answer": text}],
+)
+
+text = "they say hello and hi to each other"
+message3 = Message(
+ session_index="shanshi",
+ role_name="test3",
+ role_type="summary",
+ content=text,
+ spec_parsed_contents=[{"summary": text}],
+)
+
+
+# # append or extend test
+# print("###"*10 + "append or extend" + "###"*10)
+# local_memory_manager = TbaseMemoryManager(embed_config=embed_config, llm_config=model_config, tb_config=tb_config, do_init=True)
+# # append can ignore user_name
+# local_memory_manager.append(message=message1)
+# local_memory_manager.append(message=message2)
+# local_memory_manager.append(message=message3)
+
+
+# # test load from local
+# print("###"*10 + "load local" + "###"*10)
+# local_memory_manager = TbaseMemoryManager(embed_config=embed_config, llm_config=model_config, tb_config=tb_config, do_init=False)
+# print(local_memory_manager.get_memory_pool("default").to_format_messages(
+# content_key="content", format_type='str'))
+# print(local_memory_manager.get_memory_pool("shuimo").to_format_messages(
+# content_key="content", format_type='str'))
+# print(local_memory_manager.get_memory_pool("shanshi").to_format_messages(
+# content_key="content", format_type='str'))
+
+
+# embedding retrieval test
+print("###"*10 + "retrieval" + "###"*10)
+local_memory_manager = TbaseMemoryManager(embed_config=embed_config, llm_config=model_config, tb_config=tb_config, do_init=False)
+# text = "say hi to each other,"
+# # retrieval_type=datetime => retrieval from datetime and jieba
+# print(local_memory_manager.router_retrieval(
+# session_index="shanshi", text=text, datetime=getCurrentDatetime(),
+# n=30, top_k=5, retrieval_type= "datetime"))
+# # retrieval_type=eembedding => retrieval from embedding
+# print(local_memory_manager.router_retrieval(
+# session_index="shanshi", text=text, top_k=5, retrieval_type= "embedding"))
+# # retrieval_type=text => retrieval from jieba
+# print(local_memory_manager.router_retrieval(
+# session_index="shanshi", text=text, top_k=5, retrieval_type= "text"))
+
+# # recursive_summary test
+print("###"*10 + "recursive_summary" + "###"*10)
+print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("shanshi").messages, split_n=1, session_index="shanshi"))
+
+print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("shuimo").messages, split_n=1, session_index="shanshi"))
+
+print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("default").messages, split_n=1, session_index="shanshi"))
diff --git a/tests/orm/table_test.py b/tests/orm/table_test.py
index da9980a..ce7b1ca 100644
--- a/tests/orm/table_test.py
+++ b/tests/orm/table_test.py
@@ -2,7 +2,8 @@
from loguru import logger
os.environ["do_create_dir"] = "1"
-from muagent.orm import create_tables
+# from muagent.orm import create_tables
+from muagent.db_handler import create_tables
# use to test, don't create some directory
diff --git a/tests/prompt_manager/base_test.py b/tests/prompt_manager/base_test.py
new file mode 100644
index 0000000..b828e91
--- /dev/null
+++ b/tests/prompt_manager/base_test.py
@@ -0,0 +1,116 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+
+from muagent.schemas import Message, Memory
+from muagent.prompt_manager import CommonPromptManager
+
+
+system_prompt = """#### Agent Profile
+As an agent specializing in software quality assurance,
+your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
+This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
+Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
+Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
+
+ATTENTION: response carefully referenced "Response Output Format" in format.
+
+Each test case should include:
+1. clear description of the test purpose.
+2. The input values or conditions for the test.
+3. The expected outcome or assertion for the test.
+4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
+5. these test code should have package and import
+
+#### Input Format
+
+**Code Snippet:** the initial Code or objective that the user wanted to achieve
+
+**Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
+
+#### Response Output Format
+**SaveFileName:** construct a local file name based on Question and Context, such as
+
+```java
+package/class.java
+```
+
+**Test Code:** generate the test code for the current Code Snippet.
+```java
+...
+```
+
+"""
+
+intput_template = ""
+output_template = ""
+prompt = ""
+
+agent_names = ["agent1", "agent2"]
+agent_descs = [f"hello {agent}" for agent in agent_names]
+tools = ["Multiplier", "WeatherInfo"]
+
+bpm = CommonPromptManager(
+ system_prompt=system_prompt,
+ input_template=intput_template,
+ output_template=output_template,
+ prompt=prompt,
+)
+
+#
+message1 = Message(
+ role_name="test",
+ role_type="user",
+ content="hello"
+)
+message2 = Message(
+ role_name="test",
+ role_type="assistant",
+ content="hi! can i help you!"
+)
+query = Message(
+ role_name="test",
+ role_type="user",
+ input_text="i want to know the weather of beijing",
+ content="i want to know the weather of beijing",
+ spec_parsed_content={
+ "Retrieval Code Snippets": "hi"
+ },
+ global_kwargs={
+ "Code Snippet": "hello",
+ "Test Code": "nice to meet you."
+ }
+)
+memory = Memory(messages=[message1, message2])
+
+# prompt = bpm.pre_print(
+# query=query, memory=memory, tools=tools,
+# agent_names=agent_names, agent_descs=agent_descs
+# )
+# print(prompt)
+
+prompt = bpm.generate_prompt(
+ query=query, memory=memory, tools=tools,
+ agent_names=agent_names, agent_descs=agent_descs
+)
+print(prompt)
\ No newline at end of file
diff --git a/tests/prompt_manager/extend_common_pm_test.py b/tests/prompt_manager/extend_common_pm_test.py
new file mode 100644
index 0000000..850f07b
--- /dev/null
+++ b/tests/prompt_manager/extend_common_pm_test.py
@@ -0,0 +1,161 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+
+from muagent.schemas import Memory, Message, PromptConfig
+from muagent.prompt_manager import CommonPromptManager
+
+
+from typing import (
+ List,
+ Any,
+ Union,
+ Optional,
+ Dict,
+ Literal
+)
+from pydantic import BaseModel
+
+class NewPromptManager(CommonPromptManager):
+
+ pm_type: str = "NewPromptManager"
+ """The type of prompt manager."""
+
+ def __init__(
+ self,
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = "",
+ prompt: Optional[str] = None,
+ language: Literal["en", "zh"] = "en",
+ *,
+ monitored_agents=[],
+ monitored_fields=[],
+ **kwargs
+ ):
+ # update new titles
+ extra_registry_titles: Dict = {
+ "EXAMPLE": {
+ "description": "这里是一些实例以供参考。",
+ "function": "",
+ "display_type": "title"
+ },
+ "INPUT EXAMPLE": {
+ "description": "this input example",
+ "prompt": "",
+ "function": "",
+ "display_type": "description"
+ },
+ "OUTPUT EXAMPLE": {
+ "description": "this output example",
+ "prompt": "",
+ "function": "handle_empty_key",
+ "display_type": "description"
+ },
+ }
+ #
+ extra_register_edges: List = [
+ ("EXAMPLE", "INPUT EXAMPLE"),
+ ("EXAMPLE", "OUTPUT EXAMPLE"),
+ ]
+
+ #
+ new_dfsindex_to_str_format: Dict = {
+ 0: "#### {}\n{}",
+ 1: "### {}\n{}",
+ 2: "## {}\n{}",
+ 3: "# {}\n{}",
+ }
+ """use {title name} {description/function_value}"""
+
+ super().__init__(
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=output_template,
+ prompt=prompt,
+ language=language,
+ extra_registry_titles=extra_registry_titles,
+ extra_register_edges=extra_register_edges,
+ new_dfsindex_to_str_format=new_dfsindex_to_str_format,
+ monitored_agents=monitored_agents,
+ monitored_fields=monitored_fields,
+ **kwargs
+ )
+
+
+system_prompt = "you are a helpful assistant!\n"
+intput_template = ""
+output_template = ""
+prompt = ""
+
+agent_names = ["agent1", "agent2"]
+agent_descs = [f"hello {agent}" for agent in agent_names]
+tools = ["Multiplier", "WeatherInfo"]
+
+
+bpm = NewPromptManager(
+ # system_prompt=system_prompt,
+ # input_template=intput_template,
+ # output_template=output_template,
+ # prompt=prompt,
+ language="en",
+)
+
+
+#
+message1 = Message(
+ role_name="test",
+ role_type="user",
+ content="hello"
+)
+message2 = Message(
+ role_name="test",
+ role_type="assistant",
+ content="hi! can i help you!"
+)
+query = Message(
+ role_name="test",
+ role_type="user",
+ input_text="i want to know the weather of beijing",
+ content="i want to know the weather of beijing",
+ spec_parsed_content={
+ "Retrieval Code Snippets": "hi"
+ },
+ global_kwargs={
+ "Code Snippet": "hello",
+ "Test Code": "nice to meet you."
+ }
+)
+memory = Memory(messages=[message1, message2])
+
+# prompt = bpm.pre_print(
+# query=query, memory=memory, tools=tools,
+# agent_names=agent_names, agent_descs=agent_descs
+# )
+# print(prompt)
+
+prompt = bpm.generate_prompt(
+ query=query, memory=memory, tools=tools,
+ agent_names=agent_names, agent_descs=agent_descs
+)
+print(prompt)
\ No newline at end of file
diff --git a/tests/prompt_manager/new_pm_test.py b/tests/prompt_manager/new_pm_test.py
new file mode 100644
index 0000000..6735fa1
--- /dev/null
+++ b/tests/prompt_manager/new_pm_test.py
@@ -0,0 +1,233 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+
+from muagent.schemas import Memory, Message, PromptConfig
+from muagent.prompt_manager import BasePromptManager
+from muagent.agents import BaseAgent
+
+
+from typing import (
+ List,
+ Any,
+ Union,
+ Optional,
+ Dict,
+ Literal
+)
+from pydantic import BaseModel
+
+class NewPromptManager(BasePromptManager):
+
+ pm_type: str = "NewPromptManager"
+ """The type of prompt manager."""
+
+ def __init__(
+ self,
+ system_prompt: str = "you are a helpful assistant!\n",
+ input_template: Union[str, BaseModel] = "",
+ output_template: Union[str, BaseModel] = "",
+ prompt: Optional[str] = None,
+ language: Literal["en", "zh"] = "en",
+ *,
+ monitored_agents=[],
+ monitored_fields=[],
+ **kwargs
+ ):
+ super().__init__(
+ system_prompt=system_prompt,
+ input_template=input_template,
+ output_template=output_template,
+ prompt=prompt,
+ language=language,
+ monitored_agents=monitored_agents,
+ monitored_fields=monitored_fields,
+ **kwargs
+ )
+ # update new titles
+ self.extra_registry_titles: Dict = {
+ "AGENT PROFILE": {
+ "description": "",
+ "function": "handle_agent_profile",
+ "display_type": "title"
+ },
+ "TOOL INFORMATION": {
+ "description": "",
+ "prompt": (
+ 'Below is a list of tools that are available for your use:{formatted_tools}'
+ '\nvalid "tool_name" value is:\n{tool_names}'
+ ),
+ "function": "handle_tool_data",
+ "display_type": "description",
+ "str_template": "**{}\n{}"
+ },
+ "AGENT INFORMATION": {
+ "description": "",
+ "prompt": (
+ 'Please ensure your selection is one of the listed roles. Available roles for selection:\n{agents}'
+ 'Please ensure select the Role from agent names, such as {agent_names}'
+ ),
+ "function": "handle_agent_data",
+ "display_type": "description"
+ },
+ }
+ #
+ self.extra_register_edges: List = [
+ ("AGENT PROFILE", "AGENT INFORMATION"),
+ ("AGENT PROFILE", "TOOL INFORMATION"),
+ ]
+
+ #
+
+ self.new_dfsindex_to_str_format: Dict = {
+ 0: "#### {}\n{}",
+ 1: "### {}\n{}",
+ 2: "## {}\n{}",
+ 3: "# {}\n{}",
+ }
+ """use {title name} {description/function_value}"""
+ #
+ self.register_graph({}, [], {}, {})
+
+ def register_prompt(self):
+ """register input/output/prompt into titles and edges"""
+ pass
+
+ def handle_agent_profile(self, **kwargs) -> str:
+ return self.system_prompt
+
+ def handle_tool_data(self, **kwargs):
+ import random
+ from textwrap import dedent
+ from muagent.tools import get_tool, BaseToolModel
+
+ if 'tools' not in kwargs: return ""
+
+ tools: List = kwargs.get('tools')
+ prompt: str = kwargs.get('prompt')
+ tools: List[BaseToolModel] = [get_tool(tool) for tool in tools if isinstance(tool, str)]
+
+ if len(tools) == 0: return ""
+
+ tool_strings = []
+ for tool in tools:
+ args_str = f'args: {str(tool.intput_to_json_schema())}' if tool.ToolInputArgs else ""
+ tool_strings.append(f"{tool.name}: {tool.description}, {args_str}")
+ formatted_tools = "\n".join(tool_strings)
+
+ tool_names = ", ".join([tool.name for tool in tools])
+
+ tool_prompt = dedent(prompt.format(formatted_tools=formatted_tools, tool_names=tool_names))
+ while "\n " in tool_prompt:
+ tool_prompt = tool_prompt.replace("\n ", "\n")
+
+ return tool_prompt
+
+ def handle_agent_data(self, **kwargs):
+ """"""
+ import random
+ from textwrap import dedent
+ if 'agent_names' not in kwargs or "agent_descs" not in kwargs:
+ return ""
+
+ agent_names: List = kwargs.get('agent_names')
+ agent_descs: List = kwargs.get('agent_descs')
+ prompt: str = kwargs.get('prompt')
+
+ if len(agent_names) == 0: return ""
+
+ random.shuffle(agent_names)
+ agent_descriptions = []
+ for agent_name, desc in zip(agent_names, agent_descs):
+ while "\n\n" in desc:
+ desc = desc.replace("\n\n", "\n")
+ desc = desc.replace("\n", ",")
+ agent_descriptions.append(
+ f'"role name: {agent_name}\nrole description: {desc}"'
+ )
+
+ agent_description = "\n".join(agent_descriptions)
+ agent_prompt = dedent(
+ prompt.format(agents=agent_description, agent_names=agent_names)
+ )
+
+ while "\n " in agent_prompt:
+ agent_prompt = agent_prompt.replace("\n ", "\n")
+
+ return agent_prompt
+
+system_prompt = "you are a helpful assistant!\n"
+intput_template = ""
+output_template = ""
+prompt = ""
+
+agent_names = ["agent1", "agent2"]
+agent_descs = [f"hello {agent}" for agent in agent_names]
+tools = ["Multiplier", "WeatherInfo"]
+
+
+bpm = NewPromptManager(
+ # system_prompt=system_prompt,
+ # input_template=intput_template,
+ # output_template=output_template,
+ # prompt=prompt,
+ language="zh",
+)
+
+#
+message1 = Message(
+ role_name="test",
+ role_type="user",
+ content="hello"
+)
+message2 = Message(
+ role_name="test",
+ role_type="assistant",
+ content="hi! can i help you!"
+)
+query = Message(
+ role_name="test",
+ role_type="user",
+ input_text="i want to know the weather of beijing",
+ content="i want to know the weather of beijing",
+ spec_parsed_content={
+ "Retrieval Code Snippets": "hi"
+ },
+ global_kwargs={
+ "Code Snippet": "hello",
+ "Test Code": "nice to meet you."
+ }
+)
+memory = Memory(messages=[message1, message2])
+
+prompt = bpm.pre_print(
+ query=query, memory=memory, tools=tools,
+ agent_names=agent_names, agent_descs=agent_descs
+)
+print(prompt)
+
+prompt = bpm.generate_prompt(
+ query=query, memory=memory, tools=tools,
+ agent_names=agent_names, agent_descs=agent_descs
+)
+print(prompt)
\ No newline at end of file
diff --git a/tests/retrieval/faiss_test.py b/tests/retrieval/faiss_test.py
new file mode 100644
index 0000000..101822f
--- /dev/null
+++ b/tests/retrieval/faiss_test.py
@@ -0,0 +1,69 @@
+import os, sys
+from loguru import logger
+import json
+
+os.environ["do_create_dir"] = "1"
+
+try:
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+ api_key = os.environ["OPENAI_API_KEY"]
+ api_base_url= os.environ["API_BASE_URL"]
+ model_name = os.environ["model_name"]
+ model_engine = os.environ["model_engine"]
+ embed_model = os.environ["embed_model"]
+ embed_model_path = os.environ["embed_model_path"]
+except Exception as e:
+ # set your config
+ api_key = ""
+ api_base_url= ""
+ model_name = ""
+ model_engine = os.environ["model_engine"]
+ embed_model = ""
+ embed_model_path = ""
+ logger.error(f"{e}")
+
+# test local code
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+
+
+from muagent.db_handler import LocalFaissHandler
+from muagent.schemas.db import DBConfig, GBConfig, VBConfig, TBConfig
+from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
+from muagent.models import ModelConfig
+import numpy as np
+
+llm_config = LLMConfig(
+ model_name=model_name, model_engine=model_engine, api_key=api_key, api_base_url=api_base_url, temperature=0.3,
+)
+
+model_configs = json.loads(os.environ["MODEL_CONFIGS"])
+model_type = "ollama_embedding"
+model_config = model_configs[model_type]
+
+embed_config = ModelConfig(
+ config_name="model_test",
+ model_type=model_type,
+ model_name=model_config["model_name"],
+ api_key=model_config["api_key"],
+)
+#
+import random
+embedding = [random.random() for _ in range(768)]
+print(len(embedding), np.mean(embedding))
+
+
+vb_config = VBConfig(vb_type="LocalFaissHandler")
+vb = LocalFaissHandler(embed_config, vb_config)
+
+vb.create_vs("shanshi")
+vector = np.array([embedding], dtype=np.float32)
+scores, indices = vb.search_index.index.search(vector, 20)
+print(scores)
diff --git a/tests/sandbox/nbclient_test.py b/tests/sandbox/nbclient_test.py
new file mode 100644
index 0000000..5367b25
--- /dev/null
+++ b/tests/sandbox/nbclient_test.py
@@ -0,0 +1,218 @@
+# # -*- coding: utf-8 -*-
+# # pylint: disable=C0301
+# """Service for executing jupyter notebooks interactively
+# Partially referenced the implementation of
+# https://github.com/geekan/MetaGPT/blob/main/metagpt/actions/di/execute_nb_code.py
+# """
+# import base64
+# import asyncio
+# from loguru import logger
+
+# try:
+# import nbclient
+# import nbformat
+# except ImportError:
+# nbclient = None
+# nbformat = None
+
+
+# class NoteBookExecutor:
+# """
+# Class for executing jupyter notebooks block interactively.
+# To use the service function, you should first init the class, then call the
+# run_code_on_notebook function.
+
+# Example:
+
+# ```ipython
+# from agentscope.service.service_toolkit import *
+# from agentscope.service.execute_code.exec_notebook import *
+# nbe = NoteBookExecutor()
+# code = "print('helloworld')"
+# # calling directly
+# nbe.run_code_on_notebook(code)
+
+# >>> Executing function run_code_on_notebook with arguments:
+# >>> code: print('helloworld')
+# >>> END
+
+# # calling with service toolkit
+# service_toolkit = ServiceToolkit()
+# service_toolkit.add(nbe.run_code_on_notebook)
+# input_obs = [{"name": "run_code_on_notebook", "arguments":{"code": code}}]
+# res_of_string_input = service_toolkit.parse_and_call_func(input_obs)
+
+# "1. Execute function run_code_on_notebook\n [ARGUMENTS]:\n code: print('helloworld')\n [STATUS]: SUCCESS\n [RESULT]: ['helloworld\\n']\n"
+
+# ```
+# """ # noqa
+
+# def __init__(
+# self,
+# timeout: int = 300,
+# ) -> None:
+# """
+# The construct function of the NoteBookExecutor.
+# Args:
+# timeout (Optional`int`):
+# The timeout for each cell execution.
+# Default to 300.
+# """
+
+# if nbclient is None or nbformat is None:
+# raise ImportError(
+# "The package nbclient or nbformat is not found. Please "
+# "install it by `pip install notebook nbclient nbformat`",
+# )
+
+# self.nb = nbformat.v4.new_notebook()
+# self.nb_client = nbclient.NotebookClient(nb=self.nb)
+# self.timeout = timeout
+
+# asyncio.run(self._start_client())
+
+# def _output_parser(self, output: dict) -> str:
+# """Parse the output of the notebook cell and return str"""
+# if output["output_type"] == "stream":
+# return output["text"]
+# elif output["output_type"] == "execute_result":
+# return output["data"]["text/plain"]
+# elif output["output_type"] == "display_data":
+# if "image/png" in output["data"]:
+# file_path = self._save_image(output["data"]["image/png"])
+# return f"Displayed image saved to {file_path}"
+# else:
+# return "Unsupported display type"
+# elif output["output_type"] == "error":
+# return output["traceback"]
+# else:
+# logger.info(f"Unsupported output encountered: {output}")
+# return "Unsupported output encountered"
+
+# async def _start_client(self) -> None:
+# """start notebook client"""
+# if self.nb_client.kc is None or not await self.nb_client.kc.is_alive():
+# self.nb_client.create_kernel_manager()
+# self.nb_client.start_new_kernel()
+# self.nb_client.start_new_kernel_client()
+
+# async def _kill_client(self) -> None:
+# """kill notebook client"""
+# if (
+# self.nb_client.km is not None
+# and await self.nb_client.km.is_alive()
+# ):
+# await self.nb_client.km.shutdown_kernel(now=True)
+# await self.nb_client.km.cleanup_resources()
+
+# self.nb_client.kc.stop_channels()
+# self.nb_client.kc = None
+# self.nb_client.km = None
+
+# async def _restart_client(self) -> None:
+# """Restart the notebook client"""
+# await self._kill_client()
+# self.nb_client = nbclient.NotebookClient(self.nb, timeout=self.timeout)
+# await self._start_client()
+
+# async def _run_cell(self, cell_index: int):
+# """Run a cell in the notebook by its index"""
+# try:
+# self.nb_client.execute_cell(self.nb.cells[cell_index], cell_index)
+# return [self._output_parser(output) for output in self.nb.cells[cell_index].outputs]
+# except nbclient.exceptions.DeadKernelError:
+# await self.reset_notebook()
+# return "DeadKernelError when executing cell, reset kernel"
+# except nbclient.exceptions.CellTimeoutError:
+# assert self.nb_client.km is not None
+# await self.nb_client.km.interrupt_kernel()
+# return (
+# "CellTimeoutError when executing cell"
+# ", code execution timeout"
+# )
+# except Exception as e:
+# return str(e)
+
+# @property
+# def cells_length(self) -> int:
+# """return cell length"""
+# return len(self.nb.cells)
+
+# async def async_run_code_on_notebook(self, code: str):
+# """
+# Run the code on interactive notebook
+# """
+# self.nb.cells.append(nbformat.v4.new_code_cell(code))
+# cell_index = self.cells_length - 1
+# return await self._run_cell(cell_index)
+
+# def run_code_on_notebook(self, code: str):
+# """
+# Run the code on interactive jupyter notebook.
+
+# Args:
+# code (`str`):
+# The Python code to be executed in the interactive notebook.
+
+# Returns:
+# `ServiceResponse`: whether the code execution was successful,
+# and the output of the code execution.
+# """
+# return asyncio.run(self.async_run_code_on_notebook(code))
+
+# def reset_notebook(self) -> str:
+# """
+# Reset the notebook
+# """
+# asyncio.run(self._restart_client())
+# return "Reset notebook"
+
+
+import os
+from loguru import logger
+
+try:
+ import os, sys
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+
+from muagent.sandbox import NBClientBox, NoteBookExecutor
+
+nbe = NoteBookExecutor()
+code = f"""
+x = 1
+y = 1
+z = x+y
+print(z)
+"""
+print(nbe.run_code_on_notebook(code))
+
+
+code = f"""z
+"""
+print(nbe.run_code_on_notebook(code))
+
+
+codebox = NBClientBox()
+
+reuslt = codebox.chat("```import os\nos.getcwd()```", do_code_exe=True)
+print(reuslt)
+
+reuslt = codebox.chat("```print('hello world!')```", do_code_exe=True)
+print(reuslt)
+
+with NBClientBox(do_code_exe=True) as codebox:
+ result = codebox.run("'hello world!'")
+ print(result)
\ No newline at end of file
diff --git a/tests/service/ekg_project_test.py b/tests/service/ekg_project_test.py
new file mode 100644
index 0000000..b829da4
--- /dev/null
+++ b/tests/service/ekg_project_test.py
@@ -0,0 +1,423 @@
+# -*- coding: utf-8 -*-
+from loguru import logger
+import os, sys
+import json
+
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.append(src_dir)
+try:
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+import logging
+# Set the logging level to WARNING, which will suppress INFO and DEBUG messages
+logging.basicConfig(level=logging.ERROR)
+
+
+from muagent import EKG, get_ekg_project_config_from_env
+# nodes = [{'id': 'haPvrjEkz4LARZyR7OAuPmVMHMIQPMew',
+# 'type': 'opsgptkg_intent',
+# 'attributes': {'description': '需要公司多人参与的事务,以及相关的问题', 'name': '公司事务'}},
+# {'id': 'dicVRAk5rT3y9LxcmBCN2jDi1TjHc5rm',
+# 'type': 'opsgptkg_intent',
+# 'attributes': {'description': '与个人有关的事务(如个人贷款),或遇到的个人问题,不涉及公司事务',
+# 'name': '个人事务'}},
+# {'id': 'ClKvwjBRZUJC7ttSZaiT0dh7lhSujNWi',
+# 'type': 'opsgptkg_intent',
+# 'attributes': {'description': '公司活动', 'name': '公司活动'}},
+# {'id': 'NyBXAHQckQx1xL5lnSgBGlotbZkkQ9C7',
+# 'type': 'opsgptkg_intent',
+# 'attributes': {'description': '金融(如借款、存款、贷款等)', 'name': '金融'}},
+# {'id': '6sa4zJCnVKJxKMtOtypapjZk4sdo93QU',
+# 'type': 'opsgptkg_intent',
+# 'attributes': {'description': '医疗(包括预约、挂号、看病、诊断等)', 'name': '医疗'}},
+# {'id': 'a8d85669_141a_4f54_ab8c_209c08d27c35',
+# 'type': 'opsgptkg_schedule',
+# 'attributes': {'description': '组织一次公司活动',
+# 'name': '组织一次公司活动',
+# 'enable': 'False'}},
+# {'id': '2b8df337_f29e_4d49_865f_84088c3a94e7',
+# 'type': 'opsgptkg_schedule',
+# 'attributes': {'description': '在线申请贷款',
+# 'name': '在线申请贷款',
+# 'enable': 'False'}},
+# {'id': 'b9fe38f1_33f6_468b_a1dd_43efdfd8e2d1',
+# 'type': 'opsgptkg_schedule',
+# 'attributes': {'description': '预约医生', 'name': '预约医生', 'enable': 'False'}},
+# {'id': '98234102_4e4a_4997_9b1e_3cda6382b1c7',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '确定活动主题:确定活动的主要目的(如团建、庆祝活动等)',
+# 'name': '确定活动主题:确定活动的主要目的(如团建、庆祝活动等)'}},
+# {'id': '59030678_760d_4a10_8d61_0d4e4cc5fbcb',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '访问贷款平台:输入网址并访问贷款申请网站',
+# 'name': '访问贷款平台:输入网址并访问贷款申请网站'}},
+# {'id': '5afab73b_8f03_422f_856e_386f183bdd71',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '选择医院/医生:访问医院官网或APP,查找相关科室和医生',
+# 'name': '选择医院/医生:访问医院官网或APP,查找相关科室和医生'}},
+# {'id': '95ec00ef_cc9c_4947_a21c_88eeb9a71af5',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '选择活动类型', 'name': '选择活动类型'}},
+# {'id': '5504af87_416e_4ee5_bfce_86b969a63433',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '注册/登录:如果你已经注册,输入用户名和密码进行登录。如果你还没有注册,点击“注册”按钮,填写个人信息,创建账户',
+# 'name': '注册/登录:如果你已经注册,输入用户名和密码进行登录。如果你还没有注册,点击“注册”按钮,填写个人信息,创建账户'}},
+# {'id': '3ff8f54a_fa65_4368_86ce_d65058035dd0',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '查看可预约时间:点击医生姓名,查看可预约时段',
+# 'name': '查看可预约时间:点击医生姓名,查看可预约时段'}},
+# {'id': 'd5e760b4_ae82_410d_a73d_4c0c98926ae5',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '室内活动', 'name': '室内活动'}},
+# {'id': '2a37b90a_fd96_4548_989c_7c1e8fa9d881',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '户外活动', 'name': '户外活动'}},
+# {'id': '88d4cf2b_7cf5_4e40_b54e_59268f119f63',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '选择贷款类型:浏览可用的贷款类型(如个人贷款、汽车贷款、房屋贷款),选择适合自己的贷款类型',
+# 'name': '选择贷款类型:浏览可用的贷款类型(如个人贷款、汽车贷款、房屋贷款),选择适合自己的贷款类型'}},
+# {'id': '39021995_6e63_4907_9d67_26ba50d0cd44',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '填写个人信息:输入姓名、联系方式等,选择预约时间',
+# 'name': '填写个人信息:输入姓名、联系方式等,选择预约时间'}},
+# {'id': '59fe9c1d_0731_403e_936a_2e2bbba4b3ee',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '选择具体的室内活动(如会议、晚会、游戏),确定场地和时间,准备相关的设备(如投影仪、音响),安排餐饮和娱乐节目,发出邀请通知',
+# 'name': '选择具体的室内活动(如会议、晚会、游戏),确定场地和时间,准备相关的设备(如投影仪、音响),安排餐饮和娱乐节目,发出邀请通知'}},
+# {'id': '60163dc6_87af_4972_b350_6b9275975c83',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '选择具体的户外活动(如远足、烧烤、运动会),确定地点和时间,安排交通工具和安全措施,联系供应商(如餐饮、设备租赁),发出邀请通知',
+# 'name': '选择具体的户外活动(如远足、烧烤、运动会),确定地点和时间,安排交通工具和安全措施,联系供应商(如餐饮、设备租赁),发出邀请通知'}},
+# {'id': '910f3634_b999_4cf3_94c9_346a67b0d5ed',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '填写申请表:提供个人信息(如姓名、年龄、收入等),提供贷款金额和贷款目的',
+# 'name': '填写申请表:提供个人信息(如姓名、年龄、收入等),提供贷款金额和贷款目的'}},
+# {'id': '1330ad69_dfc3_4538_864e_6867a3fd8dd4',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '确认预约:检查预约信息,点击“确认预约”按钮',
+# 'name': '确认预约:检查预约信息,点击“确认预约”按钮'}},
+# {'id': 'fcbc3e04_ad8c_4aad_9f75_191f8037ced8',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '预算审核:计算活动预估费用,提交预算给管理层审核',
+# 'name': '预算审核:计算活动预估费用,提交预算给管理层审核'}},
+# {'id': '2c7a0d7b_a490_41b9_a6f8_e71b5212e0be',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '提交资料:上传所需文件(如身份证、收入证明等)',
+# 'name': '提交资料:上传所需文件(如身份证、收入证明等)'}},
+# {'id': '3cd46fb7_e11c_4181_8670_2f080a453142',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '接收通知:收到预约确认短信或邮件',
+# 'name': '接收通知:收到预约确认短信或邮件'}},
+# {'id': '0f4610cd_cf6a_475b_8ac0_80166569a292',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '审核资料:系统开始审核申请', 'name': '审核资料:系统开始审核申请'}},
+# {'id': 'b9f81925_b43a_459d_9902_1bc4b024f5a1',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '审核通过', 'name': '审核通过'}},
+# {'id': '191687cd_1b76_4e77_9f2a_e67936dd372e',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '审核失败', 'name': '审核失败'}},
+# {'id': '18c33ec1_08ef_4df8_b938_7244852d19c8',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '用户收到“申请通过”的通知,前往下一步选择贷款期限和还款方式',
+# 'name': '用户收到“申请通过”的通知,前往下一步选择贷款期限和还款方式'}},
+# {'id': 'b73c2551_0890_40fb_b0ca_04912bc21b65',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '提供反馈,建议修改后重新申请', 'name': '提供反馈,建议修改后重新申请'}},
+# {'id': 'e95adaa2_d177_435b_bac7_a8b6047ecc3d',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '确认贷款条件:查看贷款条款和条件',
+# 'name': '确认贷款条件:查看贷款条款和条件'}},
+# {'id': '0c561d68_ee31_49d2_82c1_1dac81e731ff',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '拒绝条款', 'name': '拒绝条款'}},
+# {'id': '81f579ac_851d_4b85_8608_d2732a2612ff',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '接受条款', 'name': '接受条款'}},
+# {'id': '1f0b64aa_5d45_4cf5_bcdd_084b8c125889',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '选择“拒绝”并退出申请流程', 'name': '选择“拒绝”并退出申请流程'}},
+# {'id': '5fd5901a_8adc_4b76_aea2_dcf18884ea0e',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '点击“接受”并继续', 'name': '点击“接受”并继续'}},
+# {'id': '8c999c60_baa7_4e74_903b_f10f148dd12f',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '签署合同:在线签署贷款合同', 'name': '签署合同:在线签署贷款合同'}},
+# {'id': 'e1004c60_5c0c_4f32_b765_a57cc4d39dcc',
+# 'type': 'opsgptkg_analysis',
+# 'attributes': {'summaryswitch': 'False',
+# 'description': '根据提示前往医院就诊',
+# 'name': '根据提示前往医院就诊'}},
+# {'id': 'c50ff5e3_aa01_4a6c_96d7_d8645303846d',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '活动宣传:制作宣传材料(如海报、邮件通知),在公司内部推广活动信息',
+# 'name': '活动宣传:制作宣传材料(如海报、邮件通知),在公司内部推广活动信息'}},
+# {'id': '4f540a57_f73d_451e_aafb_43f1335a18a7',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '活动实施:根据选择的活动类型,执行相关安排,进行现场协调(无论是户外还是室内)',
+# 'name': '活动实施:根据选择的活动类型,执行相关安排,进行现场协调(无论是户外还是室内)'}},
+# {'id': 'c9952fa7_7f82_4737_8cfd_bdbb2dabb20e',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '活动反馈:收集参与者的反馈意见,总结活动的成功之处和改进建议',
+# 'name': '活动反馈:收集参与者的反馈意见,总结活动的成功之处和改进建议'}},
+# {'id': 'ekg_team_default',
+# 'type': 'opsgptkg_intent',
+# 'attributes': {'description': '团队起始节点', 'name': '开始'}}]
+
+# nodes = [{'id': '剧本杀/谁是卧底',
+# 'type': 'opsgptkg_intent',
+# 'attributes': {'description': '谁是卧底', 'name': '谁是卧底', 'extra': ''}},
+# {'id': '剧本杀/狼人杀',
+# 'type': 'opsgptkg_intent',
+# 'attributes': {'description': '狼人杀', 'name': '狼人杀', 'extra': ''}},
+# {'id': '剧本杀/谁是卧底/智能交互',
+# 'type': 'opsgptkg_schedule',
+# 'attributes': {'extra': '',
+# 'description': '智能交互',
+# 'name': '智能交互',
+# 'enable': True}},
+# {'id': '剧本杀/狼人杀/智能交互',
+# 'type': 'opsgptkg_schedule',
+# 'attributes': {'extra': '',
+# 'description': '智能交互',
+# 'name': '智能交互',
+# 'enable': False}},
+# {'id': '剧本杀/谁是卧底/智能交互/分配座位',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'extra': '{"dodisplay":"True"}',
+# 'executetype': '',
+# 'description': '分配座位',
+# 'name': '分配座位',
+# 'accesscriteria': ''}},
+# {'id': '剧本杀/狼人杀/智能交互/位置选择',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '位置选择',
+# 'name': '位置选择',
+# 'accesscriteria': '',
+# 'extra': '{"memory_tag": "all"}',
+# 'executetype': ''}},
+# {'id': '剧本杀/谁是卧底/智能交互/角色分配和单词分配',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'accesscriteria': '',
+# 'extra': '{"memory_tag": "None","dodisplay":"True"}',
+# 'executetype': '',
+# 'description': '角色分配和单词分配',
+# 'name': '角色分配和单词分配'}},
+# {'id': '剧本杀/狼人杀/智能交互/角色选择',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '角色选择',
+# 'name': '角色选择',
+# 'accesscriteria': '',
+# 'extra': '{"memory_tag": "None"}',
+# 'executetype': ''}},
+# {'id': '剧本杀/谁是卧底/智能交互/通知身份',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'extra': '{"pattern": "react","dodisplay":"True"}',
+# 'executetype': '',
+# 'description': '##角色##\n你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。\n目前已经完成 1)位置分配; 2)角色分配和单词分配。\n##任务##\n向所有玩家通知信息他们的 座位信息和单词信息。\n发送格式是: 【身份通知】你是{player_name}, 你的位置是{位置号}号, 你分配的单词是{单词}\n##详细步骤##\nstep1.依次向所有玩家通知信息他们的 座位信息和单词信息。发送格式是: 你是{player_name}, 你的位置是{位置号}号, 你分配的单词是{单词}\nstpe2.所有玩家信息都发送后,结束\n\n##注意##\n1. 每条信息只能发送给对应的玩家,其他人无法看到。\n2. 不要告诉玩家的角色信息,即不要高斯他是平民还是卧底角色\n3. 在将每个人的信息通知到后,本阶段任务结束\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为JSON,格式为\n[{"action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}, ...]\n\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n#example#\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{ "memory_tag":["agent_name_a","agent_name_b"],"content": "str"}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["agent_name_a","agent_name_b"], "content": "str",}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请输出jsonstr,不用输出markdown格式\n5. 结合已有的步骤,每次只输出下一个步骤,即一个 {"action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}',
+# 'name': '通知身份',
+# 'accesscriteria': ''}},
+# {'id': '剧本杀/狼人杀/智能交互/向玩家通知消息',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'extra': '{"pattern": "react"}',
+# 'executetype': '',
+# 'description': '##角色##\n你正在参与狼人杀这个游戏,你的角色是[主持人]。你熟悉狼人杀游戏的完整流程,你需要完成[任务],保证狼人杀游戏的顺利进行。\n目前已经完成位置分配和角色分配。\n##任务##\n向所有玩家通知信息他们的座位信息和角色信息。\n发送格式是: 你是{player_name}, 你的位置是{位置号}号,你的身份是{角色名}\n##注意##\n1. 每条信息只能发送给对应的玩家,其他人无法看到。\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为Python可解析的JSON,格式为\n\n[{"action": {player_name, agent_name}, "observation" or "Dungeon_Master": [{content, memory_tag}, ...]}]\n\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{"content": "str", "memory_tag":["agent_name_a","agent_name_b"]}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{"content": "str", memory_tag:["agent_name_a","agent_name_b"]}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请直接输出jsonstr,不用输出markdown格式\n\n##结果##',
+# 'name': '向玩家通知消息',
+# 'accesscriteria': ''}},
+# {'id': '剧本杀/谁是卧底/智能交互/关键信息_1',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'executetype': '',
+# 'description': '关键信息',
+# 'name': '关键信息',
+# 'accesscriteria': '',
+# 'extra': '{"ignorememory":"True","dodisplay":"True"}'}},
+# {'id': '剧本杀/狼人杀/智能交互/狼人时刻',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'accesscriteria': 'OR',
+# 'extra': '{"pattern": "react"}',
+# 'executetype': '',
+# 'description': '##背景##\n在狼人杀游戏中,主持人通知当前存活的狼人玩家指认一位击杀对象,所有狼人玩家给出击杀目标,主持人确定最终结果。\n\n##任务##\n整个流程分为6个步骤:\n1. 存活狼人通知:主持人向所有的狼人玩家广播,告知他们当前存活的狼人玩家有哪些。\n2. 第一轮讨论:主持人告知所有存活的狼人玩家投票,从当前存活的非狼人玩家中,挑选一个想要击杀的玩家。\n3. 第一轮投票:按照座位顺序,每一位存活的狼人为自己想要击杀的玩家投票。\n4. 第一轮结果反馈:主持人统计所有狼人的票数分布,确定他们是否达成一致。若达成一致,告知所有狼人最终被击杀的玩家的player_name,流程结束;否则,告知他们票数的分布情况,并让所有狼人重新投票指定击杀目标,主持人需要提醒他们,若该轮还不能达成一致,则取票数最大的目标为最终击杀对象。\n5. 第二轮投票:按照座位顺序,每一位存活的狼人为自己想要击杀的玩家投票。\n6. 第二轮结果反馈:主持人统计第二轮投票中所有狼人的票数分布,取票数最大的玩家为最终击杀对象,如果存在至少两个对象的票数最大且相同,取座位号最大的作为最终击杀对象。主持人告知所有狼人玩家最终被击杀的玩家的player_name。\n\n该任务的参与者只有狼人玩家和主持人,信息可见对象是所有狼人玩家。\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为Python可解析的JSON,格式为\n\n[{"action": {player_name, agent_name}, "observation" or "Dungeon_Master": [{content, memory_tag}, ...]}]\n\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{"content": "str", "memory_tag":["agent_name_a","agent_name_b"]}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{"content": "str", memory_tag:["agent_name_a","agent_name_b"]}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请直接输出jsonstr,不用输出markdown格式\n\n##结果##',
+# 'name': '狼人时刻'}},
+# {'id': '剧本杀/谁是卧底/智能交互/开始新一轮的讨论',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'accesscriteria': 'OR',
+# 'extra': '{"pattern": "react", "endcheck": "True",\n"memory_tag":"all",\n"dodisplay":"True"}',
+# 'executetype': '',
+# 'description': '###以上为本局游戏记录###\n\n\n##背景##\n你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。\n\n##任务##\n以结构化的语句来模拟进行 谁是卧底的讨论环节。 在这一个环节里,所有主持人先宣布目前存活的玩家,然后每位玩家按照座位顺序发言\n\n\n##详细步骤##\nstep1. 主持人根据本局游戏历史记录,感知最开始所有的玩家 以及 在前面轮数中已经被票选死亡的玩家。注意死亡的玩家不能参与本轮游戏。得到当前存活的玩家个数以及其player_name。 并告知所有玩家当前存活的玩家个数以及其player_name。\nstep2. 主持人确定发言规则并告知所有玩家,发言规则步骤如下: 存活的玩家按照座位顺序由小到大进行发言\n(一个例子:假设总共有5个玩家,如果3号位置处玩家死亡,则发言顺序为:1_>2_>4_>5)\nstep3. 存活的的玩家按照顺序依次发言\nstpe4. 在每一位存活的玩家都发言后,结束\n\n \n \n##注意##\n1.之前的游戏轮数可能已经投票选中了某位/某些玩家,被票选中的玩家会立即死亡,不再视为存活玩家,死亡的玩家不能参与本轮游戏 \n2.你要让所有存活玩家都参与发言,不能遗漏任何存活玩家。在本轮所有玩家只发言一次\n3.该任务的参与者为主持人和所有存活的玩家,信息可见对象为所有玩家。\n4.不仅要模拟主持人的发言,还需要模拟玩家的发言\n5.每一位存活的玩家均发完言后,本阶段结束\n\n\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为JSON,格式为\n[ {"thought": str, "action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}] }, ...]\n\n\n\n\n关键词含义如下:\n_ thought (str): 主持人执行行动的一些思考,包括分析玩家的存活状态,对历史对话信息的理解,对当前任务情况的判断等。 \n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空 ;否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为本条信息的可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"thought": "str", "action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{ "memory_tag":["agent_name_a","agent_name_b"],"content": "str"}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"thought": "str", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["agent_name_a","agent_name_b"], "content": "str",}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请输出jsonstr,不用输出markdown格式\n5. 结合已有的步骤,每次只输出下一个步骤,即一个 {"thought": str, "action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}\n6. 如果是人类玩家发言, 一定要选择类似 agent_人类玩家 这样的agent_name',
+# 'name': '开始新一轮的讨论'}},
+# {'id': '剧本杀/狼人杀/智能交互/天亮讨论',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'executetype': '',
+# 'description': '##角色##\n你正在参与狼人杀这个游戏,你的角色是[主持人]。你熟悉狼人杀游戏的完整流程,你需要完成[任务],保证狼人杀游戏的顺利进行。\n##任务##\n你的任务如下: \n1. 告诉玩家昨晚发生的情况: 首先告诉玩家天亮了,然后你需要根据过往信息,告诉所有玩家,昨晚是否有玩家死亡。如果有,则向所有人宣布死亡玩家的名字,你只能宣布死亡玩家是谁如:"昨晚xx玩家死了",不要透露任何其他信息。如果没有,则宣布昨晚是平安夜。\n2. 确定发言规则并告诉所有玩家:\n确定发言规则步骤如下: \n第一步:确定第一个发言玩家,第一个发言的玩家为死者的座位号加1位置处的玩家(注意:最后一个位置+1的位置号为1号座位),如无人死亡,则从1号玩家开始。\n第二步:告诉所有玩家从第一个发言玩家开始发言,除了死亡玩家,每个人都需要按座位号依次讨论,只讨论一轮,所有人发言完毕后结束。注意不能遗忘指挥任何存活玩家发言!\n以下是一个例子:\n```\n总共有5个玩家,如果3号位置处玩家死亡,则第一个发言玩家为4号位置处玩家,因此从他开始发言,发言顺序为:4_>5_>1_>2\n```\n3. 依次指定存活玩家依次发言\n4. 被指定的玩家依次发言\n##注意##\n1. 你必须根据规则确定第一个发言玩家是谁,然后根据第一个发言玩家的座位号,确定所有人的发言顺序并将具体发言顺序并告知所有玩家,不要做任何多余解释\n2. 你要让所有存活玩家都参与发言,不能遗漏任何存活玩家\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为Python可解析的JSON,格式为\n\n[{"action": {player_name, agent_name}, "observation" or "Dungeon_Master": [{content, memory_tag}, ...]}]\n\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{"content": "str", "memory_tag":["agent_name_a","agent_name_b"]}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{"content": "str", memory_tag:["agent_name_a","agent_name_b"]}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请直接输出jsonstr,不用输出markdown格式\n\n##结果(请直接在后面输出,如果后面已经有部分结果,请续写。一定要保持续写后的内容结合前者能构成一个合法的 jsonstr)##',
+# 'name': '天亮讨论',
+# 'accesscriteria': '',
+# 'extra': '{"pattern": "react"}'}},
+# {'id': '剧本杀/谁是卧底/智能交互/关键信息_2',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '关键信息',
+# 'name': '关键信息',
+# 'accesscriteria': '',
+# 'extra': '{"ignorememory":"True","dodisplay":"True"}',
+# 'executetype': ''}},
+# {'id': '剧本杀/狼人杀/智能交互/票选凶手',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'accesscriteria': '',
+# 'extra': '{"pattern": "react"}',
+# 'executetype': '',
+# 'description': '##角色##\n你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。\n\n##任务##\n你的任务如下:\n1. 告诉玩家投票规则,规则步骤如下: \nstep1: 确定讨论阶段第一个发言的玩家A\nstep2: 从A玩家开始,按座位号依次投票,每个玩家只能对一个玩家进行投票,投票这个玩家表示认为该玩家是“卧底”。每个玩家只能投一次票。\nstep3: 将完整投票规则告诉所有玩家\n2. 指挥存活玩家依次投票。\n3. 被指定的玩家进行投票\n4. 主持人统计投票结果,并告知所有玩家,投出的角色是谁。\n\n该任务的参与者为主持人和所有存活的玩家,信息可见对象是所有玩家。\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为Python可解析的JSON,格式为\n```\n{"action": {player_name, agent_name}, "observation" or "Dungeon_Master": [{content, memory_tag}, ...]}\n```\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{"content": "str", "memory_tag":["agent_name_a","agent_name_b"]}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{"content": "str", memory_tag:["agent_name_a","agent_name_b"]}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请直接输出jsonstr,不用输出markdown格式\n\n##结果##\n',
+# 'name': '票选凶手'}},
+# {'id': '剧本杀/谁是卧底/智能交互/票选卧底_1',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'executetype': '',
+# 'description': '##以上为本局游戏历史记录##\n##角色##\n你是一个统计票数大师,你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。 现在是投票阶段。\n\n##任务##\n以结构化的语句来模拟进行 谁是卧底的投票环节, 也仅仅只模拟投票环节,投票环节结束后就本阶段就停止了,由后续的阶段继续进行游戏。 在这一个环节里,由主持人先告知大家投票规则,然后组织每位存活玩家按照座位顺序发言投票, 所有人投票后,本阶段结束。 \n##详细步骤##\n你的任务如下:\nstep1. 向所有玩家通知现在进入了票选环节,在这个环节,每个人都一定要投票指定某一个玩家为卧底\nstep2. 主持人确定投票顺序并告知所有玩家。 投票顺序基于如下规则: 1: 存活的玩家按照座位顺序由小到大进行投票(一个例子:假设总共有5个玩家,如果3号位置处玩家死亡,则投票顺序为:1_>2_>4_>5)2: 按座位号依次投票,每个玩家只能对一个玩家进行投票。每个玩家只能投一次票。3:票数最多的玩家会立即死亡\n\nstep3. 存活的的玩家按照顺序进行投票\nstep4. 所有存活玩家发言完毕,主持人宣布投票环节结束\n该任务的参与者为主持人和所有存活的玩家,信息可见对象是所有玩家。\n##注意##\n\n1.之前的游戏轮数可能已经投票选中了某位/某些玩家,被票选中的玩家会立即死亡,不再视为存活玩家 \n2.你要让所有存活玩家都参与投票,不能遗漏任何存活玩家。在本轮每一位玩家只投票一个人\n3.该任务的参与者为主持人和所有存活的玩家,信息可见对象为所有玩家。\n4.不仅要模拟主持人的发言,还需要模拟玩家的发言\n5.不允许玩家自己投自己,如果出现了这种情况,主持人会提醒玩家重新投票。\n\n\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为JSON,格式为\n["thought": str, {"action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}, ...]\n关键词含义如下:\n_ thought (str): 主持人执行行动的一些思考,包括分析玩家的存活状态,对历史对话信息的理解,对当前任务情况的判断。 \n_ player_name (str): ***的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): ***的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n##example##\n如果是玩家发言,则用 {"thought": "str", "action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{ "memory_tag":["agent_name_a","agent_name_b"],"content": "str"}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"thought": "str", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["agent_name_a","agent_name_b"], "content": "str",}]}\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请输出jsonstr,不用输出markdown格式\n5. 结合已有的步骤,每次只输出下一个步骤,即一个 {"thought": str, "action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}\n6. 如果是人类玩家发言, 一定要选择类似 人类agent 这样的agent_name',
+# 'name': '票选卧底',
+# 'accesscriteria': '',
+# 'extra': '{"pattern": "react", "endcheck": "True", "memory_tag":"all","dodisplay":"True"}'}},
+# {'id': '剧本杀/谁是卧底/智能交互/关键信息_4',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'extra': '{"ignorememory":"True","dodisplay":"True"}',
+# 'executetype': '',
+# 'description': '关键信息_4',
+# 'name': '关键信息_4',
+# 'accesscriteria': ''}},
+# {'id': '剧本杀/谁是卧底/智能交互/统计票数',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'executetype': '',
+# 'description': '##以上为本局游戏历史记录##\n##角色##\n你是一个统计票数大师,你非常擅长计数以及统计信息。你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。 现在是票数统计阶段\n\n##任务##\n以结构化的语句来模拟进行 谁是卧底的票数统计阶段, 也仅仅只票数统计阶段环节,票数统计阶段结束后就本阶段就停止了,由后续的阶段继续进行游戏。 在这一个环节里,由主持人根据上一轮存活的玩家投票结果统计票数。 \n##详细步骤##\n你的任务如下:\nstep1. 主持人感知上一轮投票环节每位玩家的发言, 统计投票结果,格式为[{"player_name":票数}]. \nstep2 然后,主持人宣布死亡的玩家,以最大票数为本轮被投票的目标,如果票数相同,则取座位号高的角色死亡。并告知所有玩家本轮被投票玩家的player_name。(格式为【重要通知】本轮死亡的玩家为XXX)同时向所有玩家宣布,被投票中的角色会视为立即死亡(即不再视为存活角色)\nstep3. 在宣布死亡玩家后,本阶段流程结束,由后续阶段继续推进游戏\n该任务的参与者为主持人和所有存活的玩家,信息可见对象是所有玩家。\n##注意##\n1.如果有2个或者两个以上的被玩家被投的票数相同,则取座位号高的玩家死亡。并告知大家原因:票数相同,取座位号高的玩家死亡\n2.在统计票数时,首先确认存活玩家的数量,再先仔细回忆,谁被投了。 最后统计每位玩家被投的次数。 由于每位玩家只有一票,所以被投次数的总和等于存活玩家的数量 \n3.通知完死亡玩家是谁后,本阶段才结束,由后续阶段继续推进游戏。输出 {"action": "taskend"}即可\n4.主持人只有当通知本轮死亡的玩家时,才使用【重要通知】的前缀,其他情况下不要使用【重要通知】前缀\n5.只统计上一轮投票环节的情况\n##example##\n{"thought": "在上一轮中, 存活玩家有 小北,李光,赵鹤,张良 四个人。 其中 小北投了李光, 赵鹤投了小北, 张良投了李光, 李光投了张良。总结被投票数为: 李光:2票; 小北:1票,张良:1票. Check一下,一共有四个人投票了,被投的票是2(李光)+1(小北)+1(张良)=4,总结被投票数没有问题。 因此李光的票最多", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["all"], "content": "李光:2票; 小北:1票,张良:1票 .因此李光的票最多.【重要通知】本轮死亡玩家是李光",}]}\n\n##example##\n{"thought": "在上一轮中, 存活玩家有 小北,人类玩家,赵鹤,张良 四个人。 其中 小北投了人类玩家, 赵鹤投了小北, 张良投了小北, 人类玩家投了张良。总结被投票数为:小北:2票,人类玩家:1票,张良:0票 .Check一下,一共有四个人投票了,被投的票是2(小北)+1(人类玩家)+张良(0)=3,总结被投票数有问题。 更正总结被投票数为:小北:2票,人类玩家:1票,张良:1票。因此小北的票最多", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["all"], "content": "小北:2票,人类玩家:1票,张良:1票 .因此小北的票最多.【重要通知】本轮死亡玩家是小北",}]}\n\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为JSON,格式为\n["thought": str, {"action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}, ...]\n关键词含义如下:\n_ thought (str): 主持人执行行动的一些思考,包括分析玩家的存活状态,对历史对话信息的理解,对当前任务情况的判断。 \n_ player_name (str): ***的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): ***的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n##example##\n如果是玩家发言,则用 {"thought": "str", "action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{ "memory_tag":["agent_name_a","agent_name_b"],"content": "str"}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"thought": "str", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["agent_name_a","agent_name_b"], "content": "str",}]}\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请输出jsonstr,不用输出markdown格式\n5. 结合已有的步骤,每次只输出下一个步骤,即一个 {"thought": str, "action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}\n6. 如果是人类玩家发言, 一定要选择类似 人类agent 这样的agent_name',
+# 'name': '统计票数',
+# 'accesscriteria': '',
+# 'extra': '{"pattern": "react", "endcheck": "True", "memory_tag":"all","model_name":"gpt_4","dodisplay":"True"}'}},
+# {'id': '剧本杀/谁是卧底/智能交互/关键信息_3',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'accesscriteria': '',
+# 'extra': '{"ignorememory":"True","dodisplay":"True"}',
+# 'executetype': '',
+# 'description': '关键信息',
+# 'name': '关键信息'}},
+# {'id': '剧本杀/谁是卧底/智能交互/判断游戏是否结束',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '判断游戏是否结束',
+# 'name': '判断游戏是否结束',
+# 'accesscriteria': '',
+# 'extra': '{"memory_tag": "None","dodisplay":"True"}',
+# 'executetype': ''}},
+# {'id': '剧本杀/谁是卧底/智能交互/事实_1',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '是', 'name': '是', 'extra': ''}},
+# {'id': '剧本杀/谁是卧底/智能交互/事实_2',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '否', 'name': '否', 'extra': ''}},
+# {'id': '剧本杀/谁是卧底/智能交互/给出每个人的单词以及最终胜利者',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'extra': '{"dodisplay":"True"}',
+# 'executetype': '',
+# 'description': '给出每个人的单词以及最终胜利者',
+# 'name': '给出每个人的单词以及最终胜利者',
+# 'accesscriteria': ''}},
+# {'id': '剧本杀/狼人杀/智能交互/判断游戏是否结束',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'description': '判断游戏是否结束 ',
+# 'name': '判断游戏是否结束 ',
+# 'accesscriteria': '',
+# 'extra': '{"memory_tag": "None"}',
+# 'executetype': ''}},
+# {'id': '剧本杀/狼人杀/智能交互/事实_2',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'extra': '', 'description': '否', 'name': '否'}},
+# {'id': '剧本杀/狼人杀/智能交互/事实_1',
+# 'type': 'opsgptkg_phenomenon',
+# 'attributes': {'description': '是', 'name': '是', 'extra': ''}},
+# {'id': '剧本杀/l狼人杀/智能交互/宣布游戏胜利者',
+# 'type': 'opsgptkg_task',
+# 'attributes': {'extra': '',
+# 'executetype': '',
+# 'description': '判断游戏是否结束',
+# 'name': '判断游戏是否结束',
+# 'accesscriteria': ''}},
+# {'id': '剧本杀',
+# 'type': 'opsgptkg_intent',
+# 'attributes': {'description': '文本游戏相关(如狼人杀等)', 'name': '剧本杀', 'extra': ''}}]
+
+# edges = [('剧本杀', '剧本杀/谁是卧底'),
+# ('剧本杀', '剧本杀/狼人杀'),
+# ('剧本杀/谁是卧底', '剧本杀/谁是卧底/智能交互'),
+# ('剧本杀/狼人杀', '剧本杀/狼人杀/智能交互'),
+# ('剧本杀/谁是卧底/智能交互', '剧本杀/谁是卧底/智能交互/分配座位'),
+# ('剧本杀/狼人杀/智能交互', '剧本杀/狼人杀/智能交互/位置选择'),
+# ('剧本杀/谁是卧底/智能交互/分配座位', '剧本杀/谁是卧底/智能交互/角色分配和单词分配'),
+# ('剧本杀/狼人杀/智能交互/位置选择', '剧本杀/狼人杀/智能交互/角色选择'),
+# ('剧本杀/谁是卧底/智能交互/角色分配和单词分配', '剧本杀/谁是卧底/智能交互/通知身份'),
+# ('剧本杀/狼人杀/智能交互/角色选择', '剧本杀/狼人杀/智能交互/向玩家通知消息'),
+# ('剧本杀/谁是卧底/智能交互/通知身份', '剧本杀/谁是卧底/智能交互/关键信息_1'),
+# ('剧本杀/狼人杀/智能交互/向玩家通知消息', '剧本杀/狼人杀/智能交互/狼人时刻'),
+# ('剧本杀/谁是卧底/智能交互/关键信息_1', '剧本杀/谁是卧底/智能交互/开始新一轮的讨论'),
+# ('剧本杀/狼人杀/智能交互/狼人时刻', '剧本杀/狼人杀/智能交互/天亮讨论'),
+# ('剧本杀/谁是卧底/智能交互/开始新一轮的讨论', '剧本杀/谁是卧底/智能交互/关键信息_2'),
+# ('剧本杀/狼人杀/智能交互/天亮讨论', '剧本杀/狼人杀/智能交互/票选凶手'),
+# ('剧本杀/谁是卧底/智能交互/关键信息_2', '剧本杀/谁是卧底/智能交互/票选卧底_1'),
+# ('剧本杀/谁是卧底/智能交互/票选卧底_1', '剧本杀/谁是卧底/智能交互/关键信息_4'),
+# ('剧本杀/谁是卧底/智能交互/关键信息_4', '剧本杀/谁是卧底/智能交互/统计票数'),
+# ('剧本杀/谁是卧底/智能交互/统计票数', '剧本杀/谁是卧底/智能交互/关键信息_3'),
+# ('剧本杀/谁是卧底/智能交互/关键信息_3', '剧本杀/谁是卧底/智能交互/判断游戏是否结束'),
+# ('剧本杀/谁是卧底/智能交互/判断游戏是否结束', '剧本杀/谁是卧底/智能交互/事实_1'),
+# ('剧本杀/谁是卧底/智能交互/判断游戏是否结束', '剧本杀/谁是卧底/智能交互/事实_2'),
+# ('剧本杀/谁是卧底/智能交互/事实_1', '剧本杀/谁是卧底/智能交互/给出每个人的单词以及最终胜利者'),
+# ('剧本杀/谁是卧底/智能交互/事实_2', '剧本杀/谁是卧底/智能交互/开始新一轮的讨论'),
+# ('剧本杀/狼人杀/智能交互/票选凶手', '剧本杀/狼人杀/智能交互/判断游戏是否结束'),
+# ('剧本杀/狼人杀/智能交互/判断游戏是否结束', '剧本杀/狼人杀/智能交互/事实_2'),
+# ('剧本杀/狼人杀/智能交互/判断游戏是否结束', '剧本杀/狼人杀/智能交互/事实_1'),
+# ('剧本杀/狼人杀/智能交互/事实_2', '剧本杀/狼人杀/智能交互/狼人时刻'),
+# ('剧本杀/狼人杀/智能交互/事实_1', '剧本杀/l狼人杀/智能交互/宣布游戏胜利者'),
+# ('ekg_team_default', '剧本杀')
+# ]
+
+
+
+tools = [
+ "谁是卧底-座位分配", "谁是卧底-角色分配", "谁是卧底-结果输出", "谁是卧底-胜利条件判断",
+ "谁是卧底-张伟", "谁是卧底-李静", "谁是卧底-王鹏",
+]
+
+# tools = [
+# "狼人杀-角色分配工具", "狼人杀-座位分配", "狼人杀-胜利条件判断", "狼人杀-结果输出",
+# '狼人杀-agent_朱丽', '狼人杀-agent_周杰', '狼人杀-agent_沈强', '狼人杀-agent_韩刚',
+# '狼人杀-agent_梁军', '狼人杀-agent_周欣怡', '狼人杀-agent_贺子轩'
+# ]
+
+AGENT_CONFIGS = {
+ "codefuse_function_caller": {
+ "config_name": "codefuse_function_caller",
+ "agent_type": "FunctioncallAgent",
+ "agent_name": "codefuse_function_caller",
+ "llm_config_name": "qwen_chat",
+ "tools": tools,
+ }
+}
+os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS)
+
+project_config = get_ekg_project_config_from_env()
+ekg = EKG(project_config=project_config, initialize_space=False)
+
+
+# # 添加节点
+# for node in nodes:
+# ekg.add_node(node)
+
+# # 添加边
+# for start_id, end_id in edges:
+# ekg.add_edge(start_id, end_id)
+
+response = ekg.run("我要玩谁是卧底!",rootid="ekg_team_default")
+for i in response:
+ pass
diff --git a/tests/test_config.py.example b/tests/test_config.py.example
index c1347ee..ac03016 100644
--- a/tests/test_config.py.example
+++ b/tests/test_config.py.example
@@ -1,6 +1,8 @@
import os, openai, base64
from loguru import logger
+os.environ["DM_llm_name"] = 'Qwen2_72B_Instruct_OpsGPT' #or gpt_4
+
# 兜底大模型配置
OPENAI_API_BASE = "https://api.openai.com/v1"
os.environ["API_BASE_URL"] = OPENAI_API_BASE
@@ -19,6 +21,78 @@ os.environ["gpt4-llm_temperature"] = "0.0"
+MODEL_CONFIGS = {
+ # old llm config
+ "default": {
+ "model_name": "gpt-3.5-turbo",
+ "model_engine": "qwen",
+ "temperature": "0",
+ "api_key": "",
+ "api_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
+ },
+ "codefuser":{
+ "model_name": "gpt-4",
+ "model_engine": "openai",
+ "temperature": "0",
+ "api_key": "",
+ "api_base_url": OPENAI_API_BASE,
+ },
+ # new llm config
+ "dashscope_chat": {
+ "model_type": "dashscope_chat",
+ "model_name": "qwen2.5-72b-instruct" ,
+ "api_key": "",
+ },
+ "moonshot_chat": {
+ "model_type": "moonshot_chat",
+ "model_name": "moonshot-v1-8k" ,
+ "api_key": "",
+ },
+ "ollama_chat": {
+ "model_type": "ollama_chat",
+ "model_name": "qwen2.5-0.5b",
+ "api_key": "",
+ },
+ "openai_chat": {
+ "model_type": "openai_chat",
+ "model_name": "gpt-4",
+ "api_key": "",
+ },
+ "qwen_chat": {
+ "model_type": "qwen_chat",
+ "model_name": "qwen2.5-72b-instruct",
+ "api_key": "",
+ },
+ "yi_chat": {
+ "model_type": "yi_chat",
+ "model_name": "yi-lightning" ,
+ "api_key": "",
+ },
+ # embedding configs
+ "dashscope_text_embedding": {
+ "model_type": "dashscope_text_embedding",
+ "model_name": "text-embedding-v3",
+ "api_key": "",
+ },
+ "ollama_embedding": {
+ "model_type": "ollama_embedding",
+ "model_name": "qwen2.5-0.5b",
+ "api_key": "",
+ },
+ "openai_embedding": {
+ "model_type": "openai_embedding",
+ "model_name": "text-embedding-ada-002",
+ "api_key": "",
+ },
+ "qwen_text_embedding": {
+ "model_type": "dashscope_text_embedding",
+ "model_name": "text-embedding-v3",
+ "api_key": "",
+ },
+}
+
+os.environ["MODEL_CONFIGS"] = json.dumps(MODEL_CONFIGS)
+
#### NebulaHandler ####
os.environ['nb_host'] = 'graphd'
os.environ['nb_port'] = '9669'
@@ -42,6 +116,36 @@ os.environ['tb_definition_value'] = 'message_test_new'
os.environ['tb_expire_time'] = '604800' #86400*7
+#################
+## DB_CONFIGS ##
+#################
+DB_CONFIGS = {
+ "gb_config": {
+ "gb_type": "NebulaHandler",
+ "extra_kwargs": {
+ 'host':'graphd',
+ 'port': '9669',
+ 'username': os.environ['nb_username'],
+ 'password': os.environ['nb_password'],
+ 'space': "client"
+ }
+ },
+ "tb_config": {
+ "tb_type": 'TBaseHandler',
+ "index_name": "opsgptkg",
+ "host": 'redis-stack',
+ "port": '6379',
+ "username": os.environ['tb_username'],
+ "password": os.environ['tb_password'],
+ "extra_kwargs": {
+ "definition_value": "opsgptkg",
+ "memory_definition_value": "opsgptkg_message"
+ }
+ }
+}
+os.environ["DB_CONFIGS"] = json.dumps(DB_CONFIGS)
+
+
########################################
########## 以下参数暂不涉及无需配置 ########
diff --git a/tests/tools/get_tool.py b/tests/tools/get_tool.py
new file mode 100644
index 0000000..b9155e9
--- /dev/null
+++ b/tests/tools/get_tool.py
@@ -0,0 +1,30 @@
+import os
+from loguru import logger
+
+try:
+ import os, sys
+ src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ sys.path.append(src_dir)
+ import test_config
+except Exception as e:
+ # set your config
+ logger.error(f"{e}")
+
+
+src_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+sys.path.append(src_dir)
+
+from muagent import get_tool
+from muagent.tools import toLangchainTools
+
+
+tools = toLangchainTools([get_tool("Multiplier")])
+
+print(get_tool("Multiplier").intput_to_json_schema())
+print(get_tool("Multiplier").output_to_json_schema())
+# tool run 测试
+print(tools[0].func(1,2))
\ No newline at end of file