diff --git a/package-lock.json b/package-lock.json index c991bcc..4657596 100644 --- a/package-lock.json +++ b/package-lock.json @@ -18,7 +18,6 @@ "pino": "^8.11.0", "pino-http": "^8.3.3", "showdown": "^2.1.0", - "simple-git": "^3.17.0", "zlib": "^1.0.5" }, "devDependencies": { @@ -426,40 +425,6 @@ "@jridgewell/sourcemap-codec": "^1.4.10" } }, - "node_modules/@kwsites/file-exists": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/@kwsites/file-exists/-/file-exists-1.1.1.tgz", - "integrity": "sha512-m9/5YGR18lIwxSFDwfE3oA7bWuq9kdau6ugN4H2rJeyhFQZcG9AgSHkQtSD15a8WvTgfz9aikZMrKPHvbpqFiw==", - "dependencies": { - "debug": "^4.1.1" - } - }, - "node_modules/@kwsites/file-exists/node_modules/debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", - "dependencies": { - "ms": "2.1.2" - }, - "engines": { - "node": ">=6.0" - }, - "peerDependenciesMeta": { - "supports-color": { - "optional": true - } - } - }, - "node_modules/@kwsites/file-exists/node_modules/ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" - }, - "node_modules/@kwsites/promise-deferred": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/@kwsites/promise-deferred/-/promise-deferred-1.1.1.tgz", - "integrity": "sha512-GaHYm+c0O9MjZRu0ongGBRbinu8gVAMd2UZjji6jVmqKtZluZnptXGWhz1E8j8D2HJ3f/yMxKAUC0b+57wncIw==" - }, "node_modules/@tsconfig/node10": { "version": "1.0.9", "resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.9.tgz", @@ -2438,41 +2403,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/simple-git": { - "version": "3.17.0", - "resolved": "https://registry.npmjs.org/simple-git/-/simple-git-3.17.0.tgz", - "integrity": "sha512-JozI/s8jr3nvLd9yn2jzPVHnhVzt7t7QWfcIoDcqRIGN+f1IINGv52xoZti2kkYfoRhhRvzMSNPfogHMp97rlw==", - "dependencies": { - "@kwsites/file-exists": "^1.1.1", - "@kwsites/promise-deferred": "^1.1.1", - "debug": "^4.3.4" - }, - "funding": { - "type": "github", - "url": "https://github.com/steveukx/git-js?sponsor=1" - } - }, - "node_modules/simple-git/node_modules/debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", - "dependencies": { - "ms": "2.1.2" - }, - "engines": { - "node": ">=6.0" - }, - "peerDependenciesMeta": { - "supports-color": { - "optional": true - } - } - }, - "node_modules/simple-git/node_modules/ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" - }, "node_modules/simple-update-notifier": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/simple-update-notifier/-/simple-update-notifier-1.1.0.tgz", diff --git a/package.json b/package.json index d027f1f..6134de1 100644 --- a/package.json +++ b/package.json @@ -27,7 +27,6 @@ "pino": "^8.11.0", "pino-http": "^8.3.3", "showdown": "^2.1.0", - "simple-git": "^3.17.0", "zlib": "^1.0.5" }, "devDependencies": { diff --git a/src/config.ts b/src/config.ts index 49e48ba..144807c 100644 --- a/src/config.ts +++ b/src/config.ts @@ -34,6 +34,8 @@ type Config = { googleSheetsSpreadsheetId?: string; /** Whether to periodically check keys for usage and validity. */ checkKeys?: boolean; + /** Whether to allow streaming completions. This is usually fine but can cause issues on some deployments. */ + allowStreaming?: boolean; }; // To change configs, create a file called .env in the root directory. @@ -59,6 +61,7 @@ export const config: Config = { "GOOGLE_SHEETS_SPREADSHEET_ID", undefined ), + allowStreaming: getEnvWithDefault("ALLOW_STREAMING", true), } as const; export const SENSITIVE_KEYS: (keyof Config)[] = [ diff --git a/src/info-page.ts b/src/info-page.ts index 4a2f659..008107b 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -52,7 +52,7 @@ function getInfoPageHtml(host: string) { ...(config.modelRateLimit ? { proomptersNow: getUniqueIps() } : {}), keyInfo, config: listConfig(), - sha: process.env.COMMIT_SHA?.slice(0, 7) || "dev", + commitSha: process.env.COMMIT_SHA || "dev", }; const title = process.env.SPACE_ID diff --git a/src/proxy/kobold.ts b/src/proxy/kobold.ts index 0a9e743..732b924 100644 --- a/src/proxy/kobold.ts +++ b/src/proxy/kobold.ts @@ -9,7 +9,7 @@ import { logger } from "../logger"; import { ipLimiter } from "./rate-limit"; import { addKey, - disableStream, + checkStreaming, finalizeBody, languageFilter, limitOutputTokens, @@ -39,7 +39,7 @@ const rewriteRequest = ( addKey, transformKoboldPayload, languageFilter, - disableStream, + checkStreaming, limitOutputTokens, finalizeBody, ]; diff --git a/src/proxy/middleware/request/check-streaming.ts b/src/proxy/middleware/request/check-streaming.ts new file mode 100644 index 0000000..858d4f8 --- /dev/null +++ b/src/proxy/middleware/request/check-streaming.ts @@ -0,0 +1,25 @@ +import { config } from "../../../config"; +import { ExpressHttpProxyReqCallback, isCompletionRequest } from "."; + +/** + * If a stream is requested, mark the request as such so the response middleware + * knows to use the alternate EventSource response handler. + * Kobold requests can't currently be streamed as they use a different event + * format than the OpenAI API and we need to rewrite the events as they come in, + * which I have not yet implemented. + */ +export const checkStreaming: ExpressHttpProxyReqCallback = (_proxyReq, req) => { + const streamableApi = req.api !== "kobold"; + if (isCompletionRequest(req) && req.body?.stream) { + if (!streamableApi) { + req.log.warn( + { api: req.api, key: req.key?.hash }, + `Streaming requested, but ${req.api} streaming is not supported.` + ); + req.body.stream = false; + return; + } + req.body.stream = config.allowStreaming; + req.isStreaming = config.allowStreaming; + } +}; diff --git a/src/proxy/middleware/request/disable-stream.ts b/src/proxy/middleware/request/disable-stream.ts deleted file mode 100644 index 8b49a8c..0000000 --- a/src/proxy/middleware/request/disable-stream.ts +++ /dev/null @@ -1,8 +0,0 @@ -import type { ExpressHttpProxyReqCallback } from "."; - -/** Disable token streaming as the proxy middleware doesn't support it. */ -export const disableStream: ExpressHttpProxyReqCallback = (_proxyReq, req) => { - if (req.method === "POST" && req.body && req.body.stream) { - req.body.stream = false; - } -}; diff --git a/src/proxy/middleware/request/index.ts b/src/proxy/middleware/request/index.ts index ffd01b5..a8a4c00 100644 --- a/src/proxy/middleware/request/index.ts +++ b/src/proxy/middleware/request/index.ts @@ -3,13 +3,23 @@ import type { ClientRequest } from "http"; import type { ProxyReqCallback } from "http-proxy"; export { addKey } from "./add-key"; -export { disableStream } from "./disable-stream"; +export { checkStreaming } from "./check-streaming"; export { finalizeBody } from "./finalize-body"; export { languageFilter } from "./language-filter"; export { limitCompletions } from "./limit-completions"; export { limitOutputTokens } from "./limit-output-tokens"; export { transformKoboldPayload } from "./transform-kobold-payload"; +const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions"; + +/** Returns true if we're making a chat completion request. */ +export function isCompletionRequest(req: Request) { + return ( + req.method === "POST" && + req.path.startsWith(OPENAI_CHAT_COMPLETION_ENDPOINT) + ); +} + export type ExpressHttpProxyReqCallback = ProxyReqCallback< ClientRequest, Request diff --git a/src/proxy/middleware/request/limit-completions.ts b/src/proxy/middleware/request/limit-completions.ts index 32883b4..0261b06 100644 --- a/src/proxy/middleware/request/limit-completions.ts +++ b/src/proxy/middleware/request/limit-completions.ts @@ -1,13 +1,11 @@ -import type { ExpressHttpProxyReqCallback } from "."; - -const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions"; +import { ExpressHttpProxyReqCallback, isCompletionRequest } from "."; /** Don't allow multiple completions to be requested to prevent abuse. */ export const limitCompletions: ExpressHttpProxyReqCallback = ( _proxyReq, req ) => { - if (req.method === "POST" && req.path === OPENAI_CHAT_COMPLETION_ENDPOINT) { + if (isCompletionRequest(req)) { const originalN = req.body?.n || 1; req.body.n = 1; if (originalN !== req.body.n) { diff --git a/src/proxy/middleware/request/limit-output-tokens.ts b/src/proxy/middleware/request/limit-output-tokens.ts index 8329ac2..91f91e1 100644 --- a/src/proxy/middleware/request/limit-output-tokens.ts +++ b/src/proxy/middleware/request/limit-output-tokens.ts @@ -1,6 +1,6 @@ import { config } from "../../../config"; import { logger } from "../../../logger"; -import type { ExpressHttpProxyReqCallback } from "."; +import { ExpressHttpProxyReqCallback, isCompletionRequest } from "."; const MAX_TOKENS = config.maxOutputTokens; @@ -9,7 +9,7 @@ export const limitOutputTokens: ExpressHttpProxyReqCallback = ( _proxyReq, req ) => { - if (req.method === "POST" && req.body?.max_tokens) { + if (isCompletionRequest(req) && req.body?.max_tokens) { // convert bad or missing input to a MAX_TOKENS if (typeof req.body.max_tokens !== "number") { logger.warn( diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts new file mode 100644 index 0000000..c08c335 --- /dev/null +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -0,0 +1,151 @@ +import { Response } from "express"; +import * as http from "http"; +import { RawResponseBodyHandler, decodeResponseBody } from "."; + +/** + * Consume the SSE stream and forward events to the client. Once the stream is + * stream is closed, resolve with the full response body so that subsequent + * middleware can work with it. + * + * Typically we would only need of the raw response handlers to execute, but + * in the event a streamed request results in a non-200 response, we need to + * fall back to the non-streaming response handler so that the error handler + * can inspect the error response. + */ +export const handleStreamedResponse: RawResponseBodyHandler = async ( + proxyRes, + req, + res +) => { + if (!req.isStreaming) { + req.log.error( + { api: req.api, key: req.key?.hash }, + `handleEventSource called for non-streaming request, which isn't valid.` + ); + throw new Error("handleEventSource called for non-streaming request."); + } + + if (proxyRes.statusCode !== 200) { + // Ensure we use the non-streaming middleware stack since we won't be + // getting any events. + req.isStreaming = false; + req.log.warn( + `Streaming request to ${req.api} returned ${proxyRes.statusCode} status code. Falling back to non-streaming response handler.` + ); + return decodeResponseBody(proxyRes, req, res); + } + + return new Promise((resolve, reject) => { + req.log.info( + { api: req.api, key: req.key?.hash }, + `Starting to proxy SSE stream.` + ); + res.setHeader("Content-Type", "text/event-stream"); + res.setHeader("Cache-Control", "no-cache"); + res.setHeader("Connection", "keep-alive"); + copyHeaders(proxyRes, res); + + const chunks: Buffer[] = []; + proxyRes.on("data", (chunk) => { + chunks.push(chunk); + res.write(chunk); + }); + + proxyRes.on("end", () => { + const finalBody = convertEventsToOpenAiResponse(chunks); + req.log.info( + { api: req.api, key: req.key?.hash }, + `Finished proxying SSE stream.` + ); + res.end(); + resolve(finalBody); + }); + proxyRes.on("error", (err) => { + req.log.error( + { error: err, api: req.api, key: req.key?.hash }, + `Error while streaming response.` + ); + res.end(); + reject(err); + }); + }); +}; + +/** Copy headers, excluding ones we're already setting for the SSE response. */ +const copyHeaders = (proxyRes: http.IncomingMessage, res: Response) => { + const toOmit = [ + "content-length", + "content-encoding", + "transfer-encoding", + "content-type", + "connection", + "cache-control", + ]; + for (const [key, value] of Object.entries(proxyRes.headers)) { + if (!toOmit.includes(key) && value) { + res.setHeader(key, value); + } + } +}; + +type OpenAiChatCompletionResponse = { + id: string; + object: string; + created: number; + model: string; + choices: { + message: { role: string; content: string }; + finish_reason: string | null; + index: number; + }[]; +}; + +/** Converts the event stream chunks into a single completion response. */ +const convertEventsToOpenAiResponse = (chunks: Buffer[]) => { + let response: OpenAiChatCompletionResponse = { + id: "", + object: "", + created: 0, + model: "", + choices: [], + }; + const events = Buffer.concat(chunks) + .toString() + .trim() + .split("\n\n") + .map((line) => line.trim()); + + response = events.reduce((acc, chunk, i) => { + if (!chunk.startsWith("data: ")) { + return acc; + } + + if (chunk === "data: [DONE]") { + return acc; + } + + const data = JSON.parse(chunk.slice("data: ".length)); + if (i === 0) { + return { + id: data.id, + object: data.object, + created: data.created, + model: data.model, + choices: [ + { + message: { role: data.choices[0].delta.role, content: "" }, + index: 0, + finish_reason: null, + }, + ], + }; + } + + if (data.choices[0].delta.content) { + acc.choices[0].message.content += data.choices[0].delta.content; + } + acc.choices[0].finish_reason = data.choices[0].finish_reason; + return acc; + }, response); + return response; +}; diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index 48e5e85..547fe71 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -6,6 +6,7 @@ import * as httpProxy from "http-proxy"; import { logger } from "../../../logger"; import { keyPool } from "../../../key-management"; import { logPrompt } from "./log-prompt"; +import { handleStreamedResponse } from "./handle-streamed-response"; export const QUOTA_ROUTES = ["/v1/chat/completions"]; const DECODER_MAP = { @@ -20,7 +21,11 @@ const isSupportedContentEncoding = ( return contentEncoding in DECODER_MAP; }; -type DecodeResponseBodyHandler = ( +/** + * Either decodes or streams the entire response body and then passes it as the + * last argument to the rest of the middleware stack. + */ +export type RawResponseBodyHandler = ( proxyRes: http.IncomingMessage, req: Request, res: Response @@ -31,7 +36,7 @@ export type ProxyResHandlerWithBody = ( res: Response, /** * This will be an object if the response content-type is application/json, - * otherwise it will be a string. + * or if the response is a streaming response. Otherwise it will be a string. */ body: string | Record ) => Promise; @@ -43,6 +48,11 @@ export type ProxyResMiddleware = ProxyResHandlerWithBody[]; * the body. Custom middleware won't execute if the response is determined to * be an error from the downstream service as the response will be taken over * by the common error handler. + * + * For streaming responses, the handleStream middleware will block remaining + * middleware from executing as it consumes the stream and forwards events to + * the client. Once the stream is closed, the finalized body will be attached + * to res.body and the remaining middleware will execute. */ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => { return async ( @@ -50,25 +60,63 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => { req: Request, res: Response ) => { - let lastMiddlewareName = decodeResponseBody.name; - try { - const body = await decodeResponseBody(proxyRes, req, res); + const initialHandler = req.isStreaming + ? handleStreamedResponse + : decodeResponseBody; - const middlewareStack: ProxyResMiddleware = [ - handleDownstreamErrors, - incrementKeyUsage, - copyHttpHeaders, - logPrompt, - ...middleware, - ]; + let lastMiddlewareName = initialHandler.name; + + req.log.debug( + { + api: req.api, + route: req.path, + method: req.method, + stream: req.isStreaming, + middleware: lastMiddlewareName, + }, + "Handling proxy response" + ); + + try { + const body = await initialHandler(proxyRes, req, res); + + const middlewareStack: ProxyResMiddleware = []; + + if (req.isStreaming) { + // Anything that touches the response will break streaming requests so + // certain middleware can't be used. This includes whatever API-specific + // middleware is passed in, which isn't ideal but it's what we've got + // for now. + // Streamed requests will be treated as non-streaming if the upstream + // service returns a non-200 status code, so no need to include the + // error handler here. + + // This is a little too easy to accidentally screw up so I need to add a + // better way to differentiate between middleware that can be used for + // streaming requests and those that can't. Probably a separate type + // or function signature for streaming-compatible middleware. + middlewareStack.push(incrementKeyUsage, logPrompt); + } else { + middlewareStack.push( + handleDownstreamErrors, + incrementKeyUsage, + copyHttpHeaders, + logPrompt, + ...middleware + ); + } for (const middleware of middlewareStack) { lastMiddlewareName = middleware.name; await middleware(proxyRes, req, res, body); } } catch (error: any) { - // downstream errors will have already been responded to if (res.headersSent) { + req.log.error( + `Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})` + ); + // Either the downstream error handler got to it first, or we're mid- + // stream and we can't do anything about it. return; } @@ -94,11 +142,19 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => { * object. Otherwise, it will be returned as a string. * @throws {Error} Unsupported content-encoding or invalid application/json body */ -const decodeResponseBody: DecodeResponseBodyHandler = async ( +export const decodeResponseBody: RawResponseBodyHandler = async ( proxyRes, req, res ) => { + if (req.isStreaming) { + req.log.error( + { api: req.api, key: req.key?.hash }, + `decodeResponseBody called for a streaming request, which isn't valid.` + ); + throw new Error("decodeResponseBody called for a streaming request."); + } + const promise = new Promise((resolve, reject) => { let chunks: Buffer[] = []; proxyRes.on("data", (chunk) => chunks.push(chunk)); diff --git a/src/proxy/middleware/response/log-prompt.ts b/src/proxy/middleware/response/log-prompt.ts index 3d07c09..32ec1af 100644 --- a/src/proxy/middleware/response/log-prompt.ts +++ b/src/proxy/middleware/response/log-prompt.ts @@ -1,9 +1,8 @@ import { config } from "../../../config"; import { logQueue } from "../../../prompt-logging"; +import { isCompletionRequest } from "../request"; import { ProxyResHandlerWithBody } from "."; -const COMPLETE_ENDPOINT = "/v1/chat/completions"; - /** If prompt logging is enabled, enqueues the prompt for logging. */ export const logPrompt: ProxyResHandlerWithBody = async ( _proxyRes, @@ -18,9 +17,8 @@ export const logPrompt: ProxyResHandlerWithBody = async ( throw new Error("Expected body to be an object"); } - // Only log prompts if we're making a request to a completion endpoint - if (!req.path.startsWith(COMPLETE_ENDPOINT)) { + if (!isCompletionRequest(req)) { // Remove this once we're confident that we're not missing any prompts req.log.info( `Not logging prompt for ${req.path} because it's not a completion endpoint` diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index 346550d..660b36c 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -7,7 +7,7 @@ import { ipLimiter } from "./rate-limit"; import { addKey, languageFilter, - disableStream, + checkStreaming, finalizeBody, limitOutputTokens, limitCompletions, @@ -27,7 +27,7 @@ const rewriteRequest = ( const rewriterPipeline = [ addKey, languageFilter, - disableStream, + checkStreaming, limitOutputTokens, limitCompletions, finalizeBody, diff --git a/src/server.ts b/src/server.ts index 1ca892c..8cd0bec 100644 --- a/src/server.ts +++ b/src/server.ts @@ -3,7 +3,7 @@ import "source-map-support/register"; import express from "express"; import cors from "cors"; import pinoHttp from "pino-http"; -import { simpleGit } from "simple-git"; +import childProcess from "child_process"; import { logger } from "./logger"; import { keyPool } from "./key-management"; import { proxyRouter, rewriteTavernRequests } from "./proxy/routes"; @@ -57,11 +57,43 @@ app.use((_req: unknown, res: express.Response) => { // start server and load keys app.listen(PORT, async () => { try { - const git = simpleGit(); - const log = git.log({ n: 1 }); - const sha = (await log).latest!.hash; - process.env.COMMIT_SHA = sha; - } catch (error) { + // Huggingface seems to have changed something about how they deploy Spaces + // and git commands fail because of some ownership issue with the .git + // directory. This is a hacky workaround, but we only want to run it on + // deployed instances. + + if (process.env.NODE_ENV === "production") { + childProcess.execSync("git config --global --add safe.directory /app"); + } + + const sha = childProcess + .execSync("git rev-parse --short HEAD") + .toString() + .trim(); + + const status = childProcess + .execSync("git status --porcelain") + .toString() + .trim() + // ignore Dockerfile changes since that's how the user deploys the app + .split("\n") + .filter((line: string) => !line.endsWith("Dockerfile")); + + const changes = status.length > 0; + + logger.info({ sha, status, changes }, "Got commit SHA and status."); + + process.env.COMMIT_SHA = `${sha}${changes ? " (modified)" : ""}`; + } catch (error: any) { + logger.error( + { + error, + stdout: error.stdout.toString(), + stderr: error.stderr.toString(), + }, + "Failed to get commit SHA.", + error + ); process.env.COMMIT_SHA = "unknown"; } diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts index 4be7b28..76f133d 100644 --- a/src/types/custom.d.ts +++ b/src/types/custom.d.ts @@ -6,6 +6,7 @@ declare global { interface Request { key?: Key; api: "kobold" | "openai" | "anthropic"; + isStreaming?: boolean; } } }