Implement support for streamed OpenAI responses (khanon/oai-reverse-proxy!4)

This commit is contained in:
nai-degen
2023-05-01 22:01:47 +00:00
parent 2b783e0f2b
commit 176a37928d
16 changed files with 310 additions and 115 deletions
-70
View File
@@ -18,7 +18,6 @@
"pino": "^8.11.0",
"pino-http": "^8.3.3",
"showdown": "^2.1.0",
"simple-git": "^3.17.0",
"zlib": "^1.0.5"
},
"devDependencies": {
@@ -426,40 +425,6 @@
"@jridgewell/sourcemap-codec": "^1.4.10"
}
},
"node_modules/@kwsites/file-exists": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@kwsites/file-exists/-/file-exists-1.1.1.tgz",
"integrity": "sha512-m9/5YGR18lIwxSFDwfE3oA7bWuq9kdau6ugN4H2rJeyhFQZcG9AgSHkQtSD15a8WvTgfz9aikZMrKPHvbpqFiw==",
"dependencies": {
"debug": "^4.1.1"
}
},
"node_modules/@kwsites/file-exists/node_modules/debug": {
"version": "4.3.4",
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz",
"integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==",
"dependencies": {
"ms": "2.1.2"
},
"engines": {
"node": ">=6.0"
},
"peerDependenciesMeta": {
"supports-color": {
"optional": true
}
}
},
"node_modules/@kwsites/file-exists/node_modules/ms": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz",
"integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w=="
},
"node_modules/@kwsites/promise-deferred": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@kwsites/promise-deferred/-/promise-deferred-1.1.1.tgz",
"integrity": "sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw=="
},
"node_modules/@tsconfig/node10": {
"version": "1.0.9",
"resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.9.tgz",
@@ -2438,41 +2403,6 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/simple-git": {
"version": "3.17.0",
"resolved": "https://registry.npmjs.org/simple-git/-/simple-git-3.17.0.tgz",
"integrity": "sha512-JozI/s8jr3nvLd9yn2jzPVHnhVzt7t7QWfcIoDcqRIGN+f1IINGv52xoZti2kkYfoRhhRvzMSNPfogHMp97rlw==",
"dependencies": {
"@kwsites/file-exists": "^1.1.1",
"@kwsites/promise-deferred": "^1.1.1",
"debug": "^4.3.4"
},
"funding": {
"type": "github",
"url": "https://github.com/steveukx/git-js?sponsor=1"
}
},
"node_modules/simple-git/node_modules/debug": {
"version": "4.3.4",
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz",
"integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==",
"dependencies": {
"ms": "2.1.2"
},
"engines": {
"node": ">=6.0"
},
"peerDependenciesMeta": {
"supports-color": {
"optional": true
}
}
},
"node_modules/simple-git/node_modules/ms": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz",
"integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w=="
},
"node_modules/simple-update-notifier": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/simple-update-notifier/-/simple-update-notifier-1.1.0.tgz",
-1
View File
@@ -27,7 +27,6 @@
"pino": "^8.11.0",
"pino-http": "^8.3.3",
"showdown": "^2.1.0",
"simple-git": "^3.17.0",
"zlib": "^1.0.5"
},
"devDependencies": {
+3
View File
@@ -34,6 +34,8 @@ type Config = {
googleSheetsSpreadsheetId?: string;
/** Whether to periodically check keys for usage and validity. */
checkKeys?: boolean;
/** Whether to allow streaming completions. This is usually fine but can cause issues on some deployments. */
allowStreaming?: boolean;
};
// To change configs, create a file called .env in the root directory.
@@ -59,6 +61,7 @@ export const config: Config = {
"GOOGLE_SHEETS_SPREADSHEET_ID",
undefined
),
allowStreaming: getEnvWithDefault("ALLOW_STREAMING", true),
} as const;
export const SENSITIVE_KEYS: (keyof Config)[] = [
+1 -1
View File
@@ -52,7 +52,7 @@ function getInfoPageHtml(host: string) {
...(config.modelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
keyInfo,
config: listConfig(),
sha: process.env.COMMIT_SHA?.slice(0, 7) || "dev",
commitSha: process.env.COMMIT_SHA || "dev",
};
const title = process.env.SPACE_ID
+2 -2
View File
@@ -9,7 +9,7 @@ import { logger } from "../logger";
import { ipLimiter } from "./rate-limit";
import {
addKey,
disableStream,
checkStreaming,
finalizeBody,
languageFilter,
limitOutputTokens,
@@ -39,7 +39,7 @@ const rewriteRequest = (
addKey,
transformKoboldPayload,
languageFilter,
disableStream,
checkStreaming,
limitOutputTokens,
finalizeBody,
];
@@ -0,0 +1,25 @@
import { config } from "../../../config";
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
/**
* If a stream is requested, mark the request as such so the response middleware
* knows to use the alternate EventSource response handler.
* Kobold requests can't currently be streamed as they use a different event
* format than the OpenAI API and we need to rewrite the events as they come in,
* which I have not yet implemented.
*/
export const checkStreaming: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
const streamableApi = req.api !== "kobold";
if (isCompletionRequest(req) && req.body?.stream) {
if (!streamableApi) {
req.log.warn(
{ api: req.api, key: req.key?.hash },
`Streaming requested, but ${req.api} streaming is not supported.`
);
req.body.stream = false;
return;
}
req.body.stream = config.allowStreaming;
req.isStreaming = config.allowStreaming;
}
};
@@ -1,8 +0,0 @@
import type { ExpressHttpProxyReqCallback } from ".";
/** Disable token streaming as the proxy middleware doesn't support it. */
export const disableStream: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
if (req.method === "POST" && req.body && req.body.stream) {
req.body.stream = false;
}
};
+11 -1
View File
@@ -3,13 +3,23 @@ import type { ClientRequest } from "http";
import type { ProxyReqCallback } from "http-proxy";
export { addKey } from "./add-key";
export { disableStream } from "./disable-stream";
export { checkStreaming } from "./check-streaming";
export { finalizeBody } from "./finalize-body";
export { languageFilter } from "./language-filter";
export { limitCompletions } from "./limit-completions";
export { limitOutputTokens } from "./limit-output-tokens";
export { transformKoboldPayload } from "./transform-kobold-payload";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
/** Returns true if we're making a chat completion request. */
export function isCompletionRequest(req: Request) {
return (
req.method === "POST" &&
req.path.startsWith(OPENAI_CHAT_COMPLETION_ENDPOINT)
);
}
export type ExpressHttpProxyReqCallback = ProxyReqCallback<
ClientRequest,
Request
@@ -1,13 +1,11 @@
import type { ExpressHttpProxyReqCallback } from ".";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
/** Don't allow multiple completions to be requested to prevent abuse. */
export const limitCompletions: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (req.method === "POST" && req.path === OPENAI_CHAT_COMPLETION_ENDPOINT) {
if (isCompletionRequest(req)) {
const originalN = req.body?.n || 1;
req.body.n = 1;
if (originalN !== req.body.n) {
@@ -1,6 +1,6 @@
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { ExpressHttpProxyReqCallback } from ".";
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
const MAX_TOKENS = config.maxOutputTokens;
@@ -9,7 +9,7 @@ export const limitOutputTokens: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (req.method === "POST" && req.body?.max_tokens) {
if (isCompletionRequest(req) && req.body?.max_tokens) {
// convert bad or missing input to a MAX_TOKENS
if (typeof req.body.max_tokens !== "number") {
logger.warn(
@@ -0,0 +1,151 @@
import { Response } from "express";
import * as http from "http";
import { RawResponseBodyHandler, decodeResponseBody } from ".";
/**
* Consume the SSE stream and forward events to the client. Once the stream is
* stream is closed, resolve with the full response body so that subsequent
* middleware can work with it.
*
* Typically we would only need of the raw response handlers to execute, but
* in the event a streamed request results in a non-200 response, we need to
* fall back to the non-streaming response handler so that the error handler
* can inspect the error response.
*/
export const handleStreamedResponse: RawResponseBodyHandler = async (
proxyRes,
req,
res
) => {
if (!req.isStreaming) {
req.log.error(
{ api: req.api, key: req.key?.hash },
`handleEventSource called for non-streaming request, which isn't valid.`
);
throw new Error("handleEventSource called for non-streaming request.");
}
if (proxyRes.statusCode !== 200) {
// Ensure we use the non-streaming middleware stack since we won't be
// getting any events.
req.isStreaming = false;
req.log.warn(
`Streaming request to ${req.api} returned ${proxyRes.statusCode} status code. Falling back to non-streaming response handler.`
);
return decodeResponseBody(proxyRes, req, res);
}
return new Promise((resolve, reject) => {
req.log.info(
{ api: req.api, key: req.key?.hash },
`Starting to proxy SSE stream.`
);
res.setHeader("Content-Type", "text/event-stream");
res.setHeader("Cache-Control", "no-cache");
res.setHeader("Connection", "keep-alive");
copyHeaders(proxyRes, res);
const chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => {
chunks.push(chunk);
res.write(chunk);
});
proxyRes.on("end", () => {
const finalBody = convertEventsToOpenAiResponse(chunks);
req.log.info(
{ api: req.api, key: req.key?.hash },
`Finished proxying SSE stream.`
);
res.end();
resolve(finalBody);
});
proxyRes.on("error", (err) => {
req.log.error(
{ error: err, api: req.api, key: req.key?.hash },
`Error while streaming response.`
);
res.end();
reject(err);
});
});
};
/** Copy headers, excluding ones we're already setting for the SSE response. */
const copyHeaders = (proxyRes: http.IncomingMessage, res: Response) => {
const toOmit = [
"content-length",
"content-encoding",
"transfer-encoding",
"content-type",
"connection",
"cache-control",
];
for (const [key, value] of Object.entries(proxyRes.headers)) {
if (!toOmit.includes(key) && value) {
res.setHeader(key, value);
}
}
};
type OpenAiChatCompletionResponse = {
id: string;
object: string;
created: number;
model: string;
choices: {
message: { role: string; content: string };
finish_reason: string | null;
index: number;
}[];
};
/** Converts the event stream chunks into a single completion response. */
const convertEventsToOpenAiResponse = (chunks: Buffer[]) => {
let response: OpenAiChatCompletionResponse = {
id: "",
object: "",
created: 0,
model: "",
choices: [],
};
const events = Buffer.concat(chunks)
.toString()
.trim()
.split("\n\n")
.map((line) => line.trim());
response = events.reduce((acc, chunk, i) => {
if (!chunk.startsWith("data: ")) {
return acc;
}
if (chunk === "data: [DONE]") {
return acc;
}
const data = JSON.parse(chunk.slice("data: ".length));
if (i === 0) {
return {
id: data.id,
object: data.object,
created: data.created,
model: data.model,
choices: [
{
message: { role: data.choices[0].delta.role, content: "" },
index: 0,
finish_reason: null,
},
],
};
}
if (data.choices[0].delta.content) {
acc.choices[0].message.content += data.choices[0].delta.content;
}
acc.choices[0].finish_reason = data.choices[0].finish_reason;
return acc;
}, response);
return response;
};
+70 -14
View File
@@ -6,6 +6,7 @@ import * as httpProxy from "http-proxy";
import { logger } from "../../../logger";
import { keyPool } from "../../../key-management";
import { logPrompt } from "./log-prompt";
import { handleStreamedResponse } from "./handle-streamed-response";
export const QUOTA_ROUTES = ["/v1/chat/completions"];
const DECODER_MAP = {
@@ -20,7 +21,11 @@ const isSupportedContentEncoding = (
return contentEncoding in DECODER_MAP;
};
type DecodeResponseBodyHandler = (
/**
* Either decodes or streams the entire response body and then passes it as the
* last argument to the rest of the middleware stack.
*/
export type RawResponseBodyHandler = (
proxyRes: http.IncomingMessage,
req: Request,
res: Response
@@ -31,7 +36,7 @@ export type ProxyResHandlerWithBody = (
res: Response,
/**
* This will be an object if the response content-type is application/json,
* otherwise it will be a string.
* or if the response is a streaming response. Otherwise it will be a string.
*/
body: string | Record<string, any>
) => Promise<void>;
@@ -43,6 +48,11 @@ export type ProxyResMiddleware = ProxyResHandlerWithBody[];
* the body. Custom middleware won't execute if the response is determined to
* be an error from the downstream service as the response will be taken over
* by the common error handler.
*
* For streaming responses, the handleStream middleware will block remaining
* middleware from executing as it consumes the stream and forwards events to
* the client. Once the stream is closed, the finalized body will be attached
* to res.body and the remaining middleware will execute.
*/
export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
return async (
@@ -50,25 +60,63 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
req: Request,
res: Response
) => {
let lastMiddlewareName = decodeResponseBody.name;
try {
const body = await decodeResponseBody(proxyRes, req, res);
const initialHandler = req.isStreaming
? handleStreamedResponse
: decodeResponseBody;
const middlewareStack: ProxyResMiddleware = [
handleDownstreamErrors,
incrementKeyUsage,
copyHttpHeaders,
logPrompt,
...middleware,
];
let lastMiddlewareName = initialHandler.name;
req.log.debug(
{
api: req.api,
route: req.path,
method: req.method,
stream: req.isStreaming,
middleware: lastMiddlewareName,
},
"Handling proxy response"
);
try {
const body = await initialHandler(proxyRes, req, res);
const middlewareStack: ProxyResMiddleware = [];
if (req.isStreaming) {
// Anything that touches the response will break streaming requests so
// certain middleware can't be used. This includes whatever API-specific
// middleware is passed in, which isn't ideal but it's what we've got
// for now.
// Streamed requests will be treated as non-streaming if the upstream
// service returns a non-200 status code, so no need to include the
// error handler here.
// This is a little too easy to accidentally screw up so I need to add a
// better way to differentiate between middleware that can be used for
// streaming requests and those that can't. Probably a separate type
// or function signature for streaming-compatible middleware.
middlewareStack.push(incrementKeyUsage, logPrompt);
} else {
middlewareStack.push(
handleDownstreamErrors,
incrementKeyUsage,
copyHttpHeaders,
logPrompt,
...middleware
);
}
for (const middleware of middlewareStack) {
lastMiddlewareName = middleware.name;
await middleware(proxyRes, req, res, body);
}
} catch (error: any) {
// downstream errors will have already been responded to
if (res.headersSent) {
req.log.error(
`Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`
);
// Either the downstream error handler got to it first, or we're mid-
// stream and we can't do anything about it.
return;
}
@@ -94,11 +142,19 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
* object. Otherwise, it will be returned as a string.
* @throws {Error} Unsupported content-encoding or invalid application/json body
*/
const decodeResponseBody: DecodeResponseBodyHandler = async (
export const decodeResponseBody: RawResponseBodyHandler = async (
proxyRes,
req,
res
) => {
if (req.isStreaming) {
req.log.error(
{ api: req.api, key: req.key?.hash },
`decodeResponseBody called for a streaming request, which isn't valid.`
);
throw new Error("decodeResponseBody called for a streaming request.");
}
const promise = new Promise<string>((resolve, reject) => {
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
+2 -4
View File
@@ -1,9 +1,8 @@
import { config } from "../../../config";
import { logQueue } from "../../../prompt-logging";
import { isCompletionRequest } from "../request";
import { ProxyResHandlerWithBody } from ".";
const COMPLETE_ENDPOINT = "/v1/chat/completions";
/** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async (
_proxyRes,
@@ -18,9 +17,8 @@ export const logPrompt: ProxyResHandlerWithBody = async (
throw new Error("Expected body to be an object");
}
// Only log prompts if we're making a request to a completion endpoint
if (!req.path.startsWith(COMPLETE_ENDPOINT)) {
if (!isCompletionRequest(req)) {
// Remove this once we're confident that we're not missing any prompts
req.log.info(
`Not logging prompt for ${req.path} because it's not a completion endpoint`
+2 -2
View File
@@ -7,7 +7,7 @@ import { ipLimiter } from "./rate-limit";
import {
addKey,
languageFilter,
disableStream,
checkStreaming,
finalizeBody,
limitOutputTokens,
limitCompletions,
@@ -27,7 +27,7 @@ const rewriteRequest = (
const rewriterPipeline = [
addKey,
languageFilter,
disableStream,
checkStreaming,
limitOutputTokens,
limitCompletions,
finalizeBody,
+38 -6
View File
@@ -3,7 +3,7 @@ import "source-map-support/register";
import express from "express";
import cors from "cors";
import pinoHttp from "pino-http";
import { simpleGit } from "simple-git";
import childProcess from "child_process";
import { logger } from "./logger";
import { keyPool } from "./key-management";
import { proxyRouter, rewriteTavernRequests } from "./proxy/routes";
@@ -57,11 +57,43 @@ app.use((_req: unknown, res: express.Response) => {
// start server and load keys
app.listen(PORT, async () => {
try {
const git = simpleGit();
const log = git.log({ n: 1 });
const sha = (await log).latest!.hash;
process.env.COMMIT_SHA = sha;
} catch (error) {
// Huggingface seems to have changed something about how they deploy Spaces
// and git commands fail because of some ownership issue with the .git
// directory. This is a hacky workaround, but we only want to run it on
// deployed instances.
if (process.env.NODE_ENV === "production") {
childProcess.execSync("git config --global --add safe.directory /app");
}
const sha = childProcess
.execSync("git rev-parse --short HEAD")
.toString()
.trim();
const status = childProcess
.execSync("git status --porcelain")
.toString()
.trim()
// ignore Dockerfile changes since that's how the user deploys the app
.split("\n")
.filter((line: string) => !line.endsWith("Dockerfile"));
const changes = status.length > 0;
logger.info({ sha, status, changes }, "Got commit SHA and status.");
process.env.COMMIT_SHA = `${sha}${changes ? " (modified)" : ""}`;
} catch (error: any) {
logger.error(
{
error,
stdout: error.stdout.toString(),
stderr: error.stderr.toString(),
},
"Failed to get commit SHA.",
error
);
process.env.COMMIT_SHA = "unknown";
}
+1
View File
@@ -6,6 +6,7 @@ declare global {
interface Request {
key?: Key;
api: "kobold" | "openai" | "anthropic";
isStreaming?: boolean;
}
}
}