Skip to content

Commit 37e0f7d

Browse files
authored
Merge branch 'main' into main
2 parents 3139be6 + 1909bbc commit 37e0f7d

File tree

3 files changed

+68
-5
lines changed

3 files changed

+68
-5
lines changed

src/shared/protocol.test.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,44 @@ describe("protocol tests", () => {
7171
jest.useRealTimers();
7272
});
7373

74+
test("should not reset timeout when resetTimeoutOnProgress is false", async () => {
75+
await protocol.connect(transport);
76+
const request = { method: "example", params: {} };
77+
const mockSchema: ZodType<{ result: string }> = z.object({
78+
result: z.string(),
79+
});
80+
const onProgressMock = jest.fn();
81+
const requestPromise = protocol.request(request, mockSchema, {
82+
timeout: 1000,
83+
resetTimeoutOnProgress: false,
84+
onprogress: onProgressMock,
85+
});
86+
87+
jest.advanceTimersByTime(800);
88+
89+
if (transport.onmessage) {
90+
transport.onmessage({
91+
jsonrpc: "2.0",
92+
method: "notifications/progress",
93+
params: {
94+
progressToken: 0,
95+
progress: 50,
96+
total: 100,
97+
},
98+
});
99+
}
100+
await Promise.resolve();
101+
102+
expect(onProgressMock).toHaveBeenCalledWith({
103+
progress: 50,
104+
total: 100,
105+
});
106+
107+
jest.advanceTimersByTime(201);
108+
109+
await expect(requestPromise).rejects.toThrow("Request timed out");
110+
});
111+
74112
test("should reset timeout when progress notification is received", async () => {
75113
await protocol.connect(transport);
76114
const request = { method: "example", params: {} };

src/shared/protocol.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ type TimeoutInfo = {
103103
startTime: number;
104104
timeout: number;
105105
maxTotalTimeout?: number;
106+
resetTimeoutOnProgress: boolean;
106107
onTimeout: () => void;
107108
};
108109

@@ -184,13 +185,15 @@ export abstract class Protocol<
184185
messageId: number,
185186
timeout: number,
186187
maxTotalTimeout: number | undefined,
187-
onTimeout: () => void
188+
onTimeout: () => void,
189+
resetTimeoutOnProgress: boolean = false
188190
) {
189191
this._timeoutInfo.set(messageId, {
190192
timeoutId: setTimeout(onTimeout, timeout),
191193
startTime: Date.now(),
192194
timeout,
193195
maxTotalTimeout,
196+
resetTimeoutOnProgress,
194197
onTimeout
195198
});
196199
}
@@ -369,7 +372,9 @@ export abstract class Protocol<
369372
}
370373

371374
const responseHandler = this._responseHandlers.get(messageId);
372-
if (this._timeoutInfo.has(messageId) && responseHandler) {
375+
const timeoutInfo = this._timeoutInfo.get(messageId);
376+
377+
if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) {
373378
try {
374379
this._resetTimeout(messageId);
375380
} catch (error) {
@@ -531,7 +536,7 @@ export abstract class Protocol<
531536
{ timeout }
532537
));
533538

534-
this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler);
539+
this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false);
535540

536541
this._transport.send(jsonrpcRequest).catch((error) => {
537542
this._cleanupTimeout(messageId);

src/types.ts

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,23 @@ export const ImageContentSchema = z
666666
})
667667
.passthrough();
668668

669+
/**
670+
* An Audio provided to or from an LLM.
671+
*/
672+
export const AudioContentSchema = z
673+
.object({
674+
type: z.literal("audio"),
675+
/**
676+
* The base64-encoded audio data.
677+
*/
678+
data: z.string().base64(),
679+
/**
680+
* The MIME type of the audio. Different providers may support different audio types.
681+
*/
682+
mimeType: z.string(),
683+
})
684+
.passthrough();
685+
669686
/**
670687
* The contents of a resource, embedded into a prompt or tool call result.
671688
*/
@@ -685,6 +702,7 @@ export const PromptMessageSchema = z
685702
content: z.union([
686703
TextContentSchema,
687704
ImageContentSchema,
705+
AudioContentSchema,
688706
EmbeddedResourceSchema,
689707
]),
690708
})
@@ -753,7 +771,7 @@ export const ListToolsResultSchema = PaginatedResultSchema.extend({
753771
*/
754772
export const CallToolResultSchema = ResultSchema.extend({
755773
content: z.array(
756-
z.union([TextContentSchema, ImageContentSchema, EmbeddedResourceSchema]),
774+
z.union([TextContentSchema, ImageContentSchema, AudioContentSchema, EmbeddedResourceSchema]),
757775
),
758776
isError: z.boolean().default(false).optional(),
759777
});
@@ -877,7 +895,7 @@ export const ModelPreferencesSchema = z
877895
export const SamplingMessageSchema = z
878896
.object({
879897
role: z.enum(["user", "assistant"]),
880-
content: z.union([TextContentSchema, ImageContentSchema]),
898+
content: z.union([TextContentSchema, ImageContentSchema, AudioContentSchema]),
881899
})
882900
.passthrough();
883901

@@ -931,6 +949,7 @@ export const CreateMessageResultSchema = ResultSchema.extend({
931949
content: z.discriminatedUnion("type", [
932950
TextContentSchema,
933951
ImageContentSchema,
952+
AudioContentSchema
934953
]),
935954
});
936955

@@ -1195,6 +1214,7 @@ export type ListPromptsResult = Infer<typeof ListPromptsResultSchema>;
11951214
export type GetPromptRequest = Infer<typeof GetPromptRequestSchema>;
11961215
export type TextContent = Infer<typeof TextContentSchema>;
11971216
export type ImageContent = Infer<typeof ImageContentSchema>;
1217+
export type AudioContent = Infer<typeof AudioContentSchema>;
11981218
export type EmbeddedResource = Infer<typeof EmbeddedResourceSchema>;
11991219
export type PromptMessage = Infer<typeof PromptMessageSchema>;
12001220
export type GetPromptResult = Infer<typeof GetPromptResultSchema>;

0 commit comments

Comments
 (0)