Skip to content

Commit 9cdaede

Browse files
authored
Merge pull request #41 from 0xturboblitz/master
feat: model selection logic
2 parents 71ded69 + 22d5d2d commit 9cdaede

File tree

7 files changed

+134
-55
lines changed

7 files changed

+134
-55
lines changed

src/cli/commands/estimate/index.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ export const estimate = async ({
1414
root,
1515
output,
1616
llms,
17+
priority,
18+
maxConcurrentCalls,
19+
addQuestions,
1720
ignore,
1821
filePrompt,
1922
folderPrompt,
@@ -37,6 +40,9 @@ export const estimate = async ({
3740
root,
3841
output: json,
3942
llms,
43+
priority,
44+
maxConcurrentCalls,
45+
addQuestions,
4046
ignore,
4147
filePrompt,
4248
folderPrompt,

src/cli/commands/index/index.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ export const index = async ({
1111
root,
1212
output,
1313
llms,
14+
priority,
15+
maxConcurrentCalls,
16+
addQuestions,
1417
ignore,
1518
filePrompt,
1619
folderPrompt,
@@ -35,6 +38,9 @@ export const index = async ({
3538
root,
3639
output: json,
3740
llms,
41+
priority,
42+
maxConcurrentCalls,
43+
addQuestions,
3844
ignore,
3945
filePrompt,
4046
folderPrompt,
@@ -56,6 +62,9 @@ export const index = async ({
5662
root: json,
5763
output: markdown,
5864
llms,
65+
priority,
66+
maxConcurrentCalls,
67+
addQuestions,
5968
ignore,
6069
filePrompt,
6170
folderPrompt,
@@ -73,6 +82,9 @@ export const index = async ({
7382
root: markdown,
7483
output: data,
7584
llms,
85+
priority,
86+
maxConcurrentCalls,
87+
addQuestions,
7688
ignore,
7789
filePrompt,
7890
folderPrompt,

src/cli/commands/index/processRepository.ts

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import {
3030
githubFolderUrl,
3131
} from '../../utils/FileUtil.js';
3232
import { models } from '../../utils/LLMUtil.js';
33+
import { selectModel } from './selectModel.js';
3334

3435
export const processRepository = async (
3536
{
@@ -38,6 +39,9 @@ export const processRepository = async (
3839
root: inputRoot,
3940
output: outputRoot,
4041
llms,
42+
priority,
43+
maxConcurrentCalls,
44+
addQuestions,
4145
ignore,
4246
filePrompt,
4347
folderPrompt,
@@ -47,8 +51,7 @@ export const processRepository = async (
4751
}: AutodocRepoConfig,
4852
dryRun?: boolean,
4953
) => {
50-
const encoding = encoding_for_model('gpt-3.5-turbo');
51-
const rateLimit = new APIRateLimit(25);
54+
const rateLimit = new APIRateLimit(maxConcurrentCalls);
5255

5356
const callLLM = async (
5457
prompt: string,
@@ -91,6 +94,7 @@ export const processRepository = async (
9194

9295
const markdownFilePath = path.join(outputRoot, filePath);
9396
const url = githubFileUrl(repositoryUrl, inputRoot, filePath, linkHosted);
97+
9498
const summaryPrompt = createCodeFileSummary(
9599
projectName,
96100
projectName,
@@ -105,51 +109,28 @@ export const processRepository = async (
105109
contentType,
106110
targetAudience,
107111
);
108-
const summaryLength = encoding.encode(summaryPrompt).length;
109-
const questionLength = encoding.encode(questionsPrompt).length;
110-
const max = Math.max(questionLength, summaryLength);
111112

112-
/**
113-
* TODO: Encapsulate logic for selecting the best model
114-
* TODO: Allow for different selection strategies based
115-
* TODO: preference for cost/performace
116-
* TODO: When this is re-written, it should use the correct
117-
* TODO: TikToken encoding for each model
118-
*/
113+
const prompts = addQuestions
114+
? [summaryPrompt, questionsPrompt]
115+
: [summaryPrompt];
119116

120-
const model: LLMModelDetails | null = (() => {
121-
if (
122-
models[LLMModels.GPT3].maxLength > max &&
123-
llms.includes(LLMModels.GPT3)
124-
) {
125-
return models[LLMModels.GPT3];
126-
} else if (
127-
models[LLMModels.GPT4].maxLength > max &&
128-
llms.includes(LLMModels.GPT4)
129-
) {
130-
return models[LLMModels.GPT4];
131-
} else if (
132-
models[LLMModels.GPT432k].maxLength > max &&
133-
llms.includes(LLMModels.GPT432k)
134-
) {
135-
return models[LLMModels.GPT432k];
136-
} else {
137-
return null;
138-
}
139-
})();
117+
const model = selectModel(prompts, llms, models, priority);
140118

141119
if (!isModel(model)) {
142120
// console.log(`Skipped ${filePath} | Length ${max}`);
143121
return;
144122
}
145123

124+
const encoding = encoding_for_model(model.name);
125+
const summaryLength = encoding.encode(summaryPrompt).length;
126+
const questionLength = encoding.encode(questionsPrompt).length;
127+
146128
try {
147129
if (!dryRun) {
148130
/** Call LLM */
149-
const [summary, questions] = await Promise.all([
150-
callLLM(summaryPrompt, model.llm),
151-
callLLM(questionsPrompt, model.llm),
152-
]);
131+
const response = await Promise.all(
132+
prompts.map(async (prompt) => callLLM(prompt, model.llm)),
133+
);
153134

154135
/**
155136
* Create file and save to disk
@@ -158,8 +139,8 @@ export const processRepository = async (
158139
fileName,
159140
filePath,
160141
url,
161-
summary,
162-
questions,
142+
summary: response[0],
143+
questions: addQuestions ? response[1] : '',
163144
checksum: newChecksum,
164145
};
165146

@@ -186,7 +167,8 @@ export const processRepository = async (
186167
/**
187168
* Track usage for end of run summary
188169
*/
189-
model.inputTokens += summaryLength + questionLength;
170+
model.inputTokens += summaryLength;
171+
if (addQuestions) model.inputTokens += questionLength;
190172
model.total++;
191173
model.outputTokens += 1000;
192174
model.succeeded++;
@@ -236,7 +218,12 @@ export const processRepository = async (
236218
}
237219

238220
// eslint-disable-next-line prettier/prettier
239-
const url = githubFolderUrl(repositoryUrl, inputRoot, folderPath, linkHosted);
221+
const url = githubFolderUrl(
222+
repositoryUrl,
223+
inputRoot,
224+
folderPath,
225+
linkHosted,
226+
);
240227
const allFiles: (FileSummary | null)[] = await Promise.all(
241228
contents.map(async (fileName) => {
242229
const entryPath = path.join(folderPath, fileName);
@@ -279,18 +266,24 @@ export const processRepository = async (
279266
(folder): folder is FolderSummary => folder !== null,
280267
);
281268

282-
const summary = await callLLM(
283-
folderSummaryPrompt(
284-
folderPath,
285-
projectName,
286-
files,
287-
folders,
288-
contentType,
289-
folderPrompt,
290-
),
291-
models[LLMModels.GPT4].llm,
269+
const summaryPrompt = folderSummaryPrompt(
270+
folderPath,
271+
projectName,
272+
files,
273+
folders,
274+
contentType,
275+
folderPrompt,
292276
);
293277

278+
const model = selectModel([summaryPrompt], llms, models, priority);
279+
280+
if (!isModel(model)) {
281+
// console.log(`Skipped ${filePath} | Length ${max}`);
282+
return;
283+
}
284+
285+
const summary = await callLLM(summaryPrompt, model.llm);
286+
294287
const folderSummary: FolderSummary = {
295288
folderName,
296289
folderPath,

src/cli/commands/index/selectModel.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import { encoding_for_model } from '@dqbd/tiktoken';
2+
import { LLMModelDetails, LLMModels, Priority } from '../../../types.js';
3+
4+
export const selectModel = (
5+
prompts: string[],
6+
llms: LLMModels[],
7+
models: Record<LLMModels, LLMModelDetails>,
8+
priority: Priority,
9+
): LLMModelDetails | null => {
10+
if (priority === Priority.COST) {
11+
if (
12+
llms.includes(LLMModels.GPT3) &&
13+
models[LLMModels.GPT3].maxLength >
14+
getMaxPromptLength(prompts, LLMModels.GPT3)
15+
) {
16+
return models[LLMModels.GPT3];
17+
} else if (
18+
llms.includes(LLMModels.GPT4) &&
19+
models[LLMModels.GPT4].maxLength >
20+
getMaxPromptLength(prompts, LLMModels.GPT4)
21+
) {
22+
return models[LLMModels.GPT4];
23+
} else if (
24+
llms.includes(LLMModels.GPT432k) &&
25+
models[LLMModels.GPT432k].maxLength >
26+
getMaxPromptLength(prompts, LLMModels.GPT432k)
27+
) {
28+
return models[LLMModels.GPT432k];
29+
} else {
30+
return null;
31+
}
32+
} else {
33+
if (llms.includes(LLMModels.GPT4)) {
34+
if (
35+
models[LLMModels.GPT4].maxLength >
36+
getMaxPromptLength(prompts, LLMModels.GPT4)
37+
) {
38+
return models[LLMModels.GPT4];
39+
} else if (
40+
llms.includes(LLMModels.GPT432k) &&
41+
models[LLMModels.GPT432k].maxLength >
42+
getMaxPromptLength(prompts, LLMModels.GPT432k)
43+
) {
44+
return models[LLMModels.GPT432k];
45+
} else {
46+
return null;
47+
}
48+
} else {
49+
return models[LLMModels.GPT3];
50+
}
51+
}
52+
53+
function getMaxPromptLength(prompts: string[], model: LLMModels) {
54+
const encoding = encoding_for_model(model);
55+
return Math.max(...prompts.map((p) => encoding.encode(p).length));
56+
}
57+
};

src/cli/commands/init/index.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import chalk from 'chalk';
22
import inquirer from 'inquirer';
33
import fs from 'node:fs';
44
import path from 'node:path';
5-
import { AutodocRepoConfig, LLMModels } from '../../../types.js';
5+
import { AutodocRepoConfig, LLMModels, Priority } from '../../../types.js';
66

77
export const makeConfigTemplate = (
88
config?: AutodocRepoConfig,
@@ -16,6 +16,9 @@ export const makeConfigTemplate = (
1616
config?.llms?.length ?? 0 > 0
1717
? (config as AutodocRepoConfig).llms
1818
: [LLMModels.GPT3],
19+
priority: Priority.COST,
20+
maxConcurrentCalls: 25,
21+
addQuestions: true,
1922
ignore: [
2023
'.*',
2124
'*package-lock.json',

src/cli/utils/LLMUtil.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { LLMModelDetails, LLMModels } from '../../types.js';
44
export const models: Record<LLMModels, LLMModelDetails> = {
55
[LLMModels.GPT3]: {
66
name: LLMModels.GPT3,
7-
inputCostPer1KTokens: 0.002,
7+
inputCostPer1KTokens: 0.0015,
88
outputCostPer1KTokens: 0.002,
99
maxLength: 3050,
1010
llm: new OpenAIChat({
@@ -61,7 +61,7 @@ export const printModelDetails = (models: LLMModelDetails[]): void => {
6161
Failed: model.failed,
6262
Tokens: model.inputTokens + model.outputTokens,
6363
Cost:
64-
(model.total / 1000) * model.inputCostPer1KTokens +
64+
(model.inputTokens / 1000) * model.inputCostPer1KTokens +
6565
(model.outputTokens / 1000) * model.outputCostPer1KTokens,
6666
};
6767
});
@@ -95,7 +95,7 @@ export const totalIndexCostEstimate = (models: LLMModelDetails[]): number => {
9595
const totalCost = models.reduce((cur, model) => {
9696
return (
9797
cur +
98-
(model.total / 1000) * model.inputCostPer1KTokens +
98+
(model.inputTokens / 1000) * model.inputCostPer1KTokens +
9999
(model.outputTokens / 1000) * model.outputCostPer1KTokens
100100
);
101101
}, 0);

src/types.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ export type AutodocRepoConfig = {
1010
root: string;
1111
output: string;
1212
llms: LLMModels[];
13+
priority: Priority;
14+
maxConcurrentCalls: number;
15+
addQuestions: boolean;
1316
ignore: string[];
1417
filePrompt: string;
1518
folderPrompt: string;
@@ -24,7 +27,7 @@ export type FileSummary = {
2427
filePath: string;
2528
url: string;
2629
summary: string;
27-
questions: string;
30+
questions?: string;
2831
checksum: string;
2932
};
3033

@@ -96,3 +99,8 @@ export type LLMModelDetails = {
9699
failed: number;
97100
total: number;
98101
};
102+
103+
export enum Priority {
104+
COST = 'cost',
105+
PERFORMANCE = 'performance',
106+
}

0 commit comments

Comments
 (0)