Implement support for streamed OpenAI responses (khanon/oai-reverse-proxy!4)
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
import { config } from "../../../config";
|
||||
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
|
||||
|
||||
/**
|
||||
* If a stream is requested, mark the request as such so the response middleware
|
||||
* knows to use the alternate EventSource response handler.
|
||||
* Kobold requests can't currently be streamed as they use a different event
|
||||
* format than the OpenAI API and we need to rewrite the events as they come in,
|
||||
* which I have not yet implemented.
|
||||
*/
|
||||
export const checkStreaming: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
|
||||
const streamableApi = req.api !== "kobold";
|
||||
if (isCompletionRequest(req) && req.body?.stream) {
|
||||
if (!streamableApi) {
|
||||
req.log.warn(
|
||||
{ api: req.api, key: req.key?.hash },
|
||||
`Streaming requested, but ${req.api} streaming is not supported.`
|
||||
);
|
||||
req.body.stream = false;
|
||||
return;
|
||||
}
|
||||
req.body.stream = config.allowStreaming;
|
||||
req.isStreaming = config.allowStreaming;
|
||||
}
|
||||
};
|
||||
@@ -1,8 +0,0 @@
|
||||
import type { ExpressHttpProxyReqCallback } from ".";
|
||||
|
||||
/** Disable token streaming as the proxy middleware doesn't support it. */
|
||||
export const disableStream: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
|
||||
if (req.method === "POST" && req.body && req.body.stream) {
|
||||
req.body.stream = false;
|
||||
}
|
||||
};
|
||||
@@ -3,13 +3,23 @@ import type { ClientRequest } from "http";
|
||||
import type { ProxyReqCallback } from "http-proxy";
|
||||
|
||||
export { addKey } from "./add-key";
|
||||
export { disableStream } from "./disable-stream";
|
||||
export { checkStreaming } from "./check-streaming";
|
||||
export { finalizeBody } from "./finalize-body";
|
||||
export { languageFilter } from "./language-filter";
|
||||
export { limitCompletions } from "./limit-completions";
|
||||
export { limitOutputTokens } from "./limit-output-tokens";
|
||||
export { transformKoboldPayload } from "./transform-kobold-payload";
|
||||
|
||||
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
|
||||
|
||||
/** Returns true if we're making a chat completion request. */
|
||||
export function isCompletionRequest(req: Request) {
|
||||
return (
|
||||
req.method === "POST" &&
|
||||
req.path.startsWith(OPENAI_CHAT_COMPLETION_ENDPOINT)
|
||||
);
|
||||
}
|
||||
|
||||
export type ExpressHttpProxyReqCallback = ProxyReqCallback<
|
||||
ClientRequest,
|
||||
Request
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import type { ExpressHttpProxyReqCallback } from ".";
|
||||
|
||||
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
|
||||
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
|
||||
|
||||
/** Don't allow multiple completions to be requested to prevent abuse. */
|
||||
export const limitCompletions: ExpressHttpProxyReqCallback = (
|
||||
_proxyReq,
|
||||
req
|
||||
) => {
|
||||
if (req.method === "POST" && req.path === OPENAI_CHAT_COMPLETION_ENDPOINT) {
|
||||
if (isCompletionRequest(req)) {
|
||||
const originalN = req.body?.n || 1;
|
||||
req.body.n = 1;
|
||||
if (originalN !== req.body.n) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { ExpressHttpProxyReqCallback } from ".";
|
||||
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
|
||||
|
||||
const MAX_TOKENS = config.maxOutputTokens;
|
||||
|
||||
@@ -9,7 +9,7 @@ export const limitOutputTokens: ExpressHttpProxyReqCallback = (
|
||||
_proxyReq,
|
||||
req
|
||||
) => {
|
||||
if (req.method === "POST" && req.body?.max_tokens) {
|
||||
if (isCompletionRequest(req) && req.body?.max_tokens) {
|
||||
// convert bad or missing input to a MAX_TOKENS
|
||||
if (typeof req.body.max_tokens !== "number") {
|
||||
logger.warn(
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
import { Response } from "express";
|
||||
import * as http from "http";
|
||||
import { RawResponseBodyHandler, decodeResponseBody } from ".";
|
||||
|
||||
/**
|
||||
* Consume the SSE stream and forward events to the client. Once the stream is
|
||||
* stream is closed, resolve with the full response body so that subsequent
|
||||
* middleware can work with it.
|
||||
*
|
||||
* Typically we would only need of the raw response handlers to execute, but
|
||||
* in the event a streamed request results in a non-200 response, we need to
|
||||
* fall back to the non-streaming response handler so that the error handler
|
||||
* can inspect the error response.
|
||||
*/
|
||||
export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
proxyRes,
|
||||
req,
|
||||
res
|
||||
) => {
|
||||
if (!req.isStreaming) {
|
||||
req.log.error(
|
||||
{ api: req.api, key: req.key?.hash },
|
||||
`handleEventSource called for non-streaming request, which isn't valid.`
|
||||
);
|
||||
throw new Error("handleEventSource called for non-streaming request.");
|
||||
}
|
||||
|
||||
if (proxyRes.statusCode !== 200) {
|
||||
// Ensure we use the non-streaming middleware stack since we won't be
|
||||
// getting any events.
|
||||
req.isStreaming = false;
|
||||
req.log.warn(
|
||||
`Streaming request to ${req.api} returned ${proxyRes.statusCode} status code. Falling back to non-streaming response handler.`
|
||||
);
|
||||
return decodeResponseBody(proxyRes, req, res);
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
req.log.info(
|
||||
{ api: req.api, key: req.key?.hash },
|
||||
`Starting to proxy SSE stream.`
|
||||
);
|
||||
res.setHeader("Content-Type", "text/event-stream");
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
copyHeaders(proxyRes, res);
|
||||
|
||||
const chunks: Buffer[] = [];
|
||||
proxyRes.on("data", (chunk) => {
|
||||
chunks.push(chunk);
|
||||
res.write(chunk);
|
||||
});
|
||||
|
||||
proxyRes.on("end", () => {
|
||||
const finalBody = convertEventsToOpenAiResponse(chunks);
|
||||
req.log.info(
|
||||
{ api: req.api, key: req.key?.hash },
|
||||
`Finished proxying SSE stream.`
|
||||
);
|
||||
res.end();
|
||||
resolve(finalBody);
|
||||
});
|
||||
proxyRes.on("error", (err) => {
|
||||
req.log.error(
|
||||
{ error: err, api: req.api, key: req.key?.hash },
|
||||
`Error while streaming response.`
|
||||
);
|
||||
res.end();
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
/** Copy headers, excluding ones we're already setting for the SSE response. */
|
||||
const copyHeaders = (proxyRes: http.IncomingMessage, res: Response) => {
|
||||
const toOmit = [
|
||||
"content-length",
|
||||
"content-encoding",
|
||||
"transfer-encoding",
|
||||
"content-type",
|
||||
"connection",
|
||||
"cache-control",
|
||||
];
|
||||
for (const [key, value] of Object.entries(proxyRes.headers)) {
|
||||
if (!toOmit.includes(key) && value) {
|
||||
res.setHeader(key, value);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
type OpenAiChatCompletionResponse = {
|
||||
id: string;
|
||||
object: string;
|
||||
created: number;
|
||||
model: string;
|
||||
choices: {
|
||||
message: { role: string; content: string };
|
||||
finish_reason: string | null;
|
||||
index: number;
|
||||
}[];
|
||||
};
|
||||
|
||||
/** Converts the event stream chunks into a single completion response. */
|
||||
const convertEventsToOpenAiResponse = (chunks: Buffer[]) => {
|
||||
let response: OpenAiChatCompletionResponse = {
|
||||
id: "",
|
||||
object: "",
|
||||
created: 0,
|
||||
model: "",
|
||||
choices: [],
|
||||
};
|
||||
const events = Buffer.concat(chunks)
|
||||
.toString()
|
||||
.trim()
|
||||
.split("\n\n")
|
||||
.map((line) => line.trim());
|
||||
|
||||
response = events.reduce((acc, chunk, i) => {
|
||||
if (!chunk.startsWith("data: ")) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
if (chunk === "data: [DONE]") {
|
||||
return acc;
|
||||
}
|
||||
|
||||
const data = JSON.parse(chunk.slice("data: ".length));
|
||||
if (i === 0) {
|
||||
return {
|
||||
id: data.id,
|
||||
object: data.object,
|
||||
created: data.created,
|
||||
model: data.model,
|
||||
choices: [
|
||||
{
|
||||
message: { role: data.choices[0].delta.role, content: "" },
|
||||
index: 0,
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
if (data.choices[0].delta.content) {
|
||||
acc.choices[0].message.content += data.choices[0].delta.content;
|
||||
}
|
||||
acc.choices[0].finish_reason = data.choices[0].finish_reason;
|
||||
return acc;
|
||||
}, response);
|
||||
return response;
|
||||
};
|
||||
@@ -6,6 +6,7 @@ import * as httpProxy from "http-proxy";
|
||||
import { logger } from "../../../logger";
|
||||
import { keyPool } from "../../../key-management";
|
||||
import { logPrompt } from "./log-prompt";
|
||||
import { handleStreamedResponse } from "./handle-streamed-response";
|
||||
|
||||
export const QUOTA_ROUTES = ["/v1/chat/completions"];
|
||||
const DECODER_MAP = {
|
||||
@@ -20,7 +21,11 @@ const isSupportedContentEncoding = (
|
||||
return contentEncoding in DECODER_MAP;
|
||||
};
|
||||
|
||||
type DecodeResponseBodyHandler = (
|
||||
/**
|
||||
* Either decodes or streams the entire response body and then passes it as the
|
||||
* last argument to the rest of the middleware stack.
|
||||
*/
|
||||
export type RawResponseBodyHandler = (
|
||||
proxyRes: http.IncomingMessage,
|
||||
req: Request,
|
||||
res: Response
|
||||
@@ -31,7 +36,7 @@ export type ProxyResHandlerWithBody = (
|
||||
res: Response,
|
||||
/**
|
||||
* This will be an object if the response content-type is application/json,
|
||||
* otherwise it will be a string.
|
||||
* or if the response is a streaming response. Otherwise it will be a string.
|
||||
*/
|
||||
body: string | Record<string, any>
|
||||
) => Promise<void>;
|
||||
@@ -43,6 +48,11 @@ export type ProxyResMiddleware = ProxyResHandlerWithBody[];
|
||||
* the body. Custom middleware won't execute if the response is determined to
|
||||
* be an error from the downstream service as the response will be taken over
|
||||
* by the common error handler.
|
||||
*
|
||||
* For streaming responses, the handleStream middleware will block remaining
|
||||
* middleware from executing as it consumes the stream and forwards events to
|
||||
* the client. Once the stream is closed, the finalized body will be attached
|
||||
* to res.body and the remaining middleware will execute.
|
||||
*/
|
||||
export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
|
||||
return async (
|
||||
@@ -50,25 +60,63 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
|
||||
req: Request,
|
||||
res: Response
|
||||
) => {
|
||||
let lastMiddlewareName = decodeResponseBody.name;
|
||||
try {
|
||||
const body = await decodeResponseBody(proxyRes, req, res);
|
||||
const initialHandler = req.isStreaming
|
||||
? handleStreamedResponse
|
||||
: decodeResponseBody;
|
||||
|
||||
const middlewareStack: ProxyResMiddleware = [
|
||||
handleDownstreamErrors,
|
||||
incrementKeyUsage,
|
||||
copyHttpHeaders,
|
||||
logPrompt,
|
||||
...middleware,
|
||||
];
|
||||
let lastMiddlewareName = initialHandler.name;
|
||||
|
||||
req.log.debug(
|
||||
{
|
||||
api: req.api,
|
||||
route: req.path,
|
||||
method: req.method,
|
||||
stream: req.isStreaming,
|
||||
middleware: lastMiddlewareName,
|
||||
},
|
||||
"Handling proxy response"
|
||||
);
|
||||
|
||||
try {
|
||||
const body = await initialHandler(proxyRes, req, res);
|
||||
|
||||
const middlewareStack: ProxyResMiddleware = [];
|
||||
|
||||
if (req.isStreaming) {
|
||||
// Anything that touches the response will break streaming requests so
|
||||
// certain middleware can't be used. This includes whatever API-specific
|
||||
// middleware is passed in, which isn't ideal but it's what we've got
|
||||
// for now.
|
||||
// Streamed requests will be treated as non-streaming if the upstream
|
||||
// service returns a non-200 status code, so no need to include the
|
||||
// error handler here.
|
||||
|
||||
// This is a little too easy to accidentally screw up so I need to add a
|
||||
// better way to differentiate between middleware that can be used for
|
||||
// streaming requests and those that can't. Probably a separate type
|
||||
// or function signature for streaming-compatible middleware.
|
||||
middlewareStack.push(incrementKeyUsage, logPrompt);
|
||||
} else {
|
||||
middlewareStack.push(
|
||||
handleDownstreamErrors,
|
||||
incrementKeyUsage,
|
||||
copyHttpHeaders,
|
||||
logPrompt,
|
||||
...middleware
|
||||
);
|
||||
}
|
||||
|
||||
for (const middleware of middlewareStack) {
|
||||
lastMiddlewareName = middleware.name;
|
||||
await middleware(proxyRes, req, res, body);
|
||||
}
|
||||
} catch (error: any) {
|
||||
// downstream errors will have already been responded to
|
||||
if (res.headersSent) {
|
||||
req.log.error(
|
||||
`Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`
|
||||
);
|
||||
// Either the downstream error handler got to it first, or we're mid-
|
||||
// stream and we can't do anything about it.
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -94,11 +142,19 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
|
||||
* object. Otherwise, it will be returned as a string.
|
||||
* @throws {Error} Unsupported content-encoding or invalid application/json body
|
||||
*/
|
||||
const decodeResponseBody: DecodeResponseBodyHandler = async (
|
||||
export const decodeResponseBody: RawResponseBodyHandler = async (
|
||||
proxyRes,
|
||||
req,
|
||||
res
|
||||
) => {
|
||||
if (req.isStreaming) {
|
||||
req.log.error(
|
||||
{ api: req.api, key: req.key?.hash },
|
||||
`decodeResponseBody called for a streaming request, which isn't valid.`
|
||||
);
|
||||
throw new Error("decodeResponseBody called for a streaming request.");
|
||||
}
|
||||
|
||||
const promise = new Promise<string>((resolve, reject) => {
|
||||
let chunks: Buffer[] = [];
|
||||
proxyRes.on("data", (chunk) => chunks.push(chunk));
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import { config } from "../../../config";
|
||||
import { logQueue } from "../../../prompt-logging";
|
||||
import { isCompletionRequest } from "../request";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
|
||||
const COMPLETE_ENDPOINT = "/v1/chat/completions";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
@@ -18,9 +17,8 @@ export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
|
||||
// Only log prompts if we're making a request to a completion endpoint
|
||||
if (!req.path.startsWith(COMPLETE_ENDPOINT)) {
|
||||
if (!isCompletionRequest(req)) {
|
||||
// Remove this once we're confident that we're not missing any prompts
|
||||
req.log.info(
|
||||
`Not logging prompt for ${req.path} because it's not a completion endpoint`
|
||||
|
||||
Reference in New Issue
Block a user