From 37c421bb45f866e748962689ee6c5dc728253842 Mon Sep 17 00:00:00 2001 From: nai-degen Date: Tue, 13 Aug 2024 20:27:31 -0500 Subject: [PATCH] fixes token counting for streaming Mistral Text prompts --- .../transformers/mistral-ai-to-openai.ts | 62 +++++++++++++------ 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts b/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts index bcbc757..817d3cf 100644 --- a/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts +++ b/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts @@ -7,19 +7,25 @@ const log = logger.child({ transformer: "mistral-ai-to-openai", }); -type MistralAIStreamEvent = { +type MistralChatCompletionEvent = { choices: { index: number; message: { role: string; content: string }; stop_reason: string | null; }[]; +}; +type MistralTextCompletionEvent = { + outputs: { text: string; stop_reason: string | null }[]; +}; + +type MistralAIStreamEvent = { "amazon-bedrock-invocationMetrics"?: { inputTokenCount: number; outputTokenCount: number; invocationLatency: number; firstByteLatency: number; }; -}; +} & (MistralChatCompletionEvent | MistralTextCompletionEvent); export const mistralAIToOpenAI = (params: SSEResponseTransformArgs) => { const { data } = params; @@ -34,29 +40,49 @@ export const mistralAIToOpenAI = (params: SSEResponseTransformArgs) => { return { position: -1 }; } - const newEvent = { - id: params.fallbackId, - object: "chat.completion.chunk" as const, - created: Date.now(), - model: params.fallbackModel, - choices: [ - { - index: completionEvent.choices[0].index, - delta: { content: completionEvent.choices[0].message.content }, - finish_reason: completionEvent.choices[0].stop_reason, - }, - ], - }; + if ("choices" in completionEvent) { + const newChatEvent = { + id: params.fallbackId, + object: "chat.completion.chunk" as const, + created: Date.now(), + model: params.fallbackModel, + choices: [ + { + index: completionEvent.choices[0].index, + delta: { content: completionEvent.choices[0].message.content }, + finish_reason: completionEvent.choices[0].stop_reason, + }, + ], + }; + return { position: -1, event: newChatEvent }; + } else if ("outputs" in completionEvent) { + const newTextEvent = { + id: params.fallbackId, + object: "chat.completion.chunk" as const, + created: Date.now(), + model: params.fallbackModel, + choices: [ + { + index: 0, + delta: { content: completionEvent.outputs[0].text }, + finish_reason: completionEvent.outputs[0].stop_reason, + }, + ], + }; + return { position: -1, event: newTextEvent }; + } - return { position: -1, event: newEvent }; + // should never happen + return { position: -1 }; }; function asCompletion(event: ServerSentEvent): MistralAIStreamEvent | null { try { const parsed = JSON.parse(event.data); if ( - Array.isArray(parsed.choices) && - parsed.choices[0].message !== undefined + (Array.isArray(parsed.choices) && + parsed.choices[0].message !== undefined) || + (Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined) ) { return parsed; } else {