diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index 439766b..0fa4d51 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -83,7 +83,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( // streaming JSON, etc). const decoder = getDecoder({ ...streamOptions, input: proxyRes }); // Adapter consumes the decoded events and produces server-sent events so we - // have a standard transport for the client and to translate between API + // have a standard event format for the client and to translate between API // message formats. const adapter = new SSEStreamAdapter(streamOptions); // Transformer converts server-sent events from one vendor's API message diff --git a/src/proxy/middleware/response/streaming/sse-message-transformer.ts b/src/proxy/middleware/response/streaming/sse-message-transformer.ts index 800b286..cc13cd7 100644 --- a/src/proxy/middleware/response/streaming/sse-message-transformer.ts +++ b/src/proxy/middleware/response/streaming/sse-message-transformer.ts @@ -14,6 +14,7 @@ import { passthroughToOpenAI, StreamingCompletionTransformer, } from "./index"; +import { mistralAIToOpenAI } from "./transformers/mistral-ai-to-openai"; type SSEMessageTransformerOptions = TransformOptions & { requestedModel: string; @@ -130,7 +131,6 @@ function getTransformer( > { switch (responseApi) { case "openai": - case "mistral-ai": return passthroughToOpenAI; case "openai-text": return openAITextToOpenAIChat; @@ -144,6 +144,8 @@ function getTransformer( : anthropicChatToOpenAI; case "google-ai": return googleAIToOpenAI; + case "mistral-ai": + return mistralAIToOpenAI; case "openai-image": throw new Error(`SSE transformation not supported for ${responseApi}`); default: diff --git a/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts b/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts new file mode 100644 index 0000000..bcbc757 --- /dev/null +++ b/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts @@ -0,0 +1,70 @@ +import { logger } from "../../../../../logger"; +import { SSEResponseTransformArgs } from "../index"; +import { parseEvent, ServerSentEvent } from "../parse-sse"; + +const log = logger.child({ + module: "sse-transformer", + transformer: "mistral-ai-to-openai", +}); + +type MistralAIStreamEvent = { + choices: { + index: number; + message: { role: string; content: string }; + stop_reason: string | null; + }[]; + "amazon-bedrock-invocationMetrics"?: { + inputTokenCount: number; + outputTokenCount: number; + invocationLatency: number; + firstByteLatency: number; + }; +}; + +export const mistralAIToOpenAI = (params: SSEResponseTransformArgs) => { + const { data } = params; + + const rawEvent = parseEvent(data); + if (!rawEvent.data || rawEvent.data === "[DONE]") { + return { position: -1 }; + } + + const completionEvent = asCompletion(rawEvent); + if (!completionEvent) { + return { position: -1 }; + } + + const newEvent = { + id: params.fallbackId, + object: "chat.completion.chunk" as const, + created: Date.now(), + model: params.fallbackModel, + choices: [ + { + index: completionEvent.choices[0].index, + delta: { content: completionEvent.choices[0].message.content }, + finish_reason: completionEvent.choices[0].stop_reason, + }, + ], + }; + + return { position: -1, event: newEvent }; +}; + +function asCompletion(event: ServerSentEvent): MistralAIStreamEvent | null { + try { + const parsed = JSON.parse(event.data); + if ( + Array.isArray(parsed.choices) && + parsed.choices[0].message !== undefined + ) { + return parsed; + } else { + // noinspection ExceptionCaughtLocallyJS + throw new Error("Missing required fields"); + } + } catch (error) { + log.warn({ error: error.stack, event }, "Received invalid data event"); + } + return null; +}