implements basic key rotation

This commit is contained in:
nai-degen
2023-04-08 05:32:24 -05:00
committed by nai-degen
parent 5ed37bf035
commit a4840e0fe6
4 changed files with 89 additions and 14 deletions
+13 -2
View File
@@ -15,7 +15,7 @@ type KeySchema = {
}; };
/** Runtime information about a key. */ /** Runtime information about a key. */
type Key = KeySchema & { export type Key = KeySchema & {
/** Whether this key is currently disabled. We set this if we get a 429 or 401 response from OpenAI. */ /** Whether this key is currently disabled. We set this if we get a 429 or 401 response from OpenAI. */
isDisabled?: boolean; isDisabled?: boolean;
/** Threshold at which a warning email will be sent by OpenAI. */ /** Threshold at which a warning email will be sent by OpenAI. */
@@ -71,6 +71,17 @@ function list() {
})); }));
} }
function disable(key: Key) {
const keyFromPool = keyPool.find((k) => k.key === key.key)!;
if (keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
logger.warn("Key disabled", { key: key.hash });
}
function anyAvailable() {
return keyPool.some((key) => !key.isDisabled);
}
function get(model: string) { function get(model: string) {
const needsGpt4Key = model.startsWith("gpt-4"); const needsGpt4Key = model.startsWith("gpt-4");
const availableKeys = keyPool.filter( const availableKeys = keyPool.filter(
@@ -99,4 +110,4 @@ function get(model: string) {
return oldestKey; return oldestKey;
} }
export const keys = { init, list, get }; export const keys = { init, list, get, anyAvailable, disable };
+64 -11
View File
@@ -9,6 +9,7 @@ import { keys } from "./keys";
*/ */
const rewriteRequest = (proxyReq: http.ClientRequest, req: Request) => { const rewriteRequest = (proxyReq: http.ClientRequest, req: Request) => {
const key = keys.get(req.body?.model || "gpt-3.5")!; const key = keys.get(req.body?.model || "gpt-3.5")!;
req.key = key;
proxyReq.setHeader("Authorization", `Bearer ${key}`); proxyReq.setHeader("Authorization", `Bearer ${key}`);
if (req.body?.stream) { if (req.body?.stream) {
@@ -20,24 +21,67 @@ const rewriteRequest = (proxyReq: http.ClientRequest, req: Request) => {
} }
}; };
// TODO: extract this since Kobold will use it too
const handleResponse = ( const handleResponse = (
proxyRes: http.IncomingMessage, proxyRes: http.IncomingMessage,
req: Request, req: Request,
res: Response res: Response
) => { ) => {
const { method, path } = req;
const statusCode = proxyRes.statusCode || 500; const statusCode = proxyRes.statusCode || 500;
if (statusCode === 429) { if (statusCode >= 400) {
// TODO: Handle rate limit by temporarily removing that key from the pool // Consume body and then decide what to do
logger.warn(`OpenAI rate limit exceeded: ${method} ${path}`); let body = "";
} else if (statusCode >= 400) { proxyRes.on("data", (chunk) => {
logger.warn(`OpenAI error: ${method} ${path} ${statusCode}`); body += chunk;
} else { });
logger.info(`OpenAI request: ${method} ${path} ${statusCode}`); proxyRes.on("end", () => {
} let errorPayload: any = {
error: "Proxy couldn't parse error from OpenAI",
};
const canTryAgain = keys.anyAvailable()
? "You can try again to get a different key."
: "There are no more keys available.";
try {
errorPayload = JSON.parse(body);
} catch (err) {
logger.error(errorPayload.error, { error: err });
res.status(statusCode).json(errorPayload);
return;
}
proxyRes.pipe(res); if (statusCode === 401) {
// Key is invalid or was revoked
logger.warn(
`OpenAI key is invalid or revoked. Keyhash ${req.key?.hash}`
);
keys.disable(req.key!);
const message = `The OpenAI key is invalid or revoked. ${canTryAgain}`;
errorPayload.proxy_note = message;
} else if (statusCode === 429) {
// Rate limit exceeded
// Annoyingly they send this for:
// - Quota exceeded, key is totally dead
// - Rate limit exceeded, key is still good but backoff needed
// - Model overloaded, their server is fucked
if (errorPayload.error?.type === "insufficient_quota") {
logger.warn(`OpenAI key is exhausted. Keyhash ${req.key?.hash}`);
keys.disable(req.key!);
const message = `The OpenAI key is exhausted. ${canTryAgain}`;
errorPayload.proxy_note = message;
} else {
logger.warn(
`OpenAI rate limit exceeded or model overloaded. Keyhash ${req.key?.hash}`,
{ errorCode: errorPayload.error?.type }
);
}
}
res.status(statusCode).json(errorPayload);
});
} else {
proxyRes.pipe(res);
}
}; };
const openaiProxy = createProxyMiddleware({ const openaiProxy = createProxyMiddleware({
@@ -49,9 +93,16 @@ const openaiProxy = createProxyMiddleware({
pathRewrite: { pathRewrite: {
"^/proxy/openai": "", "^/proxy/openai": "",
}, },
logProvider: () => ({
debug: logger.debug.bind(logger),
info: logger.info.bind(logger),
warn: logger.warn.bind(logger),
error: logger.error.bind(logger),
log: logger.info.bind(logger),
}),
}); });
export const openaiRouter = Router(); const openaiRouter = Router();
openaiRouter.post("/v1/chat/completions", openaiProxy); openaiRouter.post("/v1/chat/completions", openaiProxy);
// openaiRouter.post("/v1/completions", openaiProxy); // openaiRouter.post("/v1/completions", openaiProxy);
// openaiRouter.get("/v1/models", handleModels); // openaiRouter.get("/v1/models", handleModels);
@@ -60,3 +111,5 @@ openaiRouter.use((req, res) => {
logger.warn(`Blocked openai proxy request: ${req.method} ${req.path}`); logger.warn(`Blocked openai proxy request: ${req.method} ${req.path}`);
res.status(404).json({ error: "Not found" }); res.status(404).json({ error: "Not found" });
}); });
export const openai = openaiRouter;
+10
View File
@@ -0,0 +1,10 @@
import { Express } from "express-serve-static-core";
import { Key } from "../keys";
declare global {
namespace Express {
interface Request {
key?: Key;
}
}
}
+2 -1
View File
@@ -11,5 +11,6 @@
"outDir": "build" "outDir": "build"
}, },
"include": ["src"], "include": ["src"],
"exclude": ["node_modules"] "exclude": ["node_modules"],
"files": ["src/types/custom.d.ts"]
} }