Skip to content

[bugfix]v0.0.4 chat and codechat bug #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ package.sh
setup_test.py
build
*egg-info
dist
dist
.ipynb_checkpoints
4 changes: 2 additions & 2 deletions docs/overview/o1.muagent.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ aliases:
本项目的Mutli-Agent框架汲取兼容了多个框架的优秀设计,比如metaGPT中的消息池(message pool)、autogen中的代理选择器(agent selector)等。

<div align=center>
<img src="/docs/resources/muagent_framework.png" alt="图片" style="width: 500px; height:auto;">
<img src="/docs/resources/muAgent_framework.png" alt="图片" style="width: 500px; height:auto;">
</div>


Expand All @@ -42,7 +42,7 @@ aliases:
1. BaseAgent:提供基础问答、工具使用、代码执行的功能,根据Prompt格式实现 输入 => 输出

<div align=center>
<img src="/docs/resources/baseagent.png" alt="图片" style="width: 500px; height:auto;">
<img src="/docs/resources/BaseAgent.png" alt="图片" style="width: 500px; height:auto;">
</div>

2. ReactAgent:提供标准React的功能,根据问题实现当前任务
Expand Down
1 change: 0 additions & 1 deletion examples/muagent_examples/baseGroup_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
source_file = 'D://project/gitlab/llm/external/ant_code/Codefuse-chatbot/jupyter_work/employee_data.csv'
shutil.copy(source_file, JUPYTER_WORK_PATH)


# round-1
query_content = "确认本地是否存在employee_data.csv,并查看它有哪些列和数据类型;然后画柱状图"
query = Message(
Expand Down
13 changes: 9 additions & 4 deletions muagent/chat/agent_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def chat(
model_name: str = Body("", description="llm模型名称"),
temperature: float = Body(0.2, description=""),
chat_index: str = "",
local_graph_path: str = "",
**kargs
) -> Message:

Expand Down Expand Up @@ -122,7 +123,8 @@ def chat(
code_engine_name=code_engine_name,
score_threshold=score_threshold, top_k=top_k,
history_node_list=history_node_list,
tools=tools
tools=tools,
local_graph_path=local_graph_path
)
# history memory mangemant
history = Memory(messages=[
Expand Down Expand Up @@ -223,6 +225,7 @@ def achat(
model_name: str = Body("", description="llm模型名称"),
temperature: float = Body(0.2, description=""),
chat_index: str = "",
local_graph_path: str = "",
**kargs
) -> Message:

Expand Down Expand Up @@ -264,7 +267,8 @@ def achat(
cb_search_type=cb_search_type,
score_threshold=score_threshold, top_k=top_k,
history_node_list=history_node_list,
tools=tools
tools=tools,
local_graph_path=local_graph_path
)
# history memory mangemant
history = Memory(messages=[
Expand Down Expand Up @@ -292,7 +296,8 @@ def achat(

def chat_iterator(message: Message, local_memory: Memory, isDetailed=False):
step_content = local_memory.to_str_messages(content_key='step_content', filter_roles=["human"])
step_content = "\n\n".join([f"{v}" for parsed_output in local_memory.get_parserd_output_list()[1:] for k, v in parsed_output.items() if k not in ["Action Status"]])
step_content = "\n\n".join([f"{v}" for parsed_output in local_memory.get_parserd_output_list() for k, v in parsed_output.items() if k not in ["Action Status", "human", "user"]])
# logger.debug(f"{local_memory.get_parserd_output_list()}")
final_content = message.role_content
result = {
"answer": "",
Expand All @@ -311,7 +316,7 @@ def chat_iterator(message: Message, local_memory: Memory, isDetailed=False):
if node not in has_nodes:
related_nodes.append(node)
result["related_nodes"] = related_nodes

# logger.debug(f"{result['figures'].keys()}, isDetailed: {isDetailed}")
message_str = step_content
if self.stream:
Expand Down
16 changes: 9 additions & 7 deletions muagent/chat/code_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def check_service_status(self) -> BaseResponse:
return BaseResponse(code=404, msg=f"未找到代码库 {self.engine_name}")
return BaseResponse(code=200, msg=f"找到代码库 {self.engine_name}")

def _process(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig):
def _process(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, local_graph_path=""):
'''process'''

codes_res = search_code(query=query, cb_name=self.engine_name, code_limit=self.code_limit,
Expand All @@ -67,7 +67,8 @@ def _process(self, query: str, history: List[History], model, llm_config: LLMCon
embed_model_path=embed_config.embed_model_path,
embed_engine=embed_config.embed_engine,
model_device=embed_config.model_device,
embed_config=embed_config
embed_config=embed_config,
local_graph_path=local_graph_path
)

context = codes_res['context']
Expand Down Expand Up @@ -115,6 +116,7 @@ def chat(
model_name: str = Body("", ),
temperature: float = Body(0.5, ),
model_device: str = Body("", ),
local_graph_path: str=Body(", "),
**kargs
):
params = locals()
Expand All @@ -127,9 +129,9 @@ def chat(
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
self.request = request
self.cb_search_type = cb_search_type
return self._chat(query, history, llm_config, embed_config, **kargs)
return self._chat(query, history, llm_config, embed_config, local_graph_path, **kargs)

def _chat(self, query: str, history: List[History], llm_config: LLMConfig, embed_config: EmbedConfig, **kargs):
def _chat(self, query: str, history: List[History], llm_config: LLMConfig, embed_config: EmbedConfig, local_graph_path: str, **kargs):
history = [History(**h) if isinstance(h, dict) else h for h in history]

service_status = self.check_service_status()
Expand All @@ -140,7 +142,7 @@ def chat_iterator(query: str, history: List[History]):
# model = getChatModel()
model = getChatModelFromConfig(llm_config)

result, content = self.create_task(query, history, model, llm_config, embed_config, **kargs)
result, content = self.create_task(query, history, model, llm_config, embed_config, local_graph_path, **kargs)
# logger.info('result={}'.format(result))
# logger.info('content={}'.format(content))

Expand All @@ -156,9 +158,9 @@ def chat_iterator(query: str, history: List[History]):
return StreamingResponse(chat_iterator(query, history),
media_type="text/event-stream")

def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig):
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, local_graph_path: str):
'''构建 llm 生成任务'''
chain, context, result = self._process(query, history, model, llm_config, embed_config)
chain, context, result = self._process(query, history, model, llm_config, embed_config, local_graph_path)
logger.info('chain={}'.format(chain))
try:
content = chain({"context": context, "question": query})
Expand Down
6 changes: 4 additions & 2 deletions muagent/codechat/code_analyzer/code_intepreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def get_intepretation_batch(self, code_list):
messages.append(message)

try:
chat_ress = [chat_model(messages) for message in messages]
except:
chat_ress = [chat_model.predict(message) for message in messages]
except Exception as e:
logger.exception(f"{e}")
chat_ress = chat_model.batch(messages)

for chat_res, code in zip(chat_ress, code_list):
try:
res[code] = chat_res.content
Expand Down
34 changes: 21 additions & 13 deletions muagent/connector/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,19 +590,23 @@ def append_tools(self, tool_information: dict, chat_index: str, nodeid: str, use
}

for k, v in tool_map.items():
message = Message(
chat_index=chat_index,
message_index= f"{nodeid}-{uuid.uuid4()}",
user_name=user_name,
role_name = v["role_name"], # agent 名字,
role_type = v["role_type"], # agent 类型,默认assistant,可选observation
## llm output
role_content = tool_information[k], # 输入
customed_kargs = {
**{kk: vv for kk, vv in tool_information.items()
if kk in v.get("customed_keys", [])}
} # 存储docs、tool等信息
)
try:
message = Message(
chat_index=chat_index,
#message_index= f"{nodeid}-{uuid.uuid4()}",
message_index= f"{nodeid}-{k}",
user_name=user_name,
role_name = v["role_name"], # agent 名字,
role_type = v["role_type"], # agent 类型,默认assistant,可选observation
## llm output
role_content = tool_information[k], # 输入
customed_kargs = {
**{kk: vv for kk, vv in tool_information.items()
if kk in v.get("customed_keys", [])}
} # 存储docs、tool等信息
)
except:
pass
self.append(message)

def get_memory_pool(self, chat_index: str = "") -> Memory:
Expand Down Expand Up @@ -802,12 +806,16 @@ def tbasedoc2Memory(self, r_docs) -> Memory:
for doc in r_docs.docs:
tbase_message = {}
for k, v in doc.__dict__.items():
if k in ["role_content", "input_query"]:
tbase_message[k] = v
continue
try:
v = json.loads(v)
except:
pass

tbase_message[k] = v

message = Message(**tbase_message)
memory.append(message)

Expand Down
6 changes: 4 additions & 2 deletions muagent/service/cb_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def create_cb(zip_file,
temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]),
embed_config: EmbedConfig = None,
local_graph_path: str = '',
) -> BaseResponse:
logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret))

Expand All @@ -74,7 +75,7 @@ async def create_cb(zip_file,

try:
logger.info('start build code base')
cbh = CodeBaseHandler(cb_name, code_path, embed_config=embed_config, llm_config=llm_config)
cbh = CodeBaseHandler(cb_name, code_path, embed_config=embed_config, llm_config=llm_config, local_graph_path=local_graph_path)
vertices_num, edge_num, file_num = cbh.import_code(zip_file=zip_file, do_interpret=do_interpret)
logger.info('build code base done')

Expand All @@ -100,6 +101,7 @@ async def delete_cb(
temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]),
embed_config: EmbedConfig = None,
local_graph_path: str="",
) -> BaseResponse:
logger.info('cb_name={}'.format(cb_name))
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
Expand All @@ -119,7 +121,7 @@ async def delete_cb(
shutil.rmtree(CB_ROOT_PATH + os.sep + cb_name)

# delete from codebase
cbh = CodeBaseHandler(cb_name, embed_config=embed_config, llm_config=llm_config)
cbh = CodeBaseHandler(cb_name, embed_config=embed_config, llm_config=llm_config, local_graph_path=local_graph_path)
cbh.delete_codebase(codebase_name=cb_name)

except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="codefuse-muagent",
version="0.0.2",
version="0.0.4",
author="shanshi",
author_email="wyp311395@antgroup.com",
description="A multi-agent framework that facilitates the rapid construction of collaborative teams of agents.",
Expand Down