178 lines
5.0 KiB
TypeScript
178 lines
5.0 KiB
TypeScript
import { Response } from "express";
|
|
import * as http from "http";
|
|
import { RawResponseBodyHandler, decodeResponseBody } from ".";
|
|
|
|
/**
|
|
* Consume the SSE stream and forward events to the client. Once the stream is
|
|
* stream is closed, resolve with the full response body so that subsequent
|
|
* middleware can work with it.
|
|
*
|
|
* Typically we would only need of the raw response handlers to execute, but
|
|
* in the event a streamed request results in a non-200 response, we need to
|
|
* fall back to the non-streaming response handler so that the error handler
|
|
* can inspect the error response.
|
|
*/
|
|
export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|
proxyRes,
|
|
req,
|
|
res
|
|
) => {
|
|
if (!req.isStreaming) {
|
|
req.log.error(
|
|
{ api: req.api, key: req.key?.hash },
|
|
`handleEventSource called for non-streaming request, which isn't valid.`
|
|
);
|
|
throw new Error("handleEventSource called for non-streaming request.");
|
|
}
|
|
|
|
if (proxyRes.statusCode !== 200) {
|
|
// Ensure we use the non-streaming middleware stack since we won't be
|
|
// getting any events.
|
|
req.isStreaming = false;
|
|
req.log.warn(
|
|
`Streaming request to ${req.api} returned ${proxyRes.statusCode} status code. Falling back to non-streaming response handler.`
|
|
);
|
|
return decodeResponseBody(proxyRes, req, res);
|
|
}
|
|
|
|
return new Promise((resolve, reject) => {
|
|
req.log.info(
|
|
{ api: req.api, key: req.key?.hash },
|
|
`Starting to proxy SSE stream.`
|
|
);
|
|
|
|
// Queued streaming requests will already have a connection open and headers
|
|
// sent due to the heartbeat handler. In that case we can just start
|
|
// streaming the response without sending headers.
|
|
if (!res.headersSent) {
|
|
res.setHeader("Content-Type", "text/event-stream");
|
|
res.setHeader("Cache-Control", "no-cache");
|
|
res.setHeader("Connection", "keep-alive");
|
|
res.setHeader("X-Accel-Buffering", "no");
|
|
copyHeaders(proxyRes, res);
|
|
res.flushHeaders();
|
|
}
|
|
|
|
const chunks: Buffer[] = [];
|
|
proxyRes.on("data", (chunk) => {
|
|
chunks.push(chunk);
|
|
res.write(chunk);
|
|
});
|
|
|
|
proxyRes.on("end", () => {
|
|
const finalBody = convertEventsToOpenAiResponse(chunks);
|
|
req.log.info(
|
|
{ api: req.api, key: req.key?.hash },
|
|
`Finished proxying SSE stream.`
|
|
);
|
|
res.end();
|
|
resolve(finalBody);
|
|
});
|
|
proxyRes.on("error", (err) => {
|
|
req.log.error(
|
|
{ error: err, api: req.api, key: req.key?.hash },
|
|
`Error while streaming response.`
|
|
);
|
|
// OAI's spec doesn't allow for error events and clients wouldn't know
|
|
// what to do with them anyway, so we'll just send a completion event
|
|
// with the error message.
|
|
const fakeErrorEvent = {
|
|
id: "chatcmpl-error",
|
|
object: "chat.completion.chunk",
|
|
created: Date.now(),
|
|
model: "",
|
|
choices: [
|
|
{
|
|
delta: { content: "[Proxy streaming error: " + err.message + "]" },
|
|
index: 0,
|
|
finish_reason: "error",
|
|
},
|
|
],
|
|
};
|
|
res.write(`data: ${JSON.stringify(fakeErrorEvent)}\n\n`);
|
|
res.write("data: [DONE]\n\n");
|
|
res.end();
|
|
reject(err);
|
|
});
|
|
});
|
|
};
|
|
|
|
/** Copy headers, excluding ones we're already setting for the SSE response. */
|
|
const copyHeaders = (proxyRes: http.IncomingMessage, res: Response) => {
|
|
const toOmit = [
|
|
"content-length",
|
|
"content-encoding",
|
|
"transfer-encoding",
|
|
"content-type",
|
|
"connection",
|
|
"cache-control",
|
|
];
|
|
for (const [key, value] of Object.entries(proxyRes.headers)) {
|
|
if (!toOmit.includes(key) && value) {
|
|
res.setHeader(key, value);
|
|
}
|
|
}
|
|
};
|
|
|
|
type OpenAiChatCompletionResponse = {
|
|
id: string;
|
|
object: string;
|
|
created: number;
|
|
model: string;
|
|
choices: {
|
|
message: { role: string; content: string };
|
|
finish_reason: string | null;
|
|
index: number;
|
|
}[];
|
|
};
|
|
|
|
/** Converts the event stream chunks into a single completion response. */
|
|
const convertEventsToOpenAiResponse = (chunks: Buffer[]) => {
|
|
let response: OpenAiChatCompletionResponse = {
|
|
id: "",
|
|
object: "",
|
|
created: 0,
|
|
model: "",
|
|
choices: [],
|
|
};
|
|
const events = Buffer.concat(chunks)
|
|
.toString()
|
|
.trim()
|
|
.split("\n\n")
|
|
.map((line) => line.trim());
|
|
|
|
response = events.reduce((acc, chunk, i) => {
|
|
if (!chunk.startsWith("data: ")) {
|
|
return acc;
|
|
}
|
|
|
|
if (chunk === "data: [DONE]") {
|
|
return acc;
|
|
}
|
|
|
|
const data = JSON.parse(chunk.slice("data: ".length));
|
|
if (i === 0) {
|
|
return {
|
|
id: data.id,
|
|
object: data.object,
|
|
created: data.created,
|
|
model: data.model,
|
|
choices: [
|
|
{
|
|
message: { role: data.choices[0].delta.role, content: "" },
|
|
index: 0,
|
|
finish_reason: null,
|
|
},
|
|
],
|
|
};
|
|
}
|
|
|
|
if (data.choices[0].delta.content) {
|
|
acc.choices[0].message.content += data.choices[0].delta.content;
|
|
}
|
|
acc.choices[0].finish_reason = data.choices[0].finish_reason;
|
|
return acc;
|
|
}, response);
|
|
return response;
|
|
};
|