implements aws mistral streaming

This commit is contained in:
nai-degen
2024-08-13 20:04:02 -05:00
parent 2fe6e07cf5
commit e145f5757e
3 changed files with 74 additions and 2 deletions
@@ -83,7 +83,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
// streaming JSON, etc).
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
// Adapter consumes the decoded events and produces server-sent events so we
// have a standard transport for the client and to translate between API
// have a standard event format for the client and to translate between API
// message formats.
const adapter = new SSEStreamAdapter(streamOptions);
// Transformer converts server-sent events from one vendor's API message
@@ -14,6 +14,7 @@ import {
passthroughToOpenAI,
StreamingCompletionTransformer,
} from "./index";
import { mistralAIToOpenAI } from "./transformers/mistral-ai-to-openai";
type SSEMessageTransformerOptions = TransformOptions & {
requestedModel: string;
@@ -130,7 +131,6 @@ function getTransformer(
> {
switch (responseApi) {
case "openai":
case "mistral-ai":
return passthroughToOpenAI;
case "openai-text":
return openAITextToOpenAIChat;
@@ -144,6 +144,8 @@ function getTransformer(
: anthropicChatToOpenAI;
case "google-ai":
return googleAIToOpenAI;
case "mistral-ai":
return mistralAIToOpenAI;
case "openai-image":
throw new Error(`SSE transformation not supported for ${responseApi}`);
default:
@@ -0,0 +1,70 @@
import { logger } from "../../../../../logger";
import { SSEResponseTransformArgs } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
const log = logger.child({
module: "sse-transformer",
transformer: "mistral-ai-to-openai",
});
type MistralAIStreamEvent = {
choices: {
index: number;
message: { role: string; content: string };
stop_reason: string | null;
}[];
"amazon-bedrock-invocationMetrics"?: {
inputTokenCount: number;
outputTokenCount: number;
invocationLatency: number;
firstByteLatency: number;
};
};
export const mistralAIToOpenAI = (params: SSEResponseTransformArgs) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
const newEvent = {
id: params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: completionEvent.choices[0].index,
delta: { content: completionEvent.choices[0].message.content },
finish_reason: completionEvent.choices[0].stop_reason,
},
],
};
return { position: -1, event: newEvent };
};
function asCompletion(event: ServerSentEvent): MistralAIStreamEvent | null {
try {
const parsed = JSON.parse(event.data);
if (
Array.isArray(parsed.choices) &&
parsed.choices[0].message !== undefined
) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid data event");
}
return null;
}