Refactor request middleware (khanon/oai-reverse-proxy!18)

This commit is contained in:
khanon
2023-06-02 04:03:16 +00:00
parent a26979f7bc
commit dae1262f7a
19 changed files with 440 additions and 337 deletions
+73 -69
View File
@@ -1,22 +1,66 @@
import { Request, Router } from "express";
import { Request, RequestHandler, Router } from "express";
import * as http from "http";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
createPreprocessorMiddleware,
finalizeBody,
languageFilter,
limitOutputTokens,
setApiFormat,
transformOutboundPayload,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
handleInternalError,
} from "./middleware/response";
import { createQueueMiddleware } from "./queue";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.anthropicKey) return { object: "list", data: [] };
const claudeVariants = [
"claude-v1",
"claude-v1-100k",
"claude-instant-v1",
"claude-instant-v1-100k",
"claude-v1.3",
"claude-v1.3-100k",
"claude-v1.2",
"claude-v1.0",
"claude-instant-v1.1",
"claude-instant-v1.1-100k",
"claude-instant-v1.0",
];
const models = claudeVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "anthropic",
permission: [],
root: "claude",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
const rewriteAnthropicRequest = (
proxyReq: http.ClientRequest,
@@ -27,7 +71,6 @@ const rewriteAnthropicRequest = (
addKey,
languageFilter,
limitOutputTokens,
transformOutboundPayload,
finalizeBody,
];
@@ -96,22 +139,23 @@ function transformAnthropicResponse(
};
}
const anthropicProxy = createProxyMiddleware({
target: "https://api.anthropic.com",
changeOrigin: true,
on: {
proxyReq: rewriteAnthropicRequest,
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
error: handleInternalError,
},
selfHandleResponse: true,
logger,
pathRewrite: {
// Send OpenAI-compat requests to the real Anthropic endpoint.
"^/v1/chat/completions": "/v1/complete",
},
});
const queuedAnthropicProxy = createQueueMiddleware(anthropicProxy);
const anthropicProxy = createQueueMiddleware(
createProxyMiddleware({
target: "https://api.anthropic.com",
changeOrigin: true,
on: {
proxyReq: rewriteAnthropicRequest,
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
error: handleProxyError,
},
selfHandleResponse: true,
logger,
pathRewrite: {
// Send OpenAI-compat requests to the real Anthropic endpoint.
"^/v1/chat/completions": "/v1/complete",
},
})
);
const anthropicRouter = Router();
// Fix paths because clients don't consistently use the /v1 prefix.
@@ -121,19 +165,19 @@ anthropicRouter.use((req, _res, next) => {
}
next();
});
anthropicRouter.get("/v1/models", (req, res) => {
res.json(buildFakeModelsResponse());
});
anthropicRouter.get("/v1/models", handleModelRequest);
anthropicRouter.post(
"/v1/complete",
setApiFormat({ in: "anthropic", out: "anthropic" }),
queuedAnthropicProxy
ipLimiter,
createPreprocessorMiddleware({ inApi: "anthropic", outApi: "anthropic" }),
anthropicProxy
);
// OpenAI-to-Anthropic compatibility endpoint.
anthropicRouter.post(
"/v1/chat/completions",
setApiFormat({ in: "openai", out: "anthropic" }),
queuedAnthropicProxy
ipLimiter,
createPreprocessorMiddleware({ inApi: "openai", outApi: "anthropic" }),
anthropicProxy
);
// Redirect browser requests to the homepage.
anthropicRouter.get("*", (req, res, next) => {
@@ -145,44 +189,4 @@ anthropicRouter.get("*", (req, res, next) => {
}
});
let modelsCache: any = null;
let modelsCacheTime = 0;
function buildFakeModelsResponse() {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.anthropicKey) return { object: "list", data: [] };
const claudeVariants = [
"claude-v1",
"claude-v1-100k",
"claude-instant-v1",
"claude-instant-v1-100k",
"claude-v1.3",
"claude-v1.3-100k",
"claude-v1.2",
"claude-v1.0",
"claude-instant-v1.1",
"claude-instant-v1.1-100k",
"claude-instant-v1.0",
];
const models = claudeVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "anthropic",
permission: [],
root: "claude",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
}
export const anthropic = anthropicRouter;
+4 -4
View File
@@ -7,17 +7,17 @@ import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { logger } from "../logger";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
createPreprocessorMiddleware,
finalizeBody,
languageFilter,
limitOutputTokens,
setApiFormat,
transformKoboldPayload,
} from "./middleware/request";
import {
createOnProxyResHandler,
handleInternalError,
ProxyResHandlerWithBody,
} from "./middleware/response";
@@ -91,7 +91,7 @@ const koboldOaiProxy = createProxyMiddleware({
on: {
proxyReq: rewriteRequest,
proxyRes: createOnProxyResHandler([koboldResponseHandler]),
error: handleInternalError,
error: handleProxyError,
},
selfHandleResponse: true,
logger,
@@ -102,8 +102,8 @@ koboldRouter.get("/api/v1/model", handleModelRequest);
koboldRouter.get("/api/v1/config/soft_prompts_list", handleSoftPromptsRequest);
koboldRouter.post(
"/api/v1/generate",
setApiFormat({ in: "kobold", out: "openai" }),
ipLimiter,
createPreprocessorMiddleware({ inApi: "kobold", outApi: "openai" }),
koboldOaiProxy
);
koboldRouter.use((req, res) => {
+120
View File
@@ -0,0 +1,120 @@
import { Request, Response } from "express";
import httpProxy from "http-proxy";
import { ZodError } from "zod";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
/** Returns true if we're making a request to a completion endpoint. */
export function isCompletionRequest(req: Request) {
return (
req.method === "POST" &&
[OPENAI_CHAT_COMPLETION_ENDPOINT, ANTHROPIC_COMPLETION_ENDPOINT].some(
(endpoint) => req.path.startsWith(endpoint)
)
);
}
export function writeErrorResponse(
req: Request,
res: Response,
statusCode: number,
errorPayload: Record<string, any>
) {
const errorSource = errorPayload.error?.type.startsWith("proxy")
? "proxy"
: "upstream";
// If we're mid-SSE stream, send a data event with the error payload and end
// the stream. Otherwise just send a normal error response.
if (
res.headersSent ||
res.getHeader("content-type") === "text/event-stream"
) {
const msg = buildFakeSseMessage(
`${errorSource} error (${statusCode})`,
JSON.stringify(errorPayload, null, 2),
req
);
res.write(msg);
res.write(`data: [DONE]\n\n`);
res.end();
} else {
res.status(statusCode).json(errorPayload);
}
}
export const handleProxyError: httpProxy.ErrorCallback = (err, req, res) => {
req.log.error({ err }, `Error during proxy request middleware`);
handleInternalError(err, req as Request, res as Response);
};
export const handleInternalError = (
err: Error,
req: Request,
res: Response
) => {
try {
const isZod = err instanceof ZodError;
if (isZod) {
writeErrorResponse(req, res, 400, {
error: {
type: "proxy_validation_error",
proxy_note: `Reverse proxy couldn't validate your request when trying to transform it. Your client may be sending invalid data.`,
issues: err.issues,
stack: err.stack,
message: err.message,
},
});
} else {
writeErrorResponse(req, res, 500, {
error: {
type: "proxy_rewriter_error",
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
message: err.message,
stack: err.stack,
},
});
}
} catch (e) {
req.log.error(
{ error: e },
`Error writing error response headers, giving up.`
);
}
};
export function buildFakeSseMessage(
type: string,
string: string,
req: Request
) {
let fakeEvent;
if (req.inboundApi === "anthropic") {
fakeEvent = {
completion: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`,
stop_reason: type,
truncated: false, // I've never seen this be true
stop: null,
model: req.body?.model,
log_id: "proxy-req-" + req.id,
};
} else {
fakeEvent = {
id: "chatcmpl-" + req.id,
object: "chat.completion.chunk",
created: Date.now(),
model: req.body?.model,
choices: [
{
delta: { content: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n` },
index: 0,
finish_reason: type,
},
],
};
}
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
}
+4 -3
View File
@@ -1,8 +1,9 @@
import { Key, keyPool } from "../../../key-management";
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
/** Add a key that can service this request to the request object. */
export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => {
export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
let assignedKey: Key;
if (!isCompletionRequest(req)) {
@@ -16,7 +17,7 @@ export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => {
if (!req.inboundApi || !req.outboundApi) {
const err = new Error(
"Request API format missing. Did you forget to add the `setApiFormat` middleware to your route?"
"Request API format missing. Did you forget to add the request preprocessor to your router?"
);
req.log.error(
{ in: req.inboundApi, out: req.outboundApi, path: req.path },
@@ -1,8 +1,8 @@
import { fixRequestBody } from "http-proxy-middleware";
import type { ExpressHttpProxyReqCallback } from ".";
import type { ProxyRequestMiddleware } from ".";
/** Finalize the rewritten request body. Must be the last rewriter. */
export const finalizeBody: ExpressHttpProxyReqCallback = (proxyReq, req) => {
export const finalizeBody: ProxyRequestMiddleware = (proxyReq, req) => {
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
const updatedBody = JSON.stringify(req.body);
proxyReq.setHeader("Content-Length", Buffer.byteLength(updatedBody));
+33 -18
View File
@@ -2,29 +2,44 @@ import type { Request } from "express";
import type { ClientRequest } from "http";
import type { ProxyReqCallback } from "http-proxy";
// Express middleware (runs before http-proxy-middleware, can be async)
export { createPreprocessorMiddleware } from "./preprocess";
export { setApiFormat } from "./set-api-format";
export { transformOutboundPayload } from "./transform-outbound-payload";
// HPM middleware (runs on onProxyReq, cannot be async)
export { addKey } from "./add-key";
export { finalizeBody } from "./finalize-body";
export { languageFilter } from "./language-filter";
export { limitCompletions } from "./limit-completions";
export { limitOutputTokens } from "./limit-output-tokens";
export { setApiFormat } from "./set-api-format";
export { transformKoboldPayload } from "./transform-kobold-payload";
export { transformOutboundPayload } from "./transform-outbound-payload";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
/**
* Middleware that runs prior to the request being handled by http-proxy-
* middleware.
*
* Async functions can be used here, but you will not have access to the proxied
* request/response objects, nor the data set by ProxyRequestMiddleware
* functions as they have not yet been run.
*
* User will have been authenticated by the time this middleware runs, but your
* request won't have been assigned an API key yet.
*
* Note that these functions only run once ever per request, even if the request
* is automatically retried by the request queue middleware.
*/
export type RequestPreprocessor = (req: Request) => void | Promise<void>;
/** Returns true if we're making a request to a completion endpoint. */
export function isCompletionRequest(req: Request) {
return (
req.method === "POST" &&
[OPENAI_CHAT_COMPLETION_ENDPOINT, ANTHROPIC_COMPLETION_ENDPOINT].some(
(endpoint) => req.path.startsWith(endpoint)
)
);
}
export type ExpressHttpProxyReqCallback = ProxyReqCallback<
ClientRequest,
Request
>;
/**
* Middleware that runs immediately before the request is sent to the API in
* response to http-proxy-middleware's `proxyReq` event.
*
* Async functions cannot be used here as HPM's event emitter is not async and
* will not wait for the promise to resolve before sending the request.
*
* Note that these functions may be run multiple times per request if the
* first attempt is rate limited and the request is automatically retried by the
* request queue middleware.
*/
export type ProxyRequestMiddleware = ProxyReqCallback<ClientRequest, Request>;
@@ -1,6 +1,8 @@
import { Request } from "express";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { ExpressHttpProxyReqCallback } from ".";
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
const DISALLOWED_REGEX =
/[\u2E80-\u2E99\u2E9B-\u2EF3\u2F00-\u2FD5\u3005\u3007\u3021-\u3029\u3038-\u303B\u3400-\u4DB5\u4E00-\u9FD5\uF900-\uFA6D\uFA70-\uFAD9]/;
@@ -19,18 +21,31 @@ const containsDisallowedCharacters = (text: string) => {
};
/** Block requests containing too many disallowed characters. */
export const languageFilter: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
export const languageFilter: ProxyRequestMiddleware = (_proxyReq, req) => {
if (!config.rejectDisallowed) {
return;
}
if (req.method === "POST" && req.body?.messages) {
const combinedText = req.body.messages
.map((m: { role: string; content: string }) => m.content)
.join(",");
if (isCompletionRequest(req)) {
const combinedText = getPromptFromRequest(req);
if (containsDisallowedCharacters(combinedText)) {
logger.warn(`Blocked request containing bad characters`);
_proxyReq.destroy(new Error(config.rejectMessage));
}
}
};
function getPromptFromRequest(req: Request) {
const service = req.outboundApi;
const body = req.body;
switch (service) {
case "anthropic":
return body.prompt;
case "openai":
return body.messages
.map((m: { content: string }) => m.content)
.join("\n");
default:
throw new Error(`Unknown service: ${service}`);
}
}
@@ -1,14 +1,12 @@
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
/**
* Don't allow multiple completions to be requested to prevent abuse.
* OpenAI-only, Anthropic provides no such parameter.
**/
export const limitCompletions: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (isCompletionRequest(req)) {
export const limitCompletions: ProxyRequestMiddleware = (_proxyReq, req) => {
if (isCompletionRequest(req) && req.outboundApi === "openai") {
const originalN = req.body?.n || 1;
req.body.n = 1;
if (originalN !== req.body.n) {
@@ -1,15 +1,13 @@
import { Request } from "express";
import { config } from "../../../config";
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
const MAX_TOKENS = config.maxOutputTokens;
/** Enforce a maximum number of tokens requested from the model. */
export const limitOutputTokens: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (isCompletionRequest(req) && req.body?.max_tokens) {
export const limitOutputTokens: ProxyRequestMiddleware = (_proxyReq, req) => {
if (isCompletionRequest(req)) {
const requestedMaxTokens = Number.parseInt(getMaxTokensFromRequest(req));
let maxTokens = requestedMaxTokens;
@@ -39,5 +37,12 @@ export const limitOutputTokens: ExpressHttpProxyReqCallback = (
};
function getMaxTokensFromRequest(req: Request) {
return (req.body?.max_tokens || req.body?.max_tokens_to_sample) ?? MAX_TOKENS;
switch (req.outboundApi) {
case "anthropic":
return req.body?.max_tokens_to_sample;
case "openai":
return req.body?.max_tokens;
default:
throw new Error(`Unknown service: ${req.outboundApi}`);
}
}
@@ -0,0 +1,30 @@
import { RequestHandler } from "express";
import { handleInternalError } from "../common";
import { RequestPreprocessor, setApiFormat, transformOutboundPayload } from ".";
/**
* Returns a middleware function that processes the request body into the given
* API format, and then sequentially runs the given additional preprocessors.
*/
export const createPreprocessorMiddleware = (
apiFormat: Parameters<typeof setApiFormat>[0],
additionalPreprocessors?: RequestPreprocessor[]
): RequestHandler => {
const preprocessors: RequestPreprocessor[] = [
setApiFormat(apiFormat),
transformOutboundPayload,
...(additionalPreprocessors ?? []),
];
return async function executePreprocessors(req, res, next) {
try {
for (const preprocessor of preprocessors) {
await preprocessor(req);
}
next();
} catch (error) {
req.log.error(error, "Error while executing request preprocessor");
handleInternalError(error as Error, req, res);
}
};
};
@@ -1,13 +1,13 @@
import { Request, RequestHandler } from "express";
import { Request } from "express";
import { AIService } from "../../../key-management";
import { RequestPreprocessor } from ".";
export const setApiFormat = (api: {
in: Request["inboundApi"];
out: AIService;
}): RequestHandler => {
return (req, _res, next) => {
req.inboundApi = api.in;
req.outboundApi = api.out;
next();
inApi: Request["inboundApi"];
outApi: AIService;
}): RequestPreprocessor => {
return (req) => {
req.inboundApi = api.inApi;
req.outboundApi = api.outApi;
};
};
@@ -5,7 +5,7 @@
* many edge cases to be worth maintaining and doesn't work with newer features.
*/
import { logger } from "../../../logger";
import type { ExpressHttpProxyReqCallback } from ".";
import type { ProxyRequestMiddleware } from ".";
// Kobold requests look like this:
// body:
@@ -64,8 +64,11 @@ import type { ExpressHttpProxyReqCallback } from ".";
// lines into user and assistant messages, but that's not always correct. For
// now it will have to do.
/** Transforms a KoboldAI payload into an OpenAI payload. */
export const transformKoboldPayload: ExpressHttpProxyReqCallback = (
/**
* Transforms a KoboldAI payload into an OpenAI payload.
* @deprecated Probably doesn't work anymore, idk.
**/
export const transformKoboldPayload: ProxyRequestMiddleware = (
_proxyReq,
req
) => {
@@ -1,6 +1,8 @@
import { Request } from "express";
import { z } from "zod";
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
import { isCompletionRequest } from "../common";
import { RequestPreprocessor } from ".";
// import { countTokens } from "../../../tokenization";
// https://console.anthropic.com/docs/api/reference#-v1-complete
const AnthropicV1CompleteSchema = z.object({
@@ -34,7 +36,13 @@ const OpenAIV1ChatCompletionSchema = z.object({
),
temperature: z.number().optional().default(1),
top_p: z.number().optional().default(1),
n: z.literal(1).optional(),
n: z
.literal(1, {
errorMap: () => ({
message: "You may only request a single completion at a time.",
}),
})
.optional(),
stream: z.boolean().optional().default(false),
stop: z.union([z.string(), z.array(z.string())]).optional(),
max_tokens: z.coerce.number().optional(),
@@ -45,10 +53,7 @@ const OpenAIV1ChatCompletionSchema = z.object({
});
/** Transforms an incoming request body to one that matches the target API. */
export const transformOutboundPayload: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
const sameService = req.inboundApi === req.outboundApi;
const alreadyTransformed = req.retryCount > 0;
const notTransformable = !isCompletionRequest(req);
@@ -66,7 +71,7 @@ export const transformOutboundPayload: ExpressHttpProxyReqCallback = (
const result = validator.safeParse(req.body);
if (!result.success) {
req.log.error(
{ issues: result.error.issues, params: req.body },
{ issues: result.error.issues, body: req.body },
"Request validation failed"
);
throw result.error;
@@ -87,11 +92,8 @@ export const transformOutboundPayload: ExpressHttpProxyReqCallback = (
function openaiToAnthropic(body: any, req: Request) {
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
// don't log the prompt (usually `messages` but maybe `prompt` if the user
// misconfigured their client)
const { messages, prompt, ...params } = body;
req.log.error(
{ issues: result.error.issues, params },
{ issues: result.error.issues, body: req.body },
"Invalid OpenAI-to-Anthropic request"
);
throw result.error;
@@ -129,6 +131,14 @@ function openaiToAnthropic(body: any, req: Request) {
// For big prompts (v1, auto-selects the latest model) is all we can use.
const model = prompt.length > 25000 ? "claude-v1-100k" : "claude-v1.2";
// wip
// const tokens = countTokens({
// prompt,
// req,
// service: "anthropic",
// });
// req.log.info({ tokens }, "Token count");
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
@@ -1,7 +1,7 @@
import { Request, Response } from "express";
import * as http from "http";
import { buildFakeSseMessage } from "../common";
import { RawResponseBodyHandler, decodeResponseBody } from ".";
import { buildFakeSseMessage } from "../../queue";
type OpenAiChatCompletionResponse = {
id: string;
+2 -67
View File
@@ -1,16 +1,14 @@
/* This file is fucking horrendous, sorry */
import { Request, Response } from "express";
import * as http from "http";
import * as httpProxy from "http-proxy";
import util from "util";
import zlib from "zlib";
import { ZodError } from "zod";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { keyPool } from "../../../key-management";
import { enqueue, trackWaitTime } from "../../queue";
import { incrementPromptCount } from "../../auth/user-store";
import { buildFakeSseMessage, enqueue, trackWaitTime } from "../../queue";
import { isCompletionRequest } from "../request";
import { isCompletionRequest, writeErrorResponse } from "../common";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
@@ -354,69 +352,6 @@ function handleOpenAIRateLimitError(
return errorPayload;
}
function writeErrorResponse(
req: Request,
res: Response,
statusCode: number,
errorPayload: Record<string, any>
) {
const errorSource = errorPayload.error?.type.startsWith("proxy")
? "proxy"
: "upstream";
// If we're mid-SSE stream, send a data event with the error payload and end
// the stream. Otherwise just send a normal error response.
if (
res.headersSent ||
res.getHeader("content-type") === "text/event-stream"
) {
const msg = buildFakeSseMessage(
`${errorSource} error (${statusCode})`,
JSON.stringify(errorPayload, null, 2),
req
);
res.write(msg);
res.write(`data: [DONE]\n\n`);
res.end();
} else {
res.status(statusCode).json(errorPayload);
}
}
/** Handles errors in rewriter pipelines. */
export const handleInternalError: httpProxy.ErrorCallback = (err, req, res) => {
logger.error({ error: err }, "Error in http-proxy-middleware pipeline.");
try {
const isZod = err instanceof ZodError;
if (isZod) {
writeErrorResponse(req as Request, res as Response, 400, {
error: {
type: "proxy_validation_error",
proxy_note: `Reverse proxy couldn't validate your request when trying to transform it. Your client may be sending invalid data.`,
issues: err.issues,
stack: err.stack,
message: err.message,
},
});
} else {
writeErrorResponse(req as Request, res as Response, 500, {
error: {
type: "proxy_rewriter_error",
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
message: err.message,
stack: err.stack,
},
});
}
} catch (e) {
logger.error(
{ error: e },
`Error writing error response headers, giving up.`
);
}
};
const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (isCompletionRequest(req)) {
keyPool.incrementPrompt(req.key!);
+1 -1
View File
@@ -2,7 +2,7 @@ import { Request } from "express";
import { config } from "../../../config";
import { AIService } from "../../../key-management";
import { logQueue } from "../../../prompt-logging";
import { isCompletionRequest } from "../request";
import { isCompletionRequest } from "../common";
import { ProxyResHandlerWithBody } from ".";
/** If prompt logging is enabled, enqueues the prompt for logging. */
+94 -97
View File
@@ -1,4 +1,4 @@
import { Request, Router } from "express";
import { RequestHandler, Request, Router } from "express";
import * as http from "http";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
@@ -6,115 +6,24 @@ import { keyPool } from "../key-management";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
languageFilter,
createPreprocessorMiddleware,
finalizeBody,
limitOutputTokens,
languageFilter,
limitCompletions,
setApiFormat,
transformOutboundPayload,
limitOutputTokens,
} from "./middleware/request";
import {
createOnProxyResHandler,
handleInternalError,
ProxyResHandlerWithBody,
} from "./middleware/response";
const rewriteRequest = (
proxyReq: http.ClientRequest,
req: Request,
res: http.ServerResponse
) => {
const rewriterPipeline = [
addKey,
languageFilter,
limitOutputTokens,
limitCompletions,
transformOutboundPayload,
finalizeBody,
];
try {
for (const rewriter of rewriterPipeline) {
rewriter(proxyReq, req, res, {});
}
} catch (error) {
req.log.error(error, "Error while executing proxy rewriter");
proxyReq.destroy(error as Error);
}
};
const openaiResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
res.status(200).json(body);
};
const openaiProxy = createProxyMiddleware({
target: "https://api.openai.com",
changeOrigin: true,
on: {
proxyReq: rewriteRequest,
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
error: handleInternalError,
},
selfHandleResponse: true,
logger,
});
const queuedOpenaiProxy = createQueueMiddleware(openaiProxy);
const openaiRouter = Router();
// Fix paths because clients don't consistently use the /v1 prefix.
openaiRouter.use((req, _res, next) => {
if (!req.path.startsWith("/v1/")) {
req.url = `/v1${req.url}`;
}
next();
});
openaiRouter.get(
"/v1/models",
setApiFormat({ in: "openai", out: "openai" }),
(_req, res) => {
res.json(buildFakeModelsResponse());
}
);
openaiRouter.post(
"/v1/chat/completions",
setApiFormat({ in: "openai", out: "openai" }),
ipLimiter,
queuedOpenaiProxy
);
// Redirect browser requests to the homepage.
openaiRouter.get("*", (req, res, next) => {
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
if (isBrowser) {
res.redirect("/");
} else {
next();
}
});
openaiRouter.use((req, res) => {
req.log.warn(`Blocked openai proxy request: ${req.method} ${req.path}`);
res.status(404).json({ error: "Not found" });
});
let modelsCache: any = null;
let modelsCacheTime = 0;
function buildFakeModelsResponse() {
function getModelsResponse() {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
@@ -164,4 +73,92 @@ function buildFakeModelsResponse() {
return modelsCache;
}
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
const rewriteRequest = (
proxyReq: http.ClientRequest,
req: Request,
res: http.ServerResponse
) => {
const rewriterPipeline = [
addKey,
languageFilter,
limitOutputTokens,
limitCompletions,
finalizeBody,
];
try {
for (const rewriter of rewriterPipeline) {
rewriter(proxyReq, req, res, {});
}
} catch (error) {
req.log.error(error, "Error while executing proxy rewriter");
proxyReq.destroy(error as Error);
}
};
const openaiResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
res.status(200).json(body);
};
const openaiProxy = createQueueMiddleware(
createProxyMiddleware({
target: "https://api.openai.com",
changeOrigin: true,
on: {
proxyReq: rewriteRequest,
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
error: handleProxyError,
},
selfHandleResponse: true,
logger,
})
);
const openaiRouter = Router();
// Fix paths because clients don't consistently use the /v1 prefix.
openaiRouter.use((req, _res, next) => {
if (!req.path.startsWith("/v1/")) {
req.url = `/v1${req.url}`;
}
next();
});
openaiRouter.get("/v1/models", handleModelRequest);
openaiRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({ inApi: "openai", outApi: "openai" }),
openaiProxy
);
// Redirect browser requests to the homepage.
openaiRouter.get("*", (req, res, next) => {
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
if (isBrowser) {
res.redirect("/");
} else {
next();
}
});
openaiRouter.use((req, res) => {
req.log.warn(`Blocked openai proxy request: ${req.method} ${req.path}`);
res.status(404).json({ error: "Not found" });
});
export const openai = openaiRouter;
+1 -34
View File
@@ -20,6 +20,7 @@ import { config, DequeueMode } from "../config";
import { keyPool, SupportedModel } from "../key-management";
import { logger } from "../logger";
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
import { buildFakeSseMessage } from "./middleware/common";
const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
@@ -326,40 +327,6 @@ function initStreaming(req: Request) {
res.write(": joining queue\n\n");
}
export function buildFakeSseMessage(
type: string,
string: string,
req: Request
) {
let fakeEvent;
if (req.inboundApi === "anthropic") {
fakeEvent = {
completion: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`,
stop_reason: type,
truncated: false, // I've never seen this be true
stop: null,
model: req.body?.model,
log_id: "proxy-req-" + req.id,
};
} else {
fakeEvent = {
id: "chatcmpl-" + req.id,
object: "chat.completion.chunk",
created: Date.now(),
model: req.body?.model,
choices: [
{
delta: { content: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n` },
index: 0,
finish_reason: type,
},
],
};
}
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
}
/**
* http-proxy-middleware attaches a bunch of event listeners to the req and
* res objects which causes problems with our approach to re-enqueuing failed
+3
View File
@@ -38,6 +38,9 @@ app.use(
'req.headers["x-real-ip"]',
'req.headers["true-client-ip"]',
'req.headers["cf-connecting-ip"]',
// Don't log the prompt text on transform errors
"body.messages",
"body.prompt",
],
censor: "********",
},