Skip to content

add max_input_tokens for history management & add tests #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: Python Test

on:
pull_request:
branches:
- main
paths:
- "llm_coder/**"
- .github/workflows/python-test.yml

concurrency:
# ref for branch
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@65d7f2d534ac1bc67fcd62888c5f4f3d2cb2b236 # v4.7.1
id: setup-python
with:
python-version-file: "./pyproject.toml"
- name: Install uv
run: pip install uv
- name: Install dependencies with uv
run: |
uv sync
- name: Run lint with uv
run: |
uv run ruff check .
uv run ruff format --check .

test:
name: Test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@65d7f2d534ac1bc67fcd62888c5f4f3d2cb2b236 # v4.7.1
id: setup-python
with:
python-version-file: "./pyproject.toml"
- name: Install uv
run: pip install uv
- name: Install dependencies with uv
run: |
uv sync
- name: Run tests
run: |
uv run pytest tests
3 changes: 3 additions & 0 deletions llm-coder-config.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ max_iterations = 10
# LLM APIリクエスト1回あたりのタイムアウト秒数
request_timeout = 60

# LLMへの入力の最大トークン数 (省略可能。指定しない場合、モデルの最大入力トークン数がデフォルトとして試行されます)
# max_input_tokens = 2048

# ファイルシステム操作を許可するディレクトリのリスト
# デフォルトでは、CLIを実行したカレントワーキングディレクトリが許可されます。
# ここで指定すると、その設定がCLIのデフォルトよりも優先されます。
Expand Down
94 changes: 89 additions & 5 deletions llm_coder/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
final_summary_prompt: str = FINAL_SUMMARY_PROMPT, # 最終要約用プロンプト
repository_description_prompt: str = None, # リポジトリ説明プロンプト
request_timeout: int = 180, # 1回のリクエストに対するタイムアウト秒数(CLIから調整可能、デフォルト180)
max_input_tokens: int = None, # LLMの最大入力トークン数
):
self.model = model
self.temperature = temperature
Expand All @@ -94,6 +95,7 @@ def __init__(
self.final_summary_prompt = final_summary_prompt
# repository_description_prompt が None または空文字列の場合はそのまま None または空文字列を保持
self.repository_description_prompt = repository_description_prompt
self.max_input_tokens = max_input_tokens # 最大生成トークン数を設定

# 利用可能なツールを設定
self.available_tools = available_tools or []
Expand Down Expand Up @@ -121,6 +123,86 @@ def __init__(
else 0,
)

async def _get_messages_for_llm(self) -> List[Dict[str, Any]]:
"""
LLMに渡すメッセージリストを作成する。トークン数制限を考慮する。

Returns:
LLMに渡すメッセージの辞書リスト。
"""
if not self.conversation_history:
return []

messages_to_send = []
current_tokens = 0

# 1. 最初のシステムメッセージと最初のユーザープロンプトは必須
# 最初のシステムメッセージ
if self.conversation_history[0].role == "system":
system_message = self.conversation_history[0].to_dict()
messages_to_send.append(system_message)
if self.max_input_tokens is not None:
current_tokens += litellm.token_counter(
model=self.model, messages=[system_message]
)

# 最初のユーザーメッセージ (システムメッセージの次にあると仮定)
if (
len(self.conversation_history) > 1
and self.conversation_history[1].role == "user"
):
user_message = self.conversation_history[1].to_dict()
# 既にシステムメッセージが追加されているか確認
if not messages_to_send or messages_to_send[-1] != user_message:
# トークンチェック
if self.max_input_tokens is not None:
user_message_tokens = litellm.token_counter(
model=self.model, messages=[user_message]
)
if current_tokens + user_message_tokens <= self.max_input_tokens:
messages_to_send.append(user_message)
current_tokens += user_message_tokens
else:
raise ValueError(
f"最初のユーザーメッセージがトークン制限を超えています。必要なトークン数: {user_message_tokens}, 現在のトークン数: {current_tokens}, 最大トークン数: {self.max_input_tokens}"
)
else:
messages_to_send.append(user_message)

# 2. 最新の会話履歴からトークン制限を超えない範囲で追加
# 必須メッセージ以降の履歴を取得 (必須メッセージが2つと仮定)
remaining_history = self.conversation_history[2:]

temp_recent_messages: list[Dict[str, Any]] = []
for msg in reversed(remaining_history):
msg_dict = msg.to_dict()
if self.max_input_tokens is not None:
msg_tokens = litellm.token_counter(
model=self.model, messages=[msg_dict]
)
if current_tokens + msg_tokens <= self.max_input_tokens:
temp_recent_messages.insert(0, msg_dict) # 逆順なので先頭に追加
current_tokens += msg_tokens
else:
# トークン制限に達したらループを抜ける
logger.debug(
"トークン制限に達したため、これ以上過去のメッセージは含めません。",
message_content=msg_dict.get("content", "")[:50],
required_tokens=msg_tokens,
current_tokens=current_tokens,
max_tokens=self.max_input_tokens,
)
break
else:
temp_recent_messages.insert(0, msg_dict)

messages_to_send.extend(temp_recent_messages)

logger.debug(
f"LLMに渡すメッセージ数: {len(messages_to_send)}, トークン数: {current_tokens if self.max_input_tokens is not None else 'N/A'}"
)
return messages_to_send

async def _execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
"""指定されたツールを実行してその結果を返す"""
logger.debug("Executing tool", tool_name=tool_name, arguments=arguments)
Expand Down Expand Up @@ -209,7 +291,9 @@ async def _planning_phase(self, prompt: str) -> None:
try:
response = await litellm.acompletion(
model=self.model,
messages=[msg.to_dict() for msg in self.conversation_history],
messages=[
msg.to_dict() for msg in self.conversation_history
], # プランニングフェーズでは全履歴を使用することが多い
temperature=self.temperature,
tools=self.tools, # 更新されたツールリストを使用
timeout=self.request_timeout, # 1回のリクエスト用タイムアウト
Expand Down Expand Up @@ -287,7 +371,7 @@ async def _execution_phase(self) -> bool:

response = await litellm.acompletion(
model=self.model,
messages=[msg.to_dict() for msg in self.conversation_history],
messages=await self._get_messages_for_llm(), # 引数を削除
temperature=self.temperature,
tools=self.tools, # 更新されたツールリストを使用
timeout=self.request_timeout, # 1回のリクエスト用タイムアウト
Expand Down Expand Up @@ -390,7 +474,7 @@ async def _execution_phase(self) -> bool:
logger.debug("Getting next actions from LLM after tool executions")
response = await litellm.acompletion(
model=self.model,
messages=[msg.to_dict() for msg in self.conversation_history],
messages=await self._get_messages_for_llm(), # 引数を削除
temperature=self.temperature,
tools=self.tools, # 更新されたツールリストを使用
timeout=self.request_timeout, # 1回のリクエスト用タイムアウト
Expand Down Expand Up @@ -474,9 +558,9 @@ async def run(self, prompt: str) -> str:

final_response = await litellm.acompletion(
model=self.model,
messages=[msg.to_dict() for msg in self.conversation_history],
messages=await self._get_messages_for_llm(),
temperature=self.temperature,
tools=self.tools, # ツールパラメータを追加
tools=self.tools, # 使わないけど、ツールリストを提供して、Anthropicの要件を満たす
timeout=self.request_timeout, # 1回のリクエスト用タイムアウト
)
logger.debug(
Expand Down
19 changes: 19 additions & 0 deletions llm_coder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os # os モジュールをインポート
import sys # sys モジュールをインポート
import toml # toml をインポート
from litellm import get_model_info # get_model_info をインポート

# agent と filesystem モジュールをインポート
from llm_coder.agent import Agent
Expand Down Expand Up @@ -148,6 +149,15 @@ def parse_args():
help=f"LLM APIリクエスト1回あたりのタイムアウト秒数 (デフォルト: {request_timeout_default})",
)

# 最大入力トークン数のオプションを追加
max_input_tokens_default = config_values.get("max_input_tokens", None)
parser.add_argument(
"--max-input-tokens",
type=int,
default=max_input_tokens_default,
help="LLMの最大入力トークン数 (デフォルト: モデル固有の最大値)",
)

# remaining_argv を使って、--config 以外の引数を解析
return parser.parse_args(remaining_argv)

Expand Down Expand Up @@ -199,13 +209,22 @@ async def run_agent_from_cli(args):
logger.debug("Total available tools", tool_count=len(all_available_tools))

logger.debug("Initializing agent from CLI")

# 最大入力トークン数を決定
max_input_tokens = args.max_input_tokens
if max_input_tokens is None:
model_info = get_model_info(args.model)
if model_info and "max_input_tokens" in model_info:
max_input_tokens = model_info["max_input_tokens"]

agent_instance = Agent( # Agent クラスのインスタンス名変更
model=args.model,
temperature=args.temperature,
max_iterations=args.max_iterations,
available_tools=all_available_tools, # 更新されたツールリストを使用
repository_description_prompt=args.repository_description_prompt, # リポジトリ説明プロンプトを渡す
request_timeout=args.request_timeout, # LLM APIリクエストのタイムアウトを渡す
max_input_tokens=max_input_tokens, # 最大入力トークン数を渡す
)

logger.info("Starting agent run from CLI", prompt_length=len(prompt))
Expand Down
2 changes: 1 addition & 1 deletion playground/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Item(BaseModel):

@app.get("/")
def read_root():
return "Hello USA"
return "Hello USA!"


@app.get("/items/{item_id}")
Expand Down
2 changes: 1 addition & 1 deletion playground/server/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_read_root():
"""
response = client.get("/")
assert response.status_code == 200
assert response.json() == "Hello USA"
assert response.json() == "Hello USA!"


def test_read_item():
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ dependencies = [
"Bug Tracker" = "https://github.com/igtm/llm-coder/issues"

[dependency-groups]
dev = ["pytest>=8.3.5", "pytest-asyncio>=0.26.0", "ruff>=0.11.9"]
dev = [
"pytest>=8.3.5",
"pytest-asyncio>=0.26.0",
"ruff>=0.11.9",
]

[tool.setuptools.packages.find]
include = ["llm_coder*"]
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[pytest]
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function
Loading