From 6723cbf6620f9dc2371d82e2dd48992c38794a8b Mon Sep 17 00:00:00 2001 From: khanon <44111-khanon@users.noreply.gitgud.io> Date: Tue, 30 May 2023 03:13:17 +0000 Subject: [PATCH] Anthropic endpoint improvements (khanon/oai-reverse-proxy!16) --- .env.example | 1 + README.md | 2 +- src/proxy/anthropic.ts | 25 ++++--- src/proxy/kobold.ts | 9 ++- src/proxy/middleware/request/add-key.ts | 25 ++++--- .../middleware/request/limit-output-tokens.ts | 6 +- .../request/transform-kobold-payload.ts | 5 +- .../request/transform-outbound-payload.ts | 65 +++++++++++------- .../response/handle-streamed-response.ts | 66 +++++++++---------- src/proxy/middleware/response/index.ts | 15 ++--- src/proxy/middleware/response/log-prompt.ts | 37 +++++++---- src/proxy/openai.ts | 16 +++-- src/proxy/queue.ts | 38 +++++------ src/proxy/routes.ts | 23 +++---- src/types/custom.d.ts | 12 ++-- 15 files changed, 192 insertions(+), 153 deletions(-) diff --git a/.env.example b/.env.example index 6a738fd..982b055 100644 --- a/.env.example +++ b/.env.example @@ -33,6 +33,7 @@ # You can add multiple keys by separating them with a comma. OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx +ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # You can require a Bearer token for requests when using proxy_token gatekeeper. # PROXY_KEY=your-secret-key diff --git a/README.md b/README.md index a13a03a..5ca7dec 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # OAI Reverse Proxy -Reverse proxy server for the OpenAI (and soon Anthropic) APIs. Forwards text generation requests while rejecting administrative/billing requests. Includes optional rate limiting and prompt filtering to prevent abuse. +Reverse proxy server for the OpenAI and Anthropic APIs. Forwards text generation requests while rejecting administrative/billing requests. Includes optional rate limiting and prompt filtering to prevent abuse. ### Table of Contents - [What is this?](#what-is-this) diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index 585764f..707606e 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -16,13 +16,13 @@ import { handleInternalError, } from "./middleware/response"; import { createQueueMiddleware } from "./queue"; +import { setApiFormat } from "./routes"; const rewriteAnthropicRequest = ( proxyReq: http.ClientRequest, req: Request, res: http.ServerResponse ) => { - req.api = "anthropic"; const rewriterPipeline = [ addKey, languageFilter, @@ -107,14 +107,14 @@ const anthropicProxy = createProxyMiddleware({ selfHandleResponse: true, logger, pathRewrite: { - // If the user sends a request to /v1/chat/completions (the OpenAI endpoint) - // we will transform the payload and rewrite the path to /v1/complete. + // Send OpenAI-compat requests to the real Anthropic endpoint. "^/v1/chat/completions": "/v1/complete", }, }); const queuedAnthropicProxy = createQueueMiddleware(anthropicProxy); const anthropicRouter = Router(); +// Fix paths because clients don't consistently use the /v1 prefix. anthropicRouter.use((req, _res, next) => { if (!req.path.startsWith("/v1/")) { req.url = `/v1${req.url}`; @@ -124,10 +124,17 @@ anthropicRouter.use((req, _res, next) => { anthropicRouter.get("/v1/models", (req, res) => { res.json(buildFakeModelsResponse()); }); -anthropicRouter.post("/v1/complete", queuedAnthropicProxy); -// This is the OpenAI endpoint, to let users send OpenAI-formatted requests -// to the Anthropic API. We need to rewrite them first. -anthropicRouter.post("/v1/chat/completions", queuedAnthropicProxy); +anthropicRouter.post( + "/v1/complete", + setApiFormat({ in: "anthropic", out: "anthropic" }), + queuedAnthropicProxy +); +// OpenAI-to-Anthropic compatibility endpoint. +anthropicRouter.post( + "/v1/chat/completions", + setApiFormat({ in: "openai", out: "anthropic" }), + queuedAnthropicProxy +); // Redirect browser requests to the homepage. anthropicRouter.get("*", (req, res, next) => { const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); @@ -163,9 +170,7 @@ function buildFakeModelsResponse() { parent: null, })); - return { - models, - }; + return { models }; } export const anthropic = anthropicRouter; diff --git a/src/proxy/kobold.ts b/src/proxy/kobold.ts index 2337558..93124fd 100644 --- a/src/proxy/kobold.ts +++ b/src/proxy/kobold.ts @@ -7,6 +7,7 @@ import { createProxyMiddleware } from "http-proxy-middleware"; import { config } from "../config"; import { logger } from "../logger"; import { ipLimiter } from "./rate-limit"; +import { setApiFormat } from "./routes"; import { addKey, finalizeBody, @@ -39,7 +40,6 @@ const rewriteRequest = ( return; } - req.api = "kobold"; req.body.stream = false; const rewriterPipeline = [ addKey, @@ -100,7 +100,12 @@ const koboldOaiProxy = createProxyMiddleware({ const koboldRouter = Router(); koboldRouter.get("/api/v1/model", handleModelRequest); koboldRouter.get("/api/v1/config/soft_prompts_list", handleSoftPromptsRequest); -koboldRouter.post("/api/v1/generate", ipLimiter, koboldOaiProxy); +koboldRouter.post( + "/api/v1/generate", + setApiFormat({ in: "kobold", out: "openai" }), + ipLimiter, + koboldOaiProxy +); koboldRouter.use((req, res) => { logger.warn(`Unhandled kobold request: ${req.method} ${req.path}`); res.status(404).json({ error: "Not found" }); diff --git a/src/proxy/middleware/request/add-key.ts b/src/proxy/middleware/request/add-key.ts index b58f0d4..a7fc12f 100644 --- a/src/proxy/middleware/request/add-key.ts +++ b/src/proxy/middleware/request/add-key.ts @@ -9,9 +9,22 @@ export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => { // Horrible, horrible hack to stop the proxy from complaining about clients // not sending a model when they are requesting the list of models (which // requires a key, but obviously not a model). + // TODO: shouldn't even proxy /models to the upstream API, just fake it + // using the models our key pool has available. req.body.model = "gpt-3.5-turbo"; } + if (!req.inboundApi || !req.outboundApi) { + const err = new Error( + "Request API format missing. Did you forget to add the `setApiFormat` middleware to your route?" + ); + req.log.error( + { in: req.inboundApi, out: req.outboundApi, path: req.path }, + err.message + ); + throw err; + } + if (!req.body?.model) { throw new Error("You must specify a model with your request."); } @@ -25,14 +38,8 @@ export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => { // the requested model is an OpenAI one even though we're actually sending // an Anthropic request. // For such cases, ignore the requested model entirely. - // Real Anthropic requests come in via /proxy/anthropic/v1/complete - // The OpenAI-compatible endpoint is /proxy/anthropic/v1/chat/completions - - const openaiCompatible = - req.originalUrl === "/proxy/anthropic/v1/chat/completions"; - if (openaiCompatible) { + if (req.inboundApi === "openai" && req.outboundApi === "anthropic") { req.log.debug("Using an Anthropic key for an OpenAI-compatible request"); - req.api = "openai"; // We don't assign the model here, that will happen when transforming the // request body. assignedKey = keyPool.get("claude-v1"); @@ -45,8 +52,8 @@ export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => { { key: assignedKey.hash, model: req.body?.model, - fromApi: req.api, - toApi: assignedKey.service, + fromApi: req.inboundApi, + toApi: req.outboundApi, }, "Assigned key to request" ); diff --git a/src/proxy/middleware/request/limit-output-tokens.ts b/src/proxy/middleware/request/limit-output-tokens.ts index 2e1c9e1..1d4efe3 100644 --- a/src/proxy/middleware/request/limit-output-tokens.ts +++ b/src/proxy/middleware/request/limit-output-tokens.ts @@ -10,7 +10,7 @@ export const limitOutputTokens: ExpressHttpProxyReqCallback = ( req ) => { if (isCompletionRequest(req) && req.body?.max_tokens) { - const requestedMaxTokens = getMaxTokensFromRequest(req); + const requestedMaxTokens = Number.parseInt(getMaxTokensFromRequest(req)); let maxTokens = requestedMaxTokens; if (typeof requestedMaxTokens !== "number") { @@ -24,9 +24,9 @@ export const limitOutputTokens: ExpressHttpProxyReqCallback = ( // TODO: this is not going to scale well, need to implement a better way // of translating request parameters from one API to another. maxTokens = Math.min(maxTokens, MAX_TOKENS); - if (req.key!.service === "openai") { + if (req.outboundApi === "openai") { req.body.max_tokens = maxTokens; - } else if (req.key!.service === "anthropic") { + } else if (req.outboundApi === "anthropic") { req.body.max_tokens_to_sample = maxTokens; } diff --git a/src/proxy/middleware/request/transform-kobold-payload.ts b/src/proxy/middleware/request/transform-kobold-payload.ts index e0c14f7..620caa4 100644 --- a/src/proxy/middleware/request/transform-kobold-payload.ts +++ b/src/proxy/middleware/request/transform-kobold-payload.ts @@ -1,7 +1,8 @@ /** * Transforms a KoboldAI payload into an OpenAI payload. * @deprecated Kobold input format isn't supported anymore as all popular - * frontends support reverse proxies or changing their base URL. + * frontends support reverse proxies or changing their base URL. It adds too + * many edge cases to be worth maintaining and doesn't work with newer features. */ import { logger } from "../../../logger"; import type { ExpressHttpProxyReqCallback } from "."; @@ -68,7 +69,7 @@ export const transformKoboldPayload: ExpressHttpProxyReqCallback = ( _proxyReq, req ) => { - if (req.api !== "kobold") { + if (req.inboundApi !== "kobold") { throw new Error("transformKoboldPayload called for non-kobold request."); } diff --git a/src/proxy/middleware/request/transform-outbound-payload.ts b/src/proxy/middleware/request/transform-outbound-payload.ts index 627f762..6ce148b 100644 --- a/src/proxy/middleware/request/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/transform-outbound-payload.ts @@ -4,33 +4,40 @@ import { ExpressHttpProxyReqCallback, isCompletionRequest } from "."; // https://console.anthropic.com/docs/api/reference#-v1-complete const AnthropicV1CompleteSchema = z.object({ - model: z.string().regex(/^claude-/), - prompt: z.string(), - max_tokens_to_sample: z.number(), + model: z.string().regex(/^claude-/, "Model must start with 'claude-'"), + prompt: z.string({ + required_error: + "No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?", + }), + max_tokens_to_sample: z.coerce.number(), stop_sequences: z.array(z.string()).optional(), stream: z.boolean().optional().default(false), - temperature: z.number().optional().default(1), - top_k: z.number().optional().default(-1), - top_p: z.number().optional().default(-1), + temperature: z.coerce.number().optional().default(1), + top_k: z.coerce.number().optional().default(-1), + top_p: z.coerce.number().optional().default(-1), metadata: z.any().optional(), }); // https://platform.openai.com/docs/api-reference/chat/create const OpenAIV1ChatCompletionSchema = z.object({ - model: z.string().regex(/^gpt/), + model: z.string().regex(/^gpt/, "Model must start with 'gpt-'"), messages: z.array( z.object({ role: z.enum(["system", "user", "assistant"]), content: z.string(), name: z.string().optional(), - }) + }), + { + required_error: + "No prompt found. Are you sending an Anthropic-formatted request to the OpenAI endpoint?", + } ), temperature: z.number().optional().default(1), top_p: z.number().optional().default(1), n: z.literal(1).optional(), stream: z.boolean().optional().default(false), stop: z.union([z.string(), z.array(z.string())]).optional(), - max_tokens: z.number().optional(), + max_tokens: z.coerce.number().optional(), frequency_penalty: z.number().optional().default(0), presence_penalty: z.number().optional().default(0), logit_bias: z.any().optional(), @@ -42,39 +49,47 @@ export const transformOutboundPayload: ExpressHttpProxyReqCallback = ( _proxyReq, req ) => { - if (req.retryCount > 0 || !isCompletionRequest(req)) { + const sameService = req.inboundApi === req.outboundApi; + const alreadyTransformed = req.retryCount > 0; + const notTransformable = !isCompletionRequest(req); + + if (alreadyTransformed || notTransformable) { return; } - const inboundService = req.api; - const outboundService = req.key!.service; - - if (inboundService === outboundService) { + if (sameService) { + // Just validate, don't transform. + const validator = + req.outboundApi === "openai" + ? OpenAIV1ChatCompletionSchema + : AnthropicV1CompleteSchema; + const result = validator.safeParse(req.body); + if (!result.success) { + req.log.error( + { issues: result.error.issues, params: req.body }, + "Request validation failed" + ); + throw result.error; + } return; } - // Not supported yet and unnecessary as everything supports OpenAI. - if (inboundService === "anthropic" && outboundService === "openai") { - throw new Error( - "Anthropic -> OpenAI request transformation not supported. Provide an OpenAI-compatible payload, or use the /claude endpoint." - ); - } - - if (inboundService === "openai" && outboundService === "anthropic") { + if (req.inboundApi === "openai" && req.outboundApi === "anthropic") { req.body = openaiToAnthropic(req.body, req); return; } throw new Error( - `Unsupported transformation: ${inboundService} -> ${outboundService}` + `'${req.inboundApi}' -> '${req.outboundApi}' request proxying is not supported. Make sure your client is configured to use the correct API.` ); }; function openaiToAnthropic(body: any, req: Request) { const result = OpenAIV1ChatCompletionSchema.safeParse(body); if (!result.success) { - // don't log the prompt - const { messages, ...params } = body; + // don't log the prompt (usually `messages` but maybe `prompt` if the user + // misconfigured their client) + const { messages, prompt, ...params } = body; req.log.error( { issues: result.error.issues, params }, "Invalid OpenAI-to-Anthropic request" diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index a218356..834dfcf 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -48,31 +48,28 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( // If these differ, the user is using the OpenAI-compatibile endpoint, so // we need to translate the SSE events into OpenAI completion events for their // frontend. - const fromApi = req.api; - const toApi = req.key!.service; if (!req.isStreaming) { - req.log.error( - { api: req.api, key: req.key?.hash }, - `handleStreamedResponse called for non-streaming request, which isn't valid.` + const err = new Error( + "handleStreamedResponse called for non-streaming request." ); - throw new Error("handleStreamedResponse called for non-streaming request."); + req.log.error({ stack: err.stack, api: req.inboundApi }, err.message); + throw err; } + const key = req.key!; if (proxyRes.statusCode !== 200) { // Ensure we use the non-streaming middleware stack since we won't be // getting any events. req.isStreaming = false; req.log.warn( - `Streaming request to ${req.api} returned ${proxyRes.statusCode} status code. Falling back to non-streaming response handler.` + { statusCode: proxyRes.statusCode, key: key.hash }, + `Streaming request returned error status code. Falling back to non-streaming response handler.` ); return decodeResponseBody(proxyRes, req, res); } return new Promise((resolve, reject) => { - req.log.info( - { api: req.api, key: req.key?.hash }, - `Starting to proxy SSE stream.` - ); + req.log.info({ key: key.hash }, `Starting to proxy SSE stream.`); // Queued streaming requests will already have a connection open and headers // sent due to the heartbeat handler. In that case we can just start @@ -105,9 +102,9 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( proxyRes.on( "data", withErrorHandling((chunk) => { - // We may receive multiple (or partial) SSE messages in a single chunk, so - // we need to buffer and emit seperate stream events for full messages so - // we can parse/transform them properly. + // We may receive multiple (or partial) SSE messages in a single chunk, + // so we need to buffer and emit seperate stream events for full + // messages so we can parse/transform them properly. const str = chunk.toString(); chunkBuffer.push(str); @@ -126,12 +123,12 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( proxyRes.on( "full-sse-event", withErrorHandling((data) => { - const { event, position } = transformEvent( + const { event, position } = transformEvent({ data, - fromApi, - toApi, - lastPosition - ); + requestApi: req.inboundApi, + responseApi: req.outboundApi, + lastPosition, + }); fullChunks.push(event); lastPosition = position; res.write(event + "\n\n"); @@ -142,20 +139,14 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( "end", withErrorHandling(() => { let finalBody = convertEventsToFinalResponse(fullChunks, req); - req.log.info( - { api: req.api, key: req.key?.hash }, - `Finished proxying SSE stream.` - ); + req.log.info({ key: key.hash }, `Finished proxying SSE stream.`); res.end(); resolve(finalBody); }) ); proxyRes.on("error", (err) => { - req.log.error( - { error: err, api: req.api, key: req.key?.hash }, - `Error while streaming response.` - ); + req.log.error({ error: err, key: key.hash }, `Mid-stream error.`); const fakeErrorEvent = buildFakeSseMessage( "mid-stream-error", err.message, @@ -173,12 +164,17 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( * Transforms SSE events from the given response API into events compatible with * the API requested by the client. */ -function transformEvent( - data: string, - requestApi: string, - responseApi: string, - lastPosition: number -) { +function transformEvent({ + data, + requestApi, + responseApi, + lastPosition, +}: { + data: string; + requestApi: string; + responseApi: string; + lastPosition: number; +}) { if (requestApi === responseApi) { return { position: -1, event: data }; } @@ -236,7 +232,7 @@ function copyHeaders(proxyRes: http.IncomingMessage, res: Response) { } function convertEventsToFinalResponse(events: string[], req: Request) { - if (req.key!.service === "openai") { + if (req.outboundApi === "openai") { let response: OpenAiChatCompletionResponse = { id: "", object: "", @@ -278,7 +274,7 @@ function convertEventsToFinalResponse(events: string[], req: Request) { }, response); return response; } - if (req.key!.service === "anthropic") { + if (req.outboundApi === "anthropic") { /* * Full complete responses from Anthropic are conveniently just the same as * the final SSE event before the "DONE" event, so we can reuse that diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index 8cc1ceb..0980424 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -155,11 +155,9 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( res ) => { if (req.isStreaming) { - req.log.error( - { api: req.api, key: req.key?.hash }, - `decodeResponseBody called for a streaming request, which isn't valid.` - ); - throw new Error("decodeResponseBody called for a streaming request."); + const err = new Error("decodeResponseBody called for a streaming request."); + req.log.error({ stack: err.stack, api: req.inboundApi }, err.message); + throw err; } const promise = new Promise((resolve, reject) => { @@ -273,14 +271,14 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`; } else if (statusCode === 429) { // OpenAI uses this for a bunch of different rate-limiting scenarios. - if (req.key!.service === "openai") { + if (req.outboundApi === "openai") { handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload); } else { handleAnthropicRateLimitError(req, errorPayload); } } else if (statusCode === 404) { // Most likely model not found - if (req.key!.service === "openai") { + if (req.outboundApi === "openai") { // TODO: this probably doesn't handle GPT-4-32k variants properly if the // proxy has keys for both the 8k and 32k context models at the same time. if (errorPayload.error?.code === "model_not_found") { @@ -290,7 +288,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( errorPayload.proxy_note = `No model was found for this key.`; } } - } else if (req.key!.service === "anthropic") { + } else if (req.outboundApi === "anthropic") { errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`; } } else { @@ -313,7 +311,6 @@ function handleAnthropicRateLimitError( req: Request, errorPayload: Record ) { - //{"error":{"type":"rate_limit_error","message":"Number of concurrent connections to Claude exceeds your rate limit. Please try again, or contact sales@anthropic.com to discuss your options for a rate limit increase."}} if (errorPayload.error?.type === "rate_limit_error") { keyPool.markRateLimited(req.key!); if (config.queueMode !== "none") { diff --git a/src/proxy/middleware/response/log-prompt.ts b/src/proxy/middleware/response/log-prompt.ts index 751f418..1435b58 100644 --- a/src/proxy/middleware/response/log-prompt.ts +++ b/src/proxy/middleware/response/log-prompt.ts @@ -1,3 +1,4 @@ +import { Request } from "express"; import { config } from "../../../config"; import { AIService } from "../../../key-management"; import { logQueue } from "../../../prompt-logging"; @@ -22,19 +23,19 @@ export const logPrompt: ProxyResHandlerWithBody = async ( return; } - const model = req.body.model; - const promptFlattened = flattenMessages(req.body.messages); + const promptPayload = getPromptForRequest(req); + const promptFlattened = flattenMessages(promptPayload); const response = getResponseForService({ - service: req.key!.service, + service: req.outboundApi, body: responseBody, }); logQueue.enqueue({ - model, - endpoint: req.api, - promptRaw: JSON.stringify(req.body.messages), + endpoint: req.inboundApi, + promptRaw: JSON.stringify(promptPayload), promptFlattened, - response, + model: response.model, // may differ from the requested model + response: response.completion, }); }; @@ -43,7 +44,21 @@ type OaiMessage = { content: string; }; -const flattenMessages = (messages: OaiMessage[]): string => { +const getPromptForRequest = (req: Request): string | OaiMessage[] => { + // Since the prompt logger only runs after the request has been proxied, we + // can assume the body has already been transformed to the target API's + // format. + if (req.outboundApi === "anthropic") { + return req.body.prompt; + } else { + return req.body.messages; + } +}; + +const flattenMessages = (messages: string | OaiMessage[]): string => { + if (typeof messages === "string") { + return messages; + } return messages.map((m) => `${m.role}: ${m.content}`).join("\n"); }; @@ -53,10 +68,10 @@ const getResponseForService = ({ }: { service: AIService; body: Record; -}) => { +}): { completion: string; model: string } => { if (service === "anthropic") { - return body.completion.trim(); + return { completion: body.completion.trim(), model: body.model }; } else { - return body.choices[0].message.content; + return { completion: body.choices[0].message.content, model: body.model }; } }; diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index 080677e..a8e663d 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -18,13 +18,13 @@ import { handleInternalError, ProxyResHandlerWithBody, } from "./middleware/response"; +import { setApiFormat } from "./routes"; const rewriteRequest = ( proxyReq: http.ClientRequest, req: Request, res: http.ServerResponse ) => { - req.api = "openai"; const rewriterPipeline = [ addKey, languageFilter, @@ -76,9 +76,7 @@ const openaiProxy = createProxyMiddleware({ const queuedOpenaiProxy = createQueueMiddleware(openaiProxy); const openaiRouter = Router(); -// Some clients don't include the /v1/ prefix in their requests and users get -// confused when they get a 404. Just fix the route for them so I don't have to -// provide a bunch of different routes for each client's idiosyncrasies. +// Fix paths because clients don't consistently use the /v1 prefix. openaiRouter.use((req, _res, next) => { if (!req.path.startsWith("/v1/")) { req.url = `/v1${req.url}`; @@ -86,9 +84,13 @@ openaiRouter.use((req, _res, next) => { next(); }); openaiRouter.get("/v1/models", openaiProxy); -openaiRouter.post("/v1/chat/completions", ipLimiter, queuedOpenaiProxy); -// If a browser tries to visit a route that doesn't exist, redirect to the info -// page to help them find the right URL. +openaiRouter.post( + "/v1/chat/completions", + setApiFormat({ in: "openai", out: "openai" }), + ipLimiter, + queuedOpenaiProxy +); +// Redirect browser requests to the homepage. openaiRouter.get("*", (req, res, next) => { const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); if (isBrowser) { diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index dc8aa46..88c63e7 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -118,23 +118,24 @@ export function enqueue(req: Request) { } } -export function dequeue(model: SupportedModel): Request | undefined { +type QueuePartition = "claude" | "turbo" | "gpt-4"; +export function dequeue(partition: QueuePartition): Request | undefined { + // There is a single request queue, but it is partitioned by model and API + // provider. + // - claude: requests for the Anthropic API, regardless of model + // - gpt-4: requests for the OpenAI API, specifically for GPT-4 models + // - turbo: effectively, all other requests const modelQueue = queue.filter((req) => { - const reqProvider = req.originalUrl.startsWith("/proxy/anthropic") - ? "anthropic" - : "openai"; - - // This sucks, but the `req.body.model` on Anthropic requests via the - // OpenAI-compat endpoint isn't actually claude-*, it's a fake gpt value. - // TODO: refactor model/service detection - - if (model.startsWith("claude")) { - return reqProvider === "anthropic"; + const provider = req.outboundApi; + const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo"; + switch (partition) { + case "claude": + return provider === "anthropic"; + case "gpt-4": + return provider === "openai" && model.startsWith("gpt-4"); + case "turbo": + return provider === "openai"; } - if (model.startsWith("gpt-4")) { - return reqProvider === "openai" && req.body.model?.startsWith("gpt-4"); - } - return reqProvider === "openai" && req.body.model?.startsWith("gpt-3"); }); if (modelQueue.length === 0) { @@ -191,10 +192,10 @@ function processQueue() { reqs.push(dequeue("gpt-4")); } if (turboLockout === 0) { - reqs.push(dequeue("gpt-3.5-turbo")); + reqs.push(dequeue("turbo")); } if (claudeLockout === 0) { - reqs.push(dequeue("claude-v1")); + reqs.push(dequeue("claude")); } reqs.filter(Boolean).forEach((req) => { @@ -332,8 +333,7 @@ export function buildFakeSseMessage( ) { let fakeEvent; - if (req.api === "anthropic") { - // data: {"completion": " Here is a paragraph of lorem ipsum text:\n\nLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor inc", "stop_reason": "max_tokens", "truncated": false, "stop": null, "model": "claude-instant-v1", "log_id": "???", "exception": null} + if (req.inboundApi === "anthropic") { fakeEvent = { completion: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`, stop_reason: type, diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index 97e1e51..5d7bc7f 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -5,6 +5,7 @@ subset of the API is supported. Kobold requests must be transformed into equivalent OpenAI requests. */ import * as express from "express"; +import { AIService } from "../key-management"; import { gatekeeper } from "./auth/gatekeeper"; import { kobold } from "./kobold"; import { openai } from "./openai"; @@ -17,19 +18,15 @@ router.use("/kobold", kobold); router.use("/openai", openai); router.use("/anthropic", anthropic); -// Each client handles the endpoints input by the user in slightly different -// ways, eg TavernAI ignores everything after the hostname in Kobold mode -function rewriteTavernRequests( - req: express.Request, - _res: express.Response, - next: express.NextFunction -) { - // Requests coming into /api/v1 are actually requests to /proxy/kobold/api/v1 - if (req.path.startsWith("/api/v1")) { - req.url = req.url.replace("/api/v1", "/proxy/kobold/api/v1"); - } - next(); +export function setApiFormat(api: { + in: express.Request["inboundApi"]; + out: AIService; +}): express.RequestHandler { + return (req, _res, next) => { + req.inboundApi = api.in; + req.outboundApi = api.out; + next(); + }; } -export { rewriteTavernRequests }; export { router as proxyRouter }; diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts index 4535ce1..e81bd1b 100644 --- a/src/types/custom.d.ts +++ b/src/types/custom.d.ts @@ -1,17 +1,15 @@ import { Express } from "express-serve-static-core"; -import { Key } from "../key-management/index"; +import { AIService, Key } from "../key-management/index"; import { User } from "../proxy/auth/user-store"; declare global { namespace Express { interface Request { key?: Key; - /** - * Denotes the _inbound_ API format. This is used to determine how the - * user has submitted their request; the proxy will then translate the - * paramaters to the target API format, which is on `key.service`. - */ - api: "kobold" | "openai" | "anthropic"; + /** Denotes the format of the user's submitted request. */ + inboundApi: AIService | "kobold"; + /** Denotes the format of the request being proxied to the API. */ + outboundApi: AIService; user?: User; isStreaming?: boolean; startTime: number;