From 03c5c473e19dc4474f79dac2845eb5cbf9b2defc Mon Sep 17 00:00:00 2001 From: nai-degen Date: Mon, 4 Mar 2024 22:54:21 -0600 Subject: [PATCH] improves error handling for sillytavern --- src/proxy/anthropic.ts | 44 ++- src/proxy/middleware/common.ts | 33 +- .../request/preprocessor-factory.ts | 5 + .../request/preprocessors/language-filter.ts | 4 +- .../middleware/response/error-generator.ts | 352 ++++++++++++++++++ .../response/handle-streamed-response.ts | 5 +- src/proxy/middleware/response/index.ts | 30 +- .../response/streaming/sse-stream-adapter.ts | 6 +- src/proxy/queue.ts | 56 +-- src/proxy/routes.ts | 19 +- src/server.ts | 31 +- src/shared/errors.ts | 8 +- src/shared/streaming.ts | 142 ------- src/user/web/self-service.ts | 4 +- 14 files changed, 499 insertions(+), 240 deletions(-) create mode 100644 src/proxy/middleware/response/error-generator.ts diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index aac52e7..7f96df5 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -1,4 +1,4 @@ -import { Request, RequestHandler, Router } from "express"; +import { Request, Response, RequestHandler, Router } from "express"; import { createProxyMiddleware } from "http-proxy-middleware"; import { config } from "../config"; import { logger } from "../logger"; @@ -17,6 +17,7 @@ import { createOnProxyResHandler, } from "./middleware/response"; import { HttpError } from "../shared/errors"; +import { sendErrorToClient } from "./middleware/response/error-generator"; const CLAUDE_3_COMPAT_MODEL = process.env.CLAUDE_3_COMPAT_MODEL || "claude-3-sonnet-20240229"; @@ -251,16 +252,19 @@ anthropicRouter.post( "/v1/claude-3/complete", ipLimiter, handleCompatibilityRequest, - createPreprocessorMiddleware( - { inApi: "anthropic-text", outApi: "anthropic-chat", service: "anthropic" }, - { - beforeTransform: [(req) => void (req.body.model = CLAUDE_3_COMPAT_MODEL)], - } - ), + createPreprocessorMiddleware({ + inApi: "anthropic-text", + outApi: "anthropic-chat", + service: "anthropic", + }), anthropicProxy ); -export function handleCompatibilityRequest(req: Request, res: any, next: any) { +export function handleCompatibilityRequest( + req: Request, + res: Response, + next: any +) { const alreadyInChatFormat = Boolean(req.body.messages); const alreadyUsingClaude3 = req.body.model?.includes("claude-3"); if (!alreadyInChatFormat && !alreadyUsingClaude3) { @@ -268,18 +272,24 @@ export function handleCompatibilityRequest(req: Request, res: any, next: any) { } if (alreadyInChatFormat) { - throw new HttpError( - 400, - "Your request is already using the new API format and does not need the compatibility endpoint. Use the /proxy/anthropic endpoint instead." - ); + sendErrorToClient({ + req, + res, + options: { + title: "Proxy error (incompatible request for endpoint)", + message: + "Your request is already using the new API format and does not need to use the compatibility endpoint.\n\nUse the /proxy/anthropic endpoint instead.", + format: "unknown", + statusCode: 400, + reqId: req.id, + }, + }); } - if (alreadyUsingClaude3) { - throw new HttpError( - 400, - "Your request already includes the new model identifier and does not need the compatibility endpoint. Use the /proxy/anthropic endpoint instead." - ); + if (!alreadyUsingClaude3) { + req.body.model = CLAUDE_3_COMPAT_MODEL; } + next(); } function maybeReassignModel(req: Request) { diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index 168dbe9..e3c0be1 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -2,9 +2,9 @@ import { Request, Response } from "express"; import httpProxy from "http-proxy"; import { ZodError } from "zod"; import { generateErrorMessage } from "zod-error"; -import { makeCompletionSSE } from "../../shared/streaming"; import { assertNever } from "../../shared/utils"; import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits"; +import { buildSpoofedSSE, sendErrorToClient } from "./response/error-generator"; const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions"; const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions"; @@ -40,7 +40,7 @@ export function isEmbeddingsRequest(req: Request) { ); } -export function writeErrorResponse( +export function sendProxyError( req: Request, res: Response, statusCode: number, @@ -52,29 +52,22 @@ export function writeErrorResponse( ? `The proxy encountered an error while trying to process your prompt.` : `The proxy encountered an error while trying to send your prompt to the upstream service.`; - // If we're mid-SSE stream, send a data event with the error payload and end - // the stream. Otherwise just send a normal error response. - if ( - res.headersSent || - String(res.getHeader("content-type")).startsWith("text/event-stream") - ) { - const event = makeCompletionSSE({ + if (req.tokenizerInfo && typeof errorPayload.error === "object") { + errorPayload.error.proxy_tokenizer = req.tokenizerInfo; + } + + sendErrorToClient({ + options: { format: req.inboundApi, title: `Proxy error (HTTP ${statusCode} ${statusMessage})`, message: `${msg} Further technical details are provided below.`, obj: errorPayload, reqId: req.id, model: req.body?.model, - }); - res.write(event); - res.write(`data: [DONE]\n\n`); - res.end(); - } else { - if (req.tokenizerInfo && typeof errorPayload.error === "object") { - errorPayload.error.proxy_tokenizer = req.tokenizerInfo; - } - res.status(statusCode).json(errorPayload); - } + }, + req, + res, + }); } export const handleProxyError: httpProxy.ErrorCallback = (err, req, res) => { @@ -90,7 +83,7 @@ export const classifyErrorAndSend = ( try { const { statusCode, statusMessage, userMessage, ...errorDetails } = classifyError(err); - writeErrorResponse(req, res, statusCode, statusMessage, { + sendProxyError(req, res, statusCode, statusMessage, { error: { message: userMessage, ...errorDetails }, }); } catch (error) { diff --git a/src/proxy/middleware/request/preprocessor-factory.ts b/src/proxy/middleware/request/preprocessor-factory.ts index 9182fbc..027456f 100644 --- a/src/proxy/middleware/request/preprocessor-factory.ts +++ b/src/proxy/middleware/request/preprocessor-factory.ts @@ -122,6 +122,7 @@ const handleTestMessage: RequestHandler = (req, res) => { object: "chat.completion", created: Date.now(), model: body.model, + // openai chat choices: [ { message: { role: "assistant", content: "Hello!" }, @@ -129,6 +130,10 @@ const handleTestMessage: RequestHandler = (req, res) => { index: 0, }, ], + // anthropic text + completion: "Hello!", + // anthropic chat + content: [{ type: "text", text: "Hello!" }], proxy_note: "This response was generated by the proxy's test message handler and did not go to the API.", }); diff --git a/src/proxy/middleware/request/preprocessors/language-filter.ts b/src/proxy/middleware/request/preprocessors/language-filter.ts index 27c9d01..9610cb4 100644 --- a/src/proxy/middleware/request/preprocessors/language-filter.ts +++ b/src/proxy/middleware/request/preprocessors/language-filter.ts @@ -2,7 +2,7 @@ import { Request } from "express"; import { config } from "../../../../config"; import { assertNever } from "../../../../shared/utils"; import { RequestPreprocessor } from "../index"; -import { UserInputError } from "../../../../shared/errors"; +import { BadRequestError } from "../../../../shared/errors"; import { MistralAIChatMessage, OpenAIChatMessage, @@ -46,7 +46,7 @@ export const languageFilter: RequestPreprocessor = async (req) => { req.res!.once("close", resolve); setTimeout(resolve, delay); }); - throw new UserInputError(config.rejectMessage); + throw new BadRequestError(config.rejectMessage); } }; diff --git a/src/proxy/middleware/response/error-generator.ts b/src/proxy/middleware/response/error-generator.ts new file mode 100644 index 0000000..1ba038b --- /dev/null +++ b/src/proxy/middleware/response/error-generator.ts @@ -0,0 +1,352 @@ +import express from "express"; +import { APIFormat } from "../../../shared/key-management"; +import { assertNever } from "../../../shared/utils"; +import { initializeSseStream } from "../../../shared/streaming"; + +function getMessageContent({ + title, + message, + obj, +}: { + title: string; + message: string; + obj?: Record; +}) { + /* + Constructs a Markdown-formatted message that renders semi-nicely in most chat + frontends. For example: + + **Proxy error (HTTP 404 Not Found)** + The proxy encountered an error while trying to send your prompt to the upstream service. Further technical details are provided below. + *** + *The requested Claude model might not exist, or the key might not be provisioned for it.* + ``` + { + "type": "error", + "error": { + "type": "not_found_error", + "message": "model: some-invalid-model-id", + "proxy_tokenizer": { + "tokenizer": "@anthropic-ai/tokenizer", + "token_count": 6104, + "tokenization_duration_ms": 4.0765, + "prompt_tokens": 6104, + "completion_tokens": 30, + "max_model_tokens": 200000, + "max_proxy_tokens": 9007199254740991 + } + }, + "proxy_note": "The requested Claude model might not exist, or the key might not be provisioned for it." + } + ``` + */ + + const friendlyMessage = obj?.proxy_note + ? `${message}\n\n***\n\n*${obj.proxy_note}*` + : message; + const details = JSON.parse(JSON.stringify(obj ?? {})); + let stack = ""; + if (details.stack) { + stack = `\n\nInclude this trace when reporting an issue.\n\`\`\`\n${details.stack}\n\`\`\``; + delete details.stack; + } + return `\n\n**${title}**\n${friendlyMessage}${ + obj ? `\n\`\`\`\n${JSON.stringify(obj, null, 2)}\n\`\`\`\n${stack}` : "" + }`; +} + +type ErrorGeneratorOptions = { + format: APIFormat | "unknown"; + title: string; + message: string; + obj?: object; + reqId: string | number | object; + model?: string; + statusCode?: number; +}; + +export function tryInferFormat(body: any): APIFormat | "unknown" { + if (typeof body !== "object" || !body.model) { + return "unknown"; + } + + if (body.model.includes("gpt")) { + return "openai"; + } + + if (body.model.includes("mistral")) { + return "mistral-ai"; + } + + if (body.model.includes("claude")) { + return body.messages?.length ? "anthropic-chat" : "anthropic-text"; + } + + if (body.model.includes("gemini")) { + return "google-ai"; + } + + return "unknown"; +} + +export function sendErrorToClient({ + options, + req, + res, +}: { + options: ErrorGeneratorOptions; + req: express.Request; + res: express.Response; +}) { + const { format: inputFormat } = options; + + // This is an error thrown before we know the format of the request, so we + // can't send a response in the format the client expects. + const format = + inputFormat === "unknown" ? tryInferFormat(req.body) : inputFormat; + if (format === "unknown") { + return res.status(options.statusCode || 400).json({ + error: options.message, + details: options.obj, + }); + } + + const completion = buildSpoofedCompletion({ ...options, format }); + const event = buildSpoofedSSE({ ...options, format }); + const isStreaming = + req.isStreaming || req.body.stream === true || req.body.stream === "true"; + + if (isStreaming) { + if (!res.headersSent) { + initializeSseStream(res); + } + res.write(event); + res.write(`data: [DONE]\n\n`); + res.end(); + } else { + res.status(200).json(completion); + } +} + +/** + * Returns a non-streaming completion object that looks like it came from the + * service that the request is being proxied to. Used to send error messages to + * the client and have them look like normal responses, for clients with poor + * error handling. + */ +export function buildSpoofedCompletion({ + format, + title, + message, + obj, + reqId, + model = "unknown", +}: ErrorGeneratorOptions & { format: Exclude }) { + const id = String(reqId); + const content = getMessageContent({ title, message, obj }); + + switch (format) { + case "openai": + case "mistral-ai": + return { + id: "error-" + id, + object: "chat.completion", + created: Date.now(), + model, + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + choices: [ + { + message: { role: "assistant", content }, + finish_reason: title, + index: 0, + }, + ], + }; + case "openai-text": + return { + id: "error-" + id, + object: "text_completion", + created: Date.now(), + model, + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + choices: [ + { text: content, index: 0, logprobs: null, finish_reason: title }, + ], + }; + case "anthropic-text": + return { + id: "error-" + id, + type: "completion", + completion: content, + stop_reason: title, + stop: null, + model, + }; + case "anthropic-chat": + return { + id: "error-" + id, + type: "message", + role: "assistant", + content: [{ type: "text", text: content }], + model, + stop_reason: title, + stop_sequence: null, + }; + case "google-ai": + // TODO: Native Google AI non-streaming responses are not supported, this + // is an untested guess at what the response should look like. + return { + id: "error-" + id, + object: "chat.completion", + created: Date.now(), + model, + candidates: [ + { + content: { parts: [{ text: content }], role: "model" }, + finishReason: title, + index: 0, + tokenCount: null, + safetyRatings: [], + }, + ], + }; + case "openai-image": + throw new Error( + `Spoofed completions not supported for ${format} requests` + ); + default: + assertNever(format); + } +} + +/** + * Returns an SSE message that looks like a completion event for the service + * that the request is being proxied to. Used to send error messages to the + * client in the middle of a streaming request. + */ +export function buildSpoofedSSE({ + format, + title, + message, + obj, + reqId, + model = "unknown", +}: ErrorGeneratorOptions & { format: Exclude }) { + const id = String(reqId); + const content = getMessageContent({ title, message, obj }); + + let event; + + switch (format) { + case "openai": + case "mistral-ai": + event = { + id: "chatcmpl-" + id, + object: "chat.completion.chunk", + created: Date.now(), + model, + choices: [{ delta: { content }, index: 0, finish_reason: title }], + }; + break; + case "openai-text": + event = { + id: "cmpl-" + id, + object: "text_completion", + created: Date.now(), + choices: [ + { text: content, index: 0, logprobs: null, finish_reason: title }, + ], + model, + }; + break; + case "anthropic-text": + event = { + completion: content, + stop_reason: title, + truncated: false, + stop: null, + model, + log_id: "proxy-req-" + id, + }; + break; + case "anthropic-chat": + event = { + type: "content_block_delta", + index: 0, + delta: { type: "text_delta", text: content }, + }; + break; + case "google-ai": + return JSON.stringify({ + candidates: [ + { + content: { parts: [{ text: content }], role: "model" }, + finishReason: title, + index: 0, + tokenCount: null, + safetyRatings: [], + }, + ], + }); + case "openai-image": + throw new Error(`SSE not supported for ${format} requests`); + default: + assertNever(format); + } + + if (format === "anthropic-text") { + return ( + ["event: completion", `data: ${JSON.stringify(event)}`].join("\n") + + "\n\n" + ); + } + + // ugh. + if (format === "anthropic-chat") { + return ( + [ + [ + "event: message_start", + `data: ${JSON.stringify({ + type: "message_start", + message: { + id: "error-" + id, + type: "message", + role: "assistant", + content: [], + model, + }, + })}`, + ].join("\n"), + [ + "event: content_block_start", + `data: ${JSON.stringify({ + type: "content_block_start", + index: 0, + content_block: { type: "text", text: "" }, + })}`, + ].join("\n"), + ["event: content_block_delta", `data: ${JSON.stringify(event)}`].join( + "\n" + ), + [ + "event: content_block_stop", + `data: ${JSON.stringify({ type: "content_block_stop", index: 0 })}`, + ].join("\n"), + [ + "event: message_delta", + `data: ${JSON.stringify({ + type: "message_delta", + delta: { stop_reason: title, stop_sequence: null, usage: null }, + })}`, + ], + [ + "event: message_stop", + `data: ${JSON.stringify({ type: "message_stop" })}`, + ].join("\n"), + ].join("\n\n") + "\n\n" + ); + } + + return `data: ${JSON.stringify(event)}\n\n`; +} diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index 46b0f2a..6fa6c7c 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -6,7 +6,7 @@ import { APIFormat, keyPool } from "../../../shared/key-management"; import { copySseResponseHeaders, initializeSseStream, - makeCompletionSSE, + } from "../../../shared/streaming"; import type { logger } from "../../../logger"; import { enqueue } from "../../queue"; @@ -15,6 +15,7 @@ import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder"; import { EventAggregator } from "./streaming/event-aggregator"; import { SSEMessageTransformer } from "./streaming/sse-message-transformer"; import { SSEStreamAdapter } from "./streaming/sse-stream-adapter"; +import { buildSpoofedSSE } from "./error-generator"; const pipelineAsync = promisify(pipeline); @@ -111,7 +112,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( } else { const { message, stack, lastEvent } = err; const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined"; - const errorEvent = makeCompletionSSE({ + const errorEvent = buildSpoofedSSE({ format: req.inboundApi, title: "Proxy stream error", message: "An unexpected error occurred while streaming the response.", diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index cb8044b..e1b7c40 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -18,7 +18,7 @@ import { getCompletionFromBody, isImageGenerationRequest, isTextGenerationRequest, - writeErrorResponse, + sendProxyError, } from "../common"; import { handleStreamedResponse } from "./handle-streamed-response"; import { logPrompt } from "./log-prompt"; @@ -192,13 +192,13 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( // as it was never a problem. body = await decoder(body); } else { - const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`; - req.log.warn({ contentEncoding, key: req.key?.hash }, errorMessage); - writeErrorResponse(req, res, 500, "Internal Server Error", { - error: errorMessage, + const error = `Proxy received response with unsupported content-encoding: ${contentEncoding}`; + req.log.warn({ contentEncoding, key: req.key?.hash }, error); + sendProxyError(req, res, 500, "Internal Server Error", { + error, contentEncoding, }); - return reject(errorMessage); + return reject(error); } } @@ -208,13 +208,11 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( return resolve(json); } return resolve(body.toString()); - } catch (error: any) { - const errorMessage = `Proxy received response with invalid JSON: ${error.message}`; - req.log.warn({ error: error.stack, key: req.key?.hash }, errorMessage); - writeErrorResponse(req, res, 500, "Internal Server Error", { - error: errorMessage, - }); - return reject(errorMessage); + } catch (e) { + const msg = `Proxy received response with invalid JSON: ${e.message}`; + req.log.warn({ error: e.stack, key: req.key?.hash }, msg); + sendProxyError(req, res, 500, "Internal Server Error", { error: msg }); + return reject(msg); } }); }); @@ -267,7 +265,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( proxy_note: `Proxy got back an error, but it was not in JSON format. This is likely a temporary problem with the upstream service.`, }; - writeErrorResponse(req, res, statusCode, statusMessage, errorObject); + sendProxyError(req, res, statusCode, statusMessage, errorObject); throw new HttpError(statusCode, parseError.message); } @@ -412,7 +410,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( ); } - writeErrorResponse(req, res, statusCode, statusMessage, errorPayload); + sendProxyError(req, res, statusCode, statusMessage, errorPayload); + // This is bubbled up to onProxyRes's handler for logging but will not trigger + // a write to the response as `sendProxyError` has just done that. throw new HttpError(statusCode, errorPayload.error?.message); }; diff --git a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts index 83c3f1e..487be16 100644 --- a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts +++ b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts @@ -2,8 +2,8 @@ import pino from "pino"; import { Transform, TransformOptions } from "stream"; import { Message } from "@smithy/eventstream-codec"; import { APIFormat } from "../../../../shared/key-management"; -import { makeCompletionSSE } from "../../../../shared/streaming"; import { RetryableError } from "../index"; +import { buildSpoofedSSE } from "../error-generator"; type SSEStreamAdapterOptions = TransformOptions & { contentType?: string; @@ -75,7 +75,7 @@ export class SSEStreamAdapter extends Transform { throw new RetryableError("AWS request throttled mid-stream"); default: this.log.error({ message, type }, "Received bad AWS stream event"); - return makeCompletionSSE({ + return buildSpoofedSSE({ format: "anthropic-text", title: "Proxy stream error", message: @@ -103,7 +103,7 @@ export class SSEStreamAdapter extends Transform { return `data: ${JSON.stringify(data)}`; } else { this.log.error({ event: data }, "Received bad Google AI event"); - return `data: ${makeCompletionSSE({ + return `data: ${buildSpoofedSSE({ format: "google-ai", title: "Proxy stream error", message: diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index b980ff4..c8bb0a1 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -13,17 +13,19 @@ import crypto from "crypto"; import type { Handler, Request } from "express"; +import { BadRequestError, TooManyRequestsError } from "../shared/errors"; import { keyPool } from "../shared/key-management"; import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily, } from "../shared/models"; -import { makeCompletionSSE, initializeSseStream } from "../shared/streaming"; +import { initializeSseStream } from "../shared/streaming"; import { logger } from "../logger"; import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit"; import { RequestPreprocessor } from "./middleware/request"; import { handleProxyError } from "./middleware/common"; +import { sendErrorToClient } from "./middleware/response/error-generator"; const queue: Request[] = []; const log = logger.child({ module: "request-queue" }); @@ -80,10 +82,14 @@ export async function enqueue(req: Request) { // Re-enqueued requests are not counted towards the limit since they // already made it through the queue once. if (req.retryCount === 0) { - throw new Error("Too many agnai.chat requests are already queued"); + throw new TooManyRequestsError( + "Too many agnai.chat requests are already queued" + ); } } else { - throw new Error("Your IP or token already has a request in the queue"); + throw new TooManyRequestsError( + "Your IP or user token already has another request in the queue." + ); } } @@ -101,8 +107,8 @@ export async function enqueue(req: Request) { } registerHeartbeat(req); } else if (getProxyLoad() > LOAD_THRESHOLD) { - throw new Error( - "Due to heavy traffic on this proxy, you must enable streaming for your request." + throw new BadRequestError( + "Due to heavy traffic on this proxy, you must enable streaming in your chat client to use this endpoint." ); } @@ -354,11 +360,20 @@ export function createQueueMiddleware({ try { await enqueue(req); } catch (err: any) { - req.res!.status(429).json({ - type: "proxy_error", - message: err.message, - stack: err.stack, - proxy_note: `Only one request can be queued at a time. If you don't have another request queued, your IP or user token might be in use by another request.`, + const title = + err.status === 429 + ? "Proxy queue error (too many concurrent requests)" + : "Proxy queue error (streaming required)"; + sendErrorToClient({ + options: { + title, + message: err.message, + format: req.inboundApi, + reqId: req.id, + model: req.body?.model, + }, + req, + res, }); } }; @@ -373,20 +388,17 @@ function killQueuedRequest(req: Request) { const res = req.res; try { const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes.`; - if (res.headersSent) { - const event = makeCompletionSSE({ - format: req.inboundApi, - title: "Proxy queue error", + sendErrorToClient({ + options: { + title: "Proxy queue error (request killed)", message, - reqId: String(req.id), + format: req.inboundApi, + reqId: req.id, model: req.body?.model, - }); - res.write(event); - res.write(`data: [DONE]\n\n`); - res.end(); - } else { - res.status(500).json({ error: message }); - } + }, + req, + res, + }); } catch (e) { req.log.error(e, `Error killing stalled request.`); } diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index 0d18eee..910237a 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -8,6 +8,7 @@ import { googleAI } from "./google-ai"; import { mistralAI } from "./mistral-ai"; import { aws } from "./aws"; import { azure } from "./azure"; +import { sendErrorToClient } from "./middleware/response/error-generator"; const proxyRouter = express.Router(); proxyRouter.use((req, _res, next) => { @@ -46,8 +47,22 @@ proxyRouter.get("*", (req, res, next) => { } }); // Handle 404s. -proxyRouter.use((_req, res) => { - res.status(404).json({ error: "Not found" }); +proxyRouter.use((req, res) => { + sendErrorToClient({ + req, + res, + options: { + title: "Proxy error (HTTP 404 Not Found)", + message: "The requested proxy endpoint does not exist.", + model: req.body?.model, + reqId: req.id, + format: "unknown", + obj: { + proxy_note: "Your chat client is using the wrong endpoint. Please check your configuration.", + requested_url: req.url, + }, + }, + }); }); export { proxyRouter as proxyRouter }; diff --git a/src/server.ts b/src/server.ts index 2c3ac19..916f430 100644 --- a/src/server.ts +++ b/src/server.ts @@ -19,6 +19,7 @@ import { start as startRequestQueue } from "./proxy/queue"; import { init as initUserStore } from "./shared/users/user-store"; import { init as initTokenizers } from "./shared/tokenization"; import { checkOrigin } from "./proxy/check-origin"; +import { sendErrorToClient } from "./proxy/middleware/response/error-generator"; const PORT = config.port; const BIND_ADDRESS = config.bindAddress; @@ -74,21 +75,27 @@ if (config.staticServiceInfo) { app.use("/", infoPageRouter); } -app.use((err: any, _req: unknown, res: express.Response, _next: unknown) => { - if (err.status) { - res.status(err.status).json({ error: err.message }); - } else { - logger.error(err); - res.status(500).json({ - error: { - type: "proxy_error", - message: err.message, - stack: err.stack, - proxy_note: `Reverse proxy encountered an internal server error.`, +app.use( + (err: any, req: express.Request, res: express.Response, _next: unknown) => { + if (!err.status) { + logger.error(err, "Unhandled error in request"); + } + + sendErrorToClient({ + req, + res, + options: { + title: `Proxy error (HTTP ${err.status})`, + message: + "Reverse proxy encountered an unexpected error while processing your request.", + reqId: req.id, + statusCode: err.status, + obj: { error: err.message, stack: err.stack }, + format: "unknown", }, }); } -}); +); app.use((_req: unknown, res: express.Response) => { res.status(404).json({ error: "Not found" }); }); diff --git a/src/shared/errors.ts b/src/shared/errors.ts index d9efbbb..ec6e122 100644 --- a/src/shared/errors.ts +++ b/src/shared/errors.ts @@ -4,7 +4,7 @@ export class HttpError extends Error { } } -export class UserInputError extends HttpError { +export class BadRequestError extends HttpError { constructor(message: string) { super(400, message); } @@ -21,3 +21,9 @@ export class NotFoundError extends HttpError { super(404, message); } } + +export class TooManyRequestsError extends HttpError { + constructor(message: string) { + super(429, message); + } +} diff --git a/src/shared/streaming.ts b/src/shared/streaming.ts index d04a009..bc4f4b8 100644 --- a/src/shared/streaming.ts +++ b/src/shared/streaming.ts @@ -1,7 +1,5 @@ import { Response } from "express"; import { IncomingMessage } from "http"; -import { assertNever } from "./utils"; -import { APIFormat } from "./key-management"; export function initializeSseStream(res: Response) { res.statusCode = 200; @@ -35,143 +33,3 @@ export function copySseResponseHeaders( } } -/** - * Returns an SSE message that looks like a completion event for the service - * that the request is being proxied to. Used to send error messages to the - * client in the middle of a streaming request. - */ -export function makeCompletionSSE({ - format, - title, - message, - obj, - reqId, - model = "unknown", -}: { - format: APIFormat; - title: string; - message: string; - obj?: object; - reqId: string | number | object; - model?: string; -}) { - const id = String(reqId); - const content = `\n\n**${title}**\n${message}${ - obj ? `\n\`\`\`\n${JSON.stringify(obj, null, 2)}\n\`\`\`\n` : "" - }`; - - let event; - - switch (format) { - case "openai": - case "mistral-ai": - event = { - id: "chatcmpl-" + id, - object: "chat.completion.chunk", - created: Date.now(), - model, - choices: [{ delta: { content }, index: 0, finish_reason: title }], - }; - break; - case "openai-text": - event = { - id: "cmpl-" + id, - object: "text_completion", - created: Date.now(), - choices: [ - { text: content, index: 0, logprobs: null, finish_reason: title }, - ], - model, - }; - break; - case "anthropic-text": - event = { - completion: content, - stop_reason: title, - truncated: false, - stop: null, - model, - log_id: "proxy-req-" + id, - }; - break; - case "anthropic-chat": - event = { - type: "content_block_delta", - index: 0, - delta: { type: "text_delta", text: content }, - }; - break; - case "google-ai": - return JSON.stringify({ - candidates: [ - { - content: { parts: [{ text: content }], role: "model" }, - finishReason: title, - index: 0, - tokenCount: null, - safetyRatings: [], - }, - ], - }); - case "openai-image": - throw new Error(`SSE not supported for ${format} requests`); - default: - assertNever(format); - } - - if (format === "anthropic-text") { - return ( - ["event: completion", `data: ${JSON.stringify(event)}`].join("\n") + - "\n\n" - ); - } - - // ugh. - if (format === "anthropic-chat") { - return ( - [ - [ - "event: message_start", - `data: ${JSON.stringify({ - type: "message_start", - message: { - id: "error-" + id, - type: "message", - role: "assistant", - content: [], - model, - }, - })}`, - ].join("\n"), - [ - "event: content_block_start", - `data: ${JSON.stringify({ - type: "content_block_start", - index: 0, - content_block: { type: "text", text: "" }, - })}`, - ].join("\n"), - ["event: content_block_delta", `data: ${JSON.stringify(event)}`].join( - "\n" - ), - [ - "event: content_block_stop", - `data: ${JSON.stringify({ type: "content_block_stop", index: 0 })}`, - ].join("\n"), - [ - "event: message_delta", - `data: ${JSON.stringify({ - type: "message_delta", - delta: { stop_reason: title, stop_sequence: null, usage: null }, - })}`, - ], - [ - "event: message_stop", - `data: ${JSON.stringify({ type: "message_stop" })}`, - ].join("\n"), - ].join("\n\n") + "\n\n" - ); - } - - return `data: ${JSON.stringify(event)}\n\n`; -} diff --git a/src/user/web/self-service.ts b/src/user/web/self-service.ts index 4046245..607df47 100644 --- a/src/user/web/self-service.ts +++ b/src/user/web/self-service.ts @@ -1,7 +1,7 @@ import { Router } from "express"; import { UserPartialSchema } from "../../shared/users/schema"; import * as userStore from "../../shared/users/user-store"; -import { ForbiddenError, UserInputError } from "../../shared/errors"; +import { ForbiddenError, BadRequestError } from "../../shared/errors"; import { sanitizeAndTrim } from "../../shared/utils"; import { config } from "../../config"; @@ -62,7 +62,7 @@ router.post("/edit-nickname", (req, res) => { const result = schema.safeParse(req.body); if (!result.success) { - throw new UserInputError(result.error.message); + throw new BadRequestError(result.error.message); } const newNickname = result.data.nickname || null;