Refactor request middleware (khanon/oai-reverse-proxy!18)
This commit is contained in:
+73
-69
@@ -1,22 +1,66 @@
|
||||
import { Request, Router } from "express";
|
||||
import { Request, RequestHandler, Router } from "express";
|
||||
import * as http from "http";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
languageFilter,
|
||||
limitOutputTokens,
|
||||
setApiFormat,
|
||||
transformOutboundPayload,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
createOnProxyResHandler,
|
||||
handleInternalError,
|
||||
} from "./middleware/response";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
if (!config.anthropicKey) return { object: "list", data: [] };
|
||||
|
||||
const claudeVariants = [
|
||||
"claude-v1",
|
||||
"claude-v1-100k",
|
||||
"claude-instant-v1",
|
||||
"claude-instant-v1-100k",
|
||||
"claude-v1.3",
|
||||
"claude-v1.3-100k",
|
||||
"claude-v1.2",
|
||||
"claude-v1.0",
|
||||
"claude-instant-v1.1",
|
||||
"claude-instant-v1.1-100k",
|
||||
"claude-instant-v1.0",
|
||||
];
|
||||
|
||||
const models = claudeVariants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "anthropic",
|
||||
permission: [],
|
||||
root: "claude",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
const rewriteAnthropicRequest = (
|
||||
proxyReq: http.ClientRequest,
|
||||
@@ -27,7 +71,6 @@ const rewriteAnthropicRequest = (
|
||||
addKey,
|
||||
languageFilter,
|
||||
limitOutputTokens,
|
||||
transformOutboundPayload,
|
||||
finalizeBody,
|
||||
];
|
||||
|
||||
@@ -96,22 +139,23 @@ function transformAnthropicResponse(
|
||||
};
|
||||
}
|
||||
|
||||
const anthropicProxy = createProxyMiddleware({
|
||||
target: "https://api.anthropic.com",
|
||||
changeOrigin: true,
|
||||
on: {
|
||||
proxyReq: rewriteAnthropicRequest,
|
||||
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
|
||||
error: handleInternalError,
|
||||
},
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
pathRewrite: {
|
||||
// Send OpenAI-compat requests to the real Anthropic endpoint.
|
||||
"^/v1/chat/completions": "/v1/complete",
|
||||
},
|
||||
});
|
||||
const queuedAnthropicProxy = createQueueMiddleware(anthropicProxy);
|
||||
const anthropicProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
target: "https://api.anthropic.com",
|
||||
changeOrigin: true,
|
||||
on: {
|
||||
proxyReq: rewriteAnthropicRequest,
|
||||
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
pathRewrite: {
|
||||
// Send OpenAI-compat requests to the real Anthropic endpoint.
|
||||
"^/v1/chat/completions": "/v1/complete",
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
const anthropicRouter = Router();
|
||||
// Fix paths because clients don't consistently use the /v1 prefix.
|
||||
@@ -121,19 +165,19 @@ anthropicRouter.use((req, _res, next) => {
|
||||
}
|
||||
next();
|
||||
});
|
||||
anthropicRouter.get("/v1/models", (req, res) => {
|
||||
res.json(buildFakeModelsResponse());
|
||||
});
|
||||
anthropicRouter.get("/v1/models", handleModelRequest);
|
||||
anthropicRouter.post(
|
||||
"/v1/complete",
|
||||
setApiFormat({ in: "anthropic", out: "anthropic" }),
|
||||
queuedAnthropicProxy
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({ inApi: "anthropic", outApi: "anthropic" }),
|
||||
anthropicProxy
|
||||
);
|
||||
// OpenAI-to-Anthropic compatibility endpoint.
|
||||
anthropicRouter.post(
|
||||
"/v1/chat/completions",
|
||||
setApiFormat({ in: "openai", out: "anthropic" }),
|
||||
queuedAnthropicProxy
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({ inApi: "openai", outApi: "anthropic" }),
|
||||
anthropicProxy
|
||||
);
|
||||
// Redirect browser requests to the homepage.
|
||||
anthropicRouter.get("*", (req, res, next) => {
|
||||
@@ -145,44 +189,4 @@ anthropicRouter.get("*", (req, res, next) => {
|
||||
}
|
||||
});
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
function buildFakeModelsResponse() {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
if (!config.anthropicKey) return { object: "list", data: [] };
|
||||
|
||||
const claudeVariants = [
|
||||
"claude-v1",
|
||||
"claude-v1-100k",
|
||||
"claude-instant-v1",
|
||||
"claude-instant-v1-100k",
|
||||
"claude-v1.3",
|
||||
"claude-v1.3-100k",
|
||||
"claude-v1.2",
|
||||
"claude-v1.0",
|
||||
"claude-instant-v1.1",
|
||||
"claude-instant-v1.1-100k",
|
||||
"claude-instant-v1.0",
|
||||
];
|
||||
|
||||
const models = claudeVariants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "anthropic",
|
||||
permission: [],
|
||||
root: "claude",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
export const anthropic = anthropicRouter;
|
||||
|
||||
+4
-4
@@ -7,17 +7,17 @@ import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
languageFilter,
|
||||
limitOutputTokens,
|
||||
setApiFormat,
|
||||
transformKoboldPayload,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
handleInternalError,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
|
||||
@@ -91,7 +91,7 @@ const koboldOaiProxy = createProxyMiddleware({
|
||||
on: {
|
||||
proxyReq: rewriteRequest,
|
||||
proxyRes: createOnProxyResHandler([koboldResponseHandler]),
|
||||
error: handleInternalError,
|
||||
error: handleProxyError,
|
||||
},
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
@@ -102,8 +102,8 @@ koboldRouter.get("/api/v1/model", handleModelRequest);
|
||||
koboldRouter.get("/api/v1/config/soft_prompts_list", handleSoftPromptsRequest);
|
||||
koboldRouter.post(
|
||||
"/api/v1/generate",
|
||||
setApiFormat({ in: "kobold", out: "openai" }),
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({ inApi: "kobold", outApi: "openai" }),
|
||||
koboldOaiProxy
|
||||
);
|
||||
koboldRouter.use((req, res) => {
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
import { Request, Response } from "express";
|
||||
import httpProxy from "http-proxy";
|
||||
import { ZodError } from "zod";
|
||||
|
||||
|
||||
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
|
||||
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
|
||||
|
||||
/** Returns true if we're making a request to a completion endpoint. */
|
||||
export function isCompletionRequest(req: Request) {
|
||||
return (
|
||||
req.method === "POST" &&
|
||||
[OPENAI_CHAT_COMPLETION_ENDPOINT, ANTHROPIC_COMPLETION_ENDPOINT].some(
|
||||
(endpoint) => req.path.startsWith(endpoint)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
export function writeErrorResponse(
|
||||
req: Request,
|
||||
res: Response,
|
||||
statusCode: number,
|
||||
errorPayload: Record<string, any>
|
||||
) {
|
||||
const errorSource = errorPayload.error?.type.startsWith("proxy")
|
||||
? "proxy"
|
||||
: "upstream";
|
||||
|
||||
// If we're mid-SSE stream, send a data event with the error payload and end
|
||||
// the stream. Otherwise just send a normal error response.
|
||||
if (
|
||||
res.headersSent ||
|
||||
res.getHeader("content-type") === "text/event-stream"
|
||||
) {
|
||||
const msg = buildFakeSseMessage(
|
||||
`${errorSource} error (${statusCode})`,
|
||||
JSON.stringify(errorPayload, null, 2),
|
||||
req
|
||||
);
|
||||
res.write(msg);
|
||||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
} else {
|
||||
res.status(statusCode).json(errorPayload);
|
||||
}
|
||||
}
|
||||
|
||||
export const handleProxyError: httpProxy.ErrorCallback = (err, req, res) => {
|
||||
req.log.error({ err }, `Error during proxy request middleware`);
|
||||
handleInternalError(err, req as Request, res as Response);
|
||||
};
|
||||
|
||||
export const handleInternalError = (
|
||||
err: Error,
|
||||
req: Request,
|
||||
res: Response
|
||||
) => {
|
||||
try {
|
||||
const isZod = err instanceof ZodError;
|
||||
if (isZod) {
|
||||
writeErrorResponse(req, res, 400, {
|
||||
error: {
|
||||
type: "proxy_validation_error",
|
||||
proxy_note: `Reverse proxy couldn't validate your request when trying to transform it. Your client may be sending invalid data.`,
|
||||
issues: err.issues,
|
||||
stack: err.stack,
|
||||
message: err.message,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
writeErrorResponse(req, res, 500, {
|
||||
error: {
|
||||
type: "proxy_rewriter_error",
|
||||
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
|
||||
message: err.message,
|
||||
stack: err.stack,
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (e) {
|
||||
req.log.error(
|
||||
{ error: e },
|
||||
`Error writing error response headers, giving up.`
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
export function buildFakeSseMessage(
|
||||
type: string,
|
||||
string: string,
|
||||
req: Request
|
||||
) {
|
||||
let fakeEvent;
|
||||
|
||||
if (req.inboundApi === "anthropic") {
|
||||
fakeEvent = {
|
||||
completion: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`,
|
||||
stop_reason: type,
|
||||
truncated: false, // I've never seen this be true
|
||||
stop: null,
|
||||
model: req.body?.model,
|
||||
log_id: "proxy-req-" + req.id,
|
||||
};
|
||||
} else {
|
||||
fakeEvent = {
|
||||
id: "chatcmpl-" + req.id,
|
||||
object: "chat.completion.chunk",
|
||||
created: Date.now(),
|
||||
model: req.body?.model,
|
||||
choices: [
|
||||
{
|
||||
delta: { content: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n` },
|
||||
index: 0,
|
||||
finish_reason: type,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Key, keyPool } from "../../../key-management";
|
||||
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/** Add a key that can service this request to the request object. */
|
||||
export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => {
|
||||
export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
let assignedKey: Key;
|
||||
|
||||
if (!isCompletionRequest(req)) {
|
||||
@@ -16,7 +17,7 @@ export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => {
|
||||
|
||||
if (!req.inboundApi || !req.outboundApi) {
|
||||
const err = new Error(
|
||||
"Request API format missing. Did you forget to add the `setApiFormat` middleware to your route?"
|
||||
"Request API format missing. Did you forget to add the request preprocessor to your router?"
|
||||
);
|
||||
req.log.error(
|
||||
{ in: req.inboundApi, out: req.outboundApi, path: req.path },
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { fixRequestBody } from "http-proxy-middleware";
|
||||
import type { ExpressHttpProxyReqCallback } from ".";
|
||||
import type { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/** Finalize the rewritten request body. Must be the last rewriter. */
|
||||
export const finalizeBody: ExpressHttpProxyReqCallback = (proxyReq, req) => {
|
||||
export const finalizeBody: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
|
||||
const updatedBody = JSON.stringify(req.body);
|
||||
proxyReq.setHeader("Content-Length", Buffer.byteLength(updatedBody));
|
||||
|
||||
@@ -2,29 +2,44 @@ import type { Request } from "express";
|
||||
import type { ClientRequest } from "http";
|
||||
import type { ProxyReqCallback } from "http-proxy";
|
||||
|
||||
// Express middleware (runs before http-proxy-middleware, can be async)
|
||||
export { createPreprocessorMiddleware } from "./preprocess";
|
||||
export { setApiFormat } from "./set-api-format";
|
||||
export { transformOutboundPayload } from "./transform-outbound-payload";
|
||||
|
||||
// HPM middleware (runs on onProxyReq, cannot be async)
|
||||
export { addKey } from "./add-key";
|
||||
export { finalizeBody } from "./finalize-body";
|
||||
export { languageFilter } from "./language-filter";
|
||||
export { limitCompletions } from "./limit-completions";
|
||||
export { limitOutputTokens } from "./limit-output-tokens";
|
||||
export { setApiFormat } from "./set-api-format";
|
||||
export { transformKoboldPayload } from "./transform-kobold-payload";
|
||||
export { transformOutboundPayload } from "./transform-outbound-payload";
|
||||
|
||||
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
|
||||
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
|
||||
/**
|
||||
* Middleware that runs prior to the request being handled by http-proxy-
|
||||
* middleware.
|
||||
*
|
||||
* Async functions can be used here, but you will not have access to the proxied
|
||||
* request/response objects, nor the data set by ProxyRequestMiddleware
|
||||
* functions as they have not yet been run.
|
||||
*
|
||||
* User will have been authenticated by the time this middleware runs, but your
|
||||
* request won't have been assigned an API key yet.
|
||||
*
|
||||
* Note that these functions only run once ever per request, even if the request
|
||||
* is automatically retried by the request queue middleware.
|
||||
*/
|
||||
export type RequestPreprocessor = (req: Request) => void | Promise<void>;
|
||||
|
||||
/** Returns true if we're making a request to a completion endpoint. */
|
||||
export function isCompletionRequest(req: Request) {
|
||||
return (
|
||||
req.method === "POST" &&
|
||||
[OPENAI_CHAT_COMPLETION_ENDPOINT, ANTHROPIC_COMPLETION_ENDPOINT].some(
|
||||
(endpoint) => req.path.startsWith(endpoint)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
export type ExpressHttpProxyReqCallback = ProxyReqCallback<
|
||||
ClientRequest,
|
||||
Request
|
||||
>;
|
||||
/**
|
||||
* Middleware that runs immediately before the request is sent to the API in
|
||||
* response to http-proxy-middleware's `proxyReq` event.
|
||||
*
|
||||
* Async functions cannot be used here as HPM's event emitter is not async and
|
||||
* will not wait for the promise to resolve before sending the request.
|
||||
*
|
||||
* Note that these functions may be run multiple times per request if the
|
||||
* first attempt is rate limited and the request is automatically retried by the
|
||||
* request queue middleware.
|
||||
*/
|
||||
export type ProxyRequestMiddleware = ProxyReqCallback<ClientRequest, Request>;
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { Request } from "express";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { ExpressHttpProxyReqCallback } from ".";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
const DISALLOWED_REGEX =
|
||||
/[\u2E80-\u2E99\u2E9B-\u2EF3\u2F00-\u2FD5\u3005\u3007\u3021-\u3029\u3038-\u303B\u3400-\u4DB5\u4E00-\u9FD5\uF900-\uFA6D\uFA70-\uFAD9]/;
|
||||
@@ -19,18 +21,31 @@ const containsDisallowedCharacters = (text: string) => {
|
||||
};
|
||||
|
||||
/** Block requests containing too many disallowed characters. */
|
||||
export const languageFilter: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
|
||||
export const languageFilter: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (!config.rejectDisallowed) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (req.method === "POST" && req.body?.messages) {
|
||||
const combinedText = req.body.messages
|
||||
.map((m: { role: string; content: string }) => m.content)
|
||||
.join(",");
|
||||
if (isCompletionRequest(req)) {
|
||||
const combinedText = getPromptFromRequest(req);
|
||||
if (containsDisallowedCharacters(combinedText)) {
|
||||
logger.warn(`Blocked request containing bad characters`);
|
||||
_proxyReq.destroy(new Error(config.rejectMessage));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
function getPromptFromRequest(req: Request) {
|
||||
const service = req.outboundApi;
|
||||
const body = req.body;
|
||||
switch (service) {
|
||||
case "anthropic":
|
||||
return body.prompt;
|
||||
case "openai":
|
||||
return body.messages
|
||||
.map((m: { content: string }) => m.content)
|
||||
.join("\n");
|
||||
default:
|
||||
throw new Error(`Unknown service: ${service}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/**
|
||||
* Don't allow multiple completions to be requested to prevent abuse.
|
||||
* OpenAI-only, Anthropic provides no such parameter.
|
||||
**/
|
||||
export const limitCompletions: ExpressHttpProxyReqCallback = (
|
||||
_proxyReq,
|
||||
req
|
||||
) => {
|
||||
if (isCompletionRequest(req)) {
|
||||
export const limitCompletions: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (isCompletionRequest(req) && req.outboundApi === "openai") {
|
||||
const originalN = req.body?.n || 1;
|
||||
req.body.n = 1;
|
||||
if (originalN !== req.body.n) {
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import { Request } from "express";
|
||||
import { config } from "../../../config";
|
||||
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
const MAX_TOKENS = config.maxOutputTokens;
|
||||
|
||||
/** Enforce a maximum number of tokens requested from the model. */
|
||||
export const limitOutputTokens: ExpressHttpProxyReqCallback = (
|
||||
_proxyReq,
|
||||
req
|
||||
) => {
|
||||
if (isCompletionRequest(req) && req.body?.max_tokens) {
|
||||
export const limitOutputTokens: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (isCompletionRequest(req)) {
|
||||
const requestedMaxTokens = Number.parseInt(getMaxTokensFromRequest(req));
|
||||
let maxTokens = requestedMaxTokens;
|
||||
|
||||
@@ -39,5 +37,12 @@ export const limitOutputTokens: ExpressHttpProxyReqCallback = (
|
||||
};
|
||||
|
||||
function getMaxTokensFromRequest(req: Request) {
|
||||
return (req.body?.max_tokens || req.body?.max_tokens_to_sample) ?? MAX_TOKENS;
|
||||
switch (req.outboundApi) {
|
||||
case "anthropic":
|
||||
return req.body?.max_tokens_to_sample;
|
||||
case "openai":
|
||||
return req.body?.max_tokens;
|
||||
default:
|
||||
throw new Error(`Unknown service: ${req.outboundApi}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import { RequestHandler } from "express";
|
||||
import { handleInternalError } from "../common";
|
||||
import { RequestPreprocessor, setApiFormat, transformOutboundPayload } from ".";
|
||||
|
||||
/**
|
||||
* Returns a middleware function that processes the request body into the given
|
||||
* API format, and then sequentially runs the given additional preprocessors.
|
||||
*/
|
||||
export const createPreprocessorMiddleware = (
|
||||
apiFormat: Parameters<typeof setApiFormat>[0],
|
||||
additionalPreprocessors?: RequestPreprocessor[]
|
||||
): RequestHandler => {
|
||||
const preprocessors: RequestPreprocessor[] = [
|
||||
setApiFormat(apiFormat),
|
||||
transformOutboundPayload,
|
||||
...(additionalPreprocessors ?? []),
|
||||
];
|
||||
|
||||
return async function executePreprocessors(req, res, next) {
|
||||
try {
|
||||
for (const preprocessor of preprocessors) {
|
||||
await preprocessor(req);
|
||||
}
|
||||
next();
|
||||
} catch (error) {
|
||||
req.log.error(error, "Error while executing request preprocessor");
|
||||
handleInternalError(error as Error, req, res);
|
||||
}
|
||||
};
|
||||
};
|
||||
@@ -1,13 +1,13 @@
|
||||
import { Request, RequestHandler } from "express";
|
||||
import { Request } from "express";
|
||||
import { AIService } from "../../../key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
export const setApiFormat = (api: {
|
||||
in: Request["inboundApi"];
|
||||
out: AIService;
|
||||
}): RequestHandler => {
|
||||
return (req, _res, next) => {
|
||||
req.inboundApi = api.in;
|
||||
req.outboundApi = api.out;
|
||||
next();
|
||||
inApi: Request["inboundApi"];
|
||||
outApi: AIService;
|
||||
}): RequestPreprocessor => {
|
||||
return (req) => {
|
||||
req.inboundApi = api.inApi;
|
||||
req.outboundApi = api.outApi;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
* many edge cases to be worth maintaining and doesn't work with newer features.
|
||||
*/
|
||||
import { logger } from "../../../logger";
|
||||
import type { ExpressHttpProxyReqCallback } from ".";
|
||||
import type { ProxyRequestMiddleware } from ".";
|
||||
|
||||
// Kobold requests look like this:
|
||||
// body:
|
||||
@@ -64,8 +64,11 @@ import type { ExpressHttpProxyReqCallback } from ".";
|
||||
// lines into user and assistant messages, but that's not always correct. For
|
||||
// now it will have to do.
|
||||
|
||||
/** Transforms a KoboldAI payload into an OpenAI payload. */
|
||||
export const transformKoboldPayload: ExpressHttpProxyReqCallback = (
|
||||
/**
|
||||
* Transforms a KoboldAI payload into an OpenAI payload.
|
||||
* @deprecated Probably doesn't work anymore, idk.
|
||||
**/
|
||||
export const transformKoboldPayload: ProxyRequestMiddleware = (
|
||||
_proxyReq,
|
||||
req
|
||||
) => {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { RequestPreprocessor } from ".";
|
||||
// import { countTokens } from "../../../tokenization";
|
||||
|
||||
// https://console.anthropic.com/docs/api/reference#-v1-complete
|
||||
const AnthropicV1CompleteSchema = z.object({
|
||||
@@ -34,7 +36,13 @@ const OpenAIV1ChatCompletionSchema = z.object({
|
||||
),
|
||||
temperature: z.number().optional().default(1),
|
||||
top_p: z.number().optional().default(1),
|
||||
n: z.literal(1).optional(),
|
||||
n: z
|
||||
.literal(1, {
|
||||
errorMap: () => ({
|
||||
message: "You may only request a single completion at a time.",
|
||||
}),
|
||||
})
|
||||
.optional(),
|
||||
stream: z.boolean().optional().default(false),
|
||||
stop: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
max_tokens: z.coerce.number().optional(),
|
||||
@@ -45,10 +53,7 @@ const OpenAIV1ChatCompletionSchema = z.object({
|
||||
});
|
||||
|
||||
/** Transforms an incoming request body to one that matches the target API. */
|
||||
export const transformOutboundPayload: ExpressHttpProxyReqCallback = (
|
||||
_proxyReq,
|
||||
req
|
||||
) => {
|
||||
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
const sameService = req.inboundApi === req.outboundApi;
|
||||
const alreadyTransformed = req.retryCount > 0;
|
||||
const notTransformable = !isCompletionRequest(req);
|
||||
@@ -66,7 +71,7 @@ export const transformOutboundPayload: ExpressHttpProxyReqCallback = (
|
||||
const result = validator.safeParse(req.body);
|
||||
if (!result.success) {
|
||||
req.log.error(
|
||||
{ issues: result.error.issues, params: req.body },
|
||||
{ issues: result.error.issues, body: req.body },
|
||||
"Request validation failed"
|
||||
);
|
||||
throw result.error;
|
||||
@@ -87,11 +92,8 @@ export const transformOutboundPayload: ExpressHttpProxyReqCallback = (
|
||||
function openaiToAnthropic(body: any, req: Request) {
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
|
||||
if (!result.success) {
|
||||
// don't log the prompt (usually `messages` but maybe `prompt` if the user
|
||||
// misconfigured their client)
|
||||
const { messages, prompt, ...params } = body;
|
||||
req.log.error(
|
||||
{ issues: result.error.issues, params },
|
||||
{ issues: result.error.issues, body: req.body },
|
||||
"Invalid OpenAI-to-Anthropic request"
|
||||
);
|
||||
throw result.error;
|
||||
@@ -129,6 +131,14 @@ function openaiToAnthropic(body: any, req: Request) {
|
||||
// For big prompts (v1, auto-selects the latest model) is all we can use.
|
||||
const model = prompt.length > 25000 ? "claude-v1-100k" : "claude-v1.2";
|
||||
|
||||
// wip
|
||||
// const tokens = countTokens({
|
||||
// prompt,
|
||||
// req,
|
||||
// service: "anthropic",
|
||||
// });
|
||||
// req.log.info({ tokens }, "Token count");
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
? rest.stop
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Request, Response } from "express";
|
||||
import * as http from "http";
|
||||
import { buildFakeSseMessage } from "../common";
|
||||
import { RawResponseBodyHandler, decodeResponseBody } from ".";
|
||||
import { buildFakeSseMessage } from "../../queue";
|
||||
|
||||
type OpenAiChatCompletionResponse = {
|
||||
id: string;
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
/* This file is fucking horrendous, sorry */
|
||||
import { Request, Response } from "express";
|
||||
import * as http from "http";
|
||||
import * as httpProxy from "http-proxy";
|
||||
import util from "util";
|
||||
import zlib from "zlib";
|
||||
import { ZodError } from "zod";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { keyPool } from "../../../key-management";
|
||||
import { enqueue, trackWaitTime } from "../../queue";
|
||||
import { incrementPromptCount } from "../../auth/user-store";
|
||||
import { buildFakeSseMessage, enqueue, trackWaitTime } from "../../queue";
|
||||
import { isCompletionRequest } from "../request";
|
||||
import { isCompletionRequest, writeErrorResponse } from "../common";
|
||||
import { handleStreamedResponse } from "./handle-streamed-response";
|
||||
import { logPrompt } from "./log-prompt";
|
||||
|
||||
@@ -354,69 +352,6 @@ function handleOpenAIRateLimitError(
|
||||
return errorPayload;
|
||||
}
|
||||
|
||||
function writeErrorResponse(
|
||||
req: Request,
|
||||
res: Response,
|
||||
statusCode: number,
|
||||
errorPayload: Record<string, any>
|
||||
) {
|
||||
const errorSource = errorPayload.error?.type.startsWith("proxy")
|
||||
? "proxy"
|
||||
: "upstream";
|
||||
|
||||
// If we're mid-SSE stream, send a data event with the error payload and end
|
||||
// the stream. Otherwise just send a normal error response.
|
||||
if (
|
||||
res.headersSent ||
|
||||
res.getHeader("content-type") === "text/event-stream"
|
||||
) {
|
||||
const msg = buildFakeSseMessage(
|
||||
`${errorSource} error (${statusCode})`,
|
||||
JSON.stringify(errorPayload, null, 2),
|
||||
req
|
||||
);
|
||||
res.write(msg);
|
||||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
} else {
|
||||
res.status(statusCode).json(errorPayload);
|
||||
}
|
||||
}
|
||||
|
||||
/** Handles errors in rewriter pipelines. */
|
||||
export const handleInternalError: httpProxy.ErrorCallback = (err, req, res) => {
|
||||
logger.error({ error: err }, "Error in http-proxy-middleware pipeline.");
|
||||
|
||||
try {
|
||||
const isZod = err instanceof ZodError;
|
||||
if (isZod) {
|
||||
writeErrorResponse(req as Request, res as Response, 400, {
|
||||
error: {
|
||||
type: "proxy_validation_error",
|
||||
proxy_note: `Reverse proxy couldn't validate your request when trying to transform it. Your client may be sending invalid data.`,
|
||||
issues: err.issues,
|
||||
stack: err.stack,
|
||||
message: err.message,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
writeErrorResponse(req as Request, res as Response, 500, {
|
||||
error: {
|
||||
type: "proxy_rewriter_error",
|
||||
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
|
||||
message: err.message,
|
||||
stack: err.stack,
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
{ error: e },
|
||||
`Error writing error response headers, giving up.`
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
|
||||
if (isCompletionRequest(req)) {
|
||||
keyPool.incrementPrompt(req.key!);
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Request } from "express";
|
||||
import { config } from "../../../config";
|
||||
import { AIService } from "../../../key-management";
|
||||
import { logQueue } from "../../../prompt-logging";
|
||||
import { isCompletionRequest } from "../request";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
|
||||
+94
-97
@@ -1,4 +1,4 @@
|
||||
import { Request, Router } from "express";
|
||||
import { RequestHandler, Request, Router } from "express";
|
||||
import * as http from "http";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
@@ -6,115 +6,24 @@ import { keyPool } from "../key-management";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
languageFilter,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
limitOutputTokens,
|
||||
languageFilter,
|
||||
limitCompletions,
|
||||
setApiFormat,
|
||||
transformOutboundPayload,
|
||||
limitOutputTokens,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
handleInternalError,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
|
||||
const rewriteRequest = (
|
||||
proxyReq: http.ClientRequest,
|
||||
req: Request,
|
||||
res: http.ServerResponse
|
||||
) => {
|
||||
const rewriterPipeline = [
|
||||
addKey,
|
||||
languageFilter,
|
||||
limitOutputTokens,
|
||||
limitCompletions,
|
||||
transformOutboundPayload,
|
||||
finalizeBody,
|
||||
];
|
||||
|
||||
try {
|
||||
for (const rewriter of rewriterPipeline) {
|
||||
rewriter(proxyReq, req, res, {});
|
||||
}
|
||||
} catch (error) {
|
||||
req.log.error(error, "Error while executing proxy rewriter");
|
||||
proxyReq.destroy(error as Error);
|
||||
}
|
||||
};
|
||||
|
||||
const openaiResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
const openaiProxy = createProxyMiddleware({
|
||||
target: "https://api.openai.com",
|
||||
changeOrigin: true,
|
||||
on: {
|
||||
proxyReq: rewriteRequest,
|
||||
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
|
||||
error: handleInternalError,
|
||||
},
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
});
|
||||
const queuedOpenaiProxy = createQueueMiddleware(openaiProxy);
|
||||
|
||||
const openaiRouter = Router();
|
||||
// Fix paths because clients don't consistently use the /v1 prefix.
|
||||
openaiRouter.use((req, _res, next) => {
|
||||
if (!req.path.startsWith("/v1/")) {
|
||||
req.url = `/v1${req.url}`;
|
||||
}
|
||||
next();
|
||||
});
|
||||
openaiRouter.get(
|
||||
"/v1/models",
|
||||
setApiFormat({ in: "openai", out: "openai" }),
|
||||
(_req, res) => {
|
||||
res.json(buildFakeModelsResponse());
|
||||
}
|
||||
);
|
||||
openaiRouter.post(
|
||||
"/v1/chat/completions",
|
||||
setApiFormat({ in: "openai", out: "openai" }),
|
||||
ipLimiter,
|
||||
queuedOpenaiProxy
|
||||
);
|
||||
// Redirect browser requests to the homepage.
|
||||
openaiRouter.get("*", (req, res, next) => {
|
||||
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
|
||||
if (isBrowser) {
|
||||
res.redirect("/");
|
||||
} else {
|
||||
next();
|
||||
}
|
||||
});
|
||||
openaiRouter.use((req, res) => {
|
||||
req.log.warn(`Blocked openai proxy request: ${req.method} ${req.path}`);
|
||||
res.status(404).json({ error: "Not found" });
|
||||
});
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
function buildFakeModelsResponse() {
|
||||
function getModelsResponse() {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
@@ -164,4 +73,92 @@ function buildFakeModelsResponse() {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
const rewriteRequest = (
|
||||
proxyReq: http.ClientRequest,
|
||||
req: Request,
|
||||
res: http.ServerResponse
|
||||
) => {
|
||||
const rewriterPipeline = [
|
||||
addKey,
|
||||
languageFilter,
|
||||
limitOutputTokens,
|
||||
limitCompletions,
|
||||
finalizeBody,
|
||||
];
|
||||
|
||||
try {
|
||||
for (const rewriter of rewriterPipeline) {
|
||||
rewriter(proxyReq, req, res, {});
|
||||
}
|
||||
} catch (error) {
|
||||
req.log.error(error, "Error while executing proxy rewriter");
|
||||
proxyReq.destroy(error as Error);
|
||||
}
|
||||
};
|
||||
|
||||
const openaiResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
const openaiProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
target: "https://api.openai.com",
|
||||
changeOrigin: true,
|
||||
on: {
|
||||
proxyReq: rewriteRequest,
|
||||
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
})
|
||||
);
|
||||
|
||||
const openaiRouter = Router();
|
||||
// Fix paths because clients don't consistently use the /v1 prefix.
|
||||
openaiRouter.use((req, _res, next) => {
|
||||
if (!req.path.startsWith("/v1/")) {
|
||||
req.url = `/v1${req.url}`;
|
||||
}
|
||||
next();
|
||||
});
|
||||
openaiRouter.get("/v1/models", handleModelRequest);
|
||||
openaiRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({ inApi: "openai", outApi: "openai" }),
|
||||
openaiProxy
|
||||
);
|
||||
// Redirect browser requests to the homepage.
|
||||
openaiRouter.get("*", (req, res, next) => {
|
||||
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
|
||||
if (isBrowser) {
|
||||
res.redirect("/");
|
||||
} else {
|
||||
next();
|
||||
}
|
||||
});
|
||||
openaiRouter.use((req, res) => {
|
||||
req.log.warn(`Blocked openai proxy request: ${req.method} ${req.path}`);
|
||||
res.status(404).json({ error: "Not found" });
|
||||
});
|
||||
|
||||
export const openai = openaiRouter;
|
||||
|
||||
+1
-34
@@ -20,6 +20,7 @@ import { config, DequeueMode } from "../config";
|
||||
import { keyPool, SupportedModel } from "../key-management";
|
||||
import { logger } from "../logger";
|
||||
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
|
||||
import { buildFakeSseMessage } from "./middleware/common";
|
||||
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
@@ -326,40 +327,6 @@ function initStreaming(req: Request) {
|
||||
res.write(": joining queue\n\n");
|
||||
}
|
||||
|
||||
export function buildFakeSseMessage(
|
||||
type: string,
|
||||
string: string,
|
||||
req: Request
|
||||
) {
|
||||
let fakeEvent;
|
||||
|
||||
if (req.inboundApi === "anthropic") {
|
||||
fakeEvent = {
|
||||
completion: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`,
|
||||
stop_reason: type,
|
||||
truncated: false, // I've never seen this be true
|
||||
stop: null,
|
||||
model: req.body?.model,
|
||||
log_id: "proxy-req-" + req.id,
|
||||
};
|
||||
} else {
|
||||
fakeEvent = {
|
||||
id: "chatcmpl-" + req.id,
|
||||
object: "chat.completion.chunk",
|
||||
created: Date.now(),
|
||||
model: req.body?.model,
|
||||
choices: [
|
||||
{
|
||||
delta: { content: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n` },
|
||||
index: 0,
|
||||
finish_reason: type,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
}
|
||||
|
||||
/**
|
||||
* http-proxy-middleware attaches a bunch of event listeners to the req and
|
||||
* res objects which causes problems with our approach to re-enqueuing failed
|
||||
|
||||
@@ -38,6 +38,9 @@ app.use(
|
||||
'req.headers["x-real-ip"]',
|
||||
'req.headers["true-client-ip"]',
|
||||
'req.headers["cf-connecting-ip"]',
|
||||
// Don't log the prompt text on transform errors
|
||||
"body.messages",
|
||||
"body.prompt",
|
||||
],
|
||||
censor: "********",
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user