Skip to content

Commit cd48ad4

Browse files
authored
Merge pull request #23 from igtm/feature/igtm-dgif
add max_input_tokens for history management & add tests
2 parents 7d817c6 + b7283d5 commit cd48ad4

File tree

9 files changed

+299
-15
lines changed

9 files changed

+299
-15
lines changed

.github/workflows/python-test.yml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
name: Python Test
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
paths:
8+
- "llm_coder/**"
9+
- .github/workflows/python-test.yml
10+
11+
concurrency:
12+
# ref for branch
13+
group: ${{ github.workflow }}-${{ github.ref }}
14+
cancel-in-progress: true
15+
16+
jobs:
17+
lint:
18+
name: Lint
19+
runs-on: ubuntu-latest
20+
steps:
21+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
22+
- uses: actions/setup-python@65d7f2d534ac1bc67fcd62888c5f4f3d2cb2b236 # v4.7.1
23+
id: setup-python
24+
with:
25+
python-version-file: "./pyproject.toml"
26+
- name: Install uv
27+
run: pip install uv
28+
- name: Install dependencies with uv
29+
run: |
30+
uv sync
31+
- name: Run lint with uv
32+
run: |
33+
uv run ruff check .
34+
uv run ruff format --check .
35+
36+
test:
37+
name: Test
38+
runs-on: ubuntu-latest
39+
steps:
40+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
41+
- uses: actions/setup-python@65d7f2d534ac1bc67fcd62888c5f4f3d2cb2b236 # v4.7.1
42+
id: setup-python
43+
with:
44+
python-version-file: "./pyproject.toml"
45+
- name: Install uv
46+
run: pip install uv
47+
- name: Install dependencies with uv
48+
run: |
49+
uv sync
50+
- name: Run tests
51+
run: |
52+
uv run pytest tests

llm-coder-config.example.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ max_iterations = 10
1515
# LLM APIリクエスト1回あたりのタイムアウト秒数
1616
request_timeout = 60
1717

18+
# LLMへの入力の最大トークン数 (省略可能。指定しない場合、モデルの最大入力トークン数がデフォルトとして試行されます)
19+
# max_input_tokens = 2048
20+
1821
# ファイルシステム操作を許可するディレクトリのリスト
1922
# デフォルトでは、CLIを実行したカレントワーキングディレクトリが許可されます。
2023
# ここで指定すると、その設定がCLIのデフォルトよりも優先されます。

llm_coder/agent.py

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
final_summary_prompt: str = FINAL_SUMMARY_PROMPT, # 最終要約用プロンプト
8787
repository_description_prompt: str = None, # リポジトリ説明プロンプト
8888
request_timeout: int = 180, # 1回のリクエストに対するタイムアウト秒数(CLIから調整可能、デフォルト180)
89+
max_input_tokens: int = None, # LLMの最大入力トークン数
8990
):
9091
self.model = model
9192
self.temperature = temperature
@@ -94,6 +95,7 @@ def __init__(
9495
self.final_summary_prompt = final_summary_prompt
9596
# repository_description_prompt が None または空文字列の場合はそのまま None または空文字列を保持
9697
self.repository_description_prompt = repository_description_prompt
98+
self.max_input_tokens = max_input_tokens # 最大生成トークン数を設定
9799

98100
# 利用可能なツールを設定
99101
self.available_tools = available_tools or []
@@ -121,6 +123,86 @@ def __init__(
121123
else 0,
122124
)
123125

126+
async def _get_messages_for_llm(self) -> List[Dict[str, Any]]:
127+
"""
128+
LLMに渡すメッセージリストを作成する。トークン数制限を考慮する。
129+
130+
Returns:
131+
LLMに渡すメッセージの辞書リスト。
132+
"""
133+
if not self.conversation_history:
134+
return []
135+
136+
messages_to_send = []
137+
current_tokens = 0
138+
139+
# 1. 最初のシステムメッセージと最初のユーザープロンプトは必須
140+
# 最初のシステムメッセージ
141+
if self.conversation_history[0].role == "system":
142+
system_message = self.conversation_history[0].to_dict()
143+
messages_to_send.append(system_message)
144+
if self.max_input_tokens is not None:
145+
current_tokens += litellm.token_counter(
146+
model=self.model, messages=[system_message]
147+
)
148+
149+
# 最初のユーザーメッセージ (システムメッセージの次にあると仮定)
150+
if (
151+
len(self.conversation_history) > 1
152+
and self.conversation_history[1].role == "user"
153+
):
154+
user_message = self.conversation_history[1].to_dict()
155+
# 既にシステムメッセージが追加されているか確認
156+
if not messages_to_send or messages_to_send[-1] != user_message:
157+
# トークンチェック
158+
if self.max_input_tokens is not None:
159+
user_message_tokens = litellm.token_counter(
160+
model=self.model, messages=[user_message]
161+
)
162+
if current_tokens + user_message_tokens <= self.max_input_tokens:
163+
messages_to_send.append(user_message)
164+
current_tokens += user_message_tokens
165+
else:
166+
raise ValueError(
167+
f"最初のユーザーメッセージがトークン制限を超えています。必要なトークン数: {user_message_tokens}, 現在のトークン数: {current_tokens}, 最大トークン数: {self.max_input_tokens}"
168+
)
169+
else:
170+
messages_to_send.append(user_message)
171+
172+
# 2. 最新の会話履歴からトークン制限を超えない範囲で追加
173+
# 必須メッセージ以降の履歴を取得 (必須メッセージが2つと仮定)
174+
remaining_history = self.conversation_history[2:]
175+
176+
temp_recent_messages: list[Dict[str, Any]] = []
177+
for msg in reversed(remaining_history):
178+
msg_dict = msg.to_dict()
179+
if self.max_input_tokens is not None:
180+
msg_tokens = litellm.token_counter(
181+
model=self.model, messages=[msg_dict]
182+
)
183+
if current_tokens + msg_tokens <= self.max_input_tokens:
184+
temp_recent_messages.insert(0, msg_dict) # 逆順なので先頭に追加
185+
current_tokens += msg_tokens
186+
else:
187+
# トークン制限に達したらループを抜ける
188+
logger.debug(
189+
"トークン制限に達したため、これ以上過去のメッセージは含めません。",
190+
message_content=msg_dict.get("content", "")[:50],
191+
required_tokens=msg_tokens,
192+
current_tokens=current_tokens,
193+
max_tokens=self.max_input_tokens,
194+
)
195+
break
196+
else:
197+
temp_recent_messages.insert(0, msg_dict)
198+
199+
messages_to_send.extend(temp_recent_messages)
200+
201+
logger.debug(
202+
f"LLMに渡すメッセージ数: {len(messages_to_send)}, トークン数: {current_tokens if self.max_input_tokens is not None else 'N/A'}"
203+
)
204+
return messages_to_send
205+
124206
async def _execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
125207
"""指定されたツールを実行してその結果を返す"""
126208
logger.debug("Executing tool", tool_name=tool_name, arguments=arguments)
@@ -209,7 +291,9 @@ async def _planning_phase(self, prompt: str) -> None:
209291
try:
210292
response = await litellm.acompletion(
211293
model=self.model,
212-
messages=[msg.to_dict() for msg in self.conversation_history],
294+
messages=[
295+
msg.to_dict() for msg in self.conversation_history
296+
], # プランニングフェーズでは全履歴を使用することが多い
213297
temperature=self.temperature,
214298
tools=self.tools, # 更新されたツールリストを使用
215299
timeout=self.request_timeout, # 1回のリクエスト用タイムアウト
@@ -287,7 +371,7 @@ async def _execution_phase(self) -> bool:
287371

288372
response = await litellm.acompletion(
289373
model=self.model,
290-
messages=[msg.to_dict() for msg in self.conversation_history],
374+
messages=await self._get_messages_for_llm(), # 引数を削除
291375
temperature=self.temperature,
292376
tools=self.tools, # 更新されたツールリストを使用
293377
timeout=self.request_timeout, # 1回のリクエスト用タイムアウト
@@ -390,7 +474,7 @@ async def _execution_phase(self) -> bool:
390474
logger.debug("Getting next actions from LLM after tool executions")
391475
response = await litellm.acompletion(
392476
model=self.model,
393-
messages=[msg.to_dict() for msg in self.conversation_history],
477+
messages=await self._get_messages_for_llm(), # 引数を削除
394478
temperature=self.temperature,
395479
tools=self.tools, # 更新されたツールリストを使用
396480
timeout=self.request_timeout, # 1回のリクエスト用タイムアウト
@@ -474,9 +558,9 @@ async def run(self, prompt: str) -> str:
474558

475559
final_response = await litellm.acompletion(
476560
model=self.model,
477-
messages=[msg.to_dict() for msg in self.conversation_history],
561+
messages=await self._get_messages_for_llm(),
478562
temperature=self.temperature,
479-
tools=self.tools, # ツールパラメータを追加
563+
tools=self.tools, # 使わないけど、ツールリストを提供して、Anthropicの要件を満たす
480564
timeout=self.request_timeout, # 1回のリクエスト用タイムアウト
481565
)
482566
logger.debug(

llm_coder/cli.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os # os モジュールをインポート
55
import sys # sys モジュールをインポート
66
import toml # toml をインポート
7+
from litellm import get_model_info # get_model_info をインポート
78

89
# agent と filesystem モジュールをインポート
910
from llm_coder.agent import Agent
@@ -148,6 +149,15 @@ def parse_args():
148149
help=f"LLM APIリクエスト1回あたりのタイムアウト秒数 (デフォルト: {request_timeout_default})",
149150
)
150151

152+
# 最大入力トークン数のオプションを追加
153+
max_input_tokens_default = config_values.get("max_input_tokens", None)
154+
parser.add_argument(
155+
"--max-input-tokens",
156+
type=int,
157+
default=max_input_tokens_default,
158+
help="LLMの最大入力トークン数 (デフォルト: モデル固有の最大値)",
159+
)
160+
151161
# remaining_argv を使って、--config 以外の引数を解析
152162
return parser.parse_args(remaining_argv)
153163

@@ -199,13 +209,22 @@ async def run_agent_from_cli(args):
199209
logger.debug("Total available tools", tool_count=len(all_available_tools))
200210

201211
logger.debug("Initializing agent from CLI")
212+
213+
# 最大入力トークン数を決定
214+
max_input_tokens = args.max_input_tokens
215+
if max_input_tokens is None:
216+
model_info = get_model_info(args.model)
217+
if model_info and "max_input_tokens" in model_info:
218+
max_input_tokens = model_info["max_input_tokens"]
219+
202220
agent_instance = Agent( # Agent クラスのインスタンス名変更
203221
model=args.model,
204222
temperature=args.temperature,
205223
max_iterations=args.max_iterations,
206224
available_tools=all_available_tools, # 更新されたツールリストを使用
207225
repository_description_prompt=args.repository_description_prompt, # リポジトリ説明プロンプトを渡す
208226
request_timeout=args.request_timeout, # LLM APIリクエストのタイムアウトを渡す
227+
max_input_tokens=max_input_tokens, # 最大入力トークン数を渡す
209228
)
210229

211230
logger.info("Starting agent run from CLI", prompt_length=len(prompt))

playground/server/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class Item(BaseModel):
1414

1515
@app.get("/")
1616
def read_root():
17-
return "Hello USA"
17+
return "Hello USA!"
1818

1919

2020
@app.get("/items/{item_id}")

playground/server/test_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_read_root():
1010
"""
1111
response = client.get("/")
1212
assert response.status_code == 200
13-
assert response.json() == "Hello USA"
13+
assert response.json() == "Hello USA!"
1414

1515

1616
def test_read_item():

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ dependencies = [
2727
"Bug Tracker" = "https://github.com/igtm/llm-coder/issues"
2828

2929
[dependency-groups]
30-
dev = ["pytest>=8.3.5", "pytest-asyncio>=0.26.0", "ruff>=0.11.9"]
30+
dev = [
31+
"pytest>=8.3.5",
32+
"pytest-asyncio>=0.26.0",
33+
"ruff>=0.11.9",
34+
]
3135

3236
[tool.setuptools.packages.find]
3337
include = ["llm_coder*"]

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[pytest]
2+
asyncio_mode = auto
23
asyncio_default_fixture_loop_scope = function

0 commit comments

Comments
 (0)