1 Commits

Author SHA1 Message Date
nai-degen 59141813d9 adds quick scale keyprovider 2023-07-05 22:11:25 -05:00
55 changed files with 901 additions and 2237 deletions
+14 -11
View File
@@ -10,7 +10,8 @@
# REJECT_DISALLOWED=false
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."
# CHECK_KEYS=true
# TURBO_ONLY=false
# QUOTA_DISPLAY_MODE=full
# QUEUE_MODE=fair
# BLOCKED_ORIGINS=reddit.com,9gag.com
# BLOCK_MESSAGE="You must be over the age of majority in your country to use this service."
# BLOCK_REDIRECT="https://roblox.com/"
@@ -18,8 +19,7 @@
# Note: CHECK_KEYS is disabled by default in local development mode, but enabled
# by default in production mode.
# Optional settings for user management and access control. See
# `docs/user-management.md` to learn how to use these.
# Optional settings for user management. See docs/user-management.md.
# GATEKEEPER=none
# GATEKEEPER_STORE=memory
# MAX_IPS_PER_USER=20
@@ -28,8 +28,7 @@
# PROMPT_LOGGING=false
# ------------------------------------------------------------------------------
# The values below are secret -- make sure they are set securely. Do NOT set
# them in the .env file of a public repository.
# The values below are secret -- make sure they are set securely.
# For Huggingface, set them via the Secrets section in your Space's config UI.
# For Render, create a "secret file" called .env using the Environment tab.
@@ -37,20 +36,24 @@
OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# TEMPORARY: This will eventually be replaced by a more robust system.
# You can adjust the models used when sending OpenAI prompts to /anthropic.
# Refer to Anthropic's docs for more info (note that they don't list older
# versions of the models, but they still work).
# CLAUDE_SMALL_MODEL=claude-v1.2
# CLAUDE_BIG_MODEL=claude-v1-100k
# You can require a Bearer token for requests when using proxy_token gatekeeper.
# PROXY_KEY=your-secret-key
# You can set an admin key for user management when using user_token gatekeeper.
# ADMIN_KEY=your-very-secret-key
# These are used to push data to a Huggingface Dataset repository.
# HF_DATASET_REPO_URL=https://huggingface.co/datasets/your-username/your-dataset-name
# HF_PRIVATE_SSH_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# These are used to persist user data to Firebase across restarts.
# These are used for various persistence features. Refer to the docs for more
# info.
# FIREBASE_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# FIREBASE_RTDB_URL=https://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.firebaseio.com
# These are used to log prompts to Google Sheets.
# This is only relevant if you want to use the prompt logging feature.
# GOOGLE_SHEETS_SPREADSHEET_ID=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# GOOGLE_SHEETS_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
+9 -5
View File
@@ -24,19 +24,19 @@ To set the password, create a `PROXY_KEY` secret in your environment.
## Per-user authentication (`GATEKEEPER=user_token`)
This mode allows you to provision separate Bearer tokens for each user. You can manage users via the /admin/users via REST or through the admin interface at `/admin`.
This mode allows you to provision separate Bearer tokens for each user. You can manage users via the /admin/users REST API, which itself requires an admin Bearer token.
To begin, set `ADMIN_KEY` to a secret value. This will be used to authenticate requests to the REST API or to log in to the UI.
To begin, set `ADMIN_KEY` to a secret value. This will be used to authenticate requests to the /admin/users REST API.
[You can find an OpenAPI specification for the /admin/users REST API here.](openapi-admin-users.yaml)
By default, the proxy will store user data in memory. Naturally, this means that user data will be lost when the proxy is restarted, though you can use the user import/export feature to save and restore user data manually or via a script. However, the proxy also supports persisting user data to an external data store with some additional configuration.
By default, the proxy will store user data in memory. Naturally, this means that user data will be lost when the proxy is restarted, though you can use the bulk user import/export feature to save and restore user data manually or via a script. However, the proxy also supports persisting user data to an external data store with some additional configuration.
Below are the supported data stores and their configuration options.
### Memory
This is the default data store (`GATEKEEPER_STORE=memory`) User data will be stored in memory and will be lost when the server is restarted. You are responsible for exporting and re-importing user data after a restart.
This is the default data store (`GATEKEEPER_STORE=memory`) User data will be stored in memory and will be lost when the proxy is restarted. You are responsible for downloading and re-uploading user data via the REST API if you want to persist it.
### Firebase Realtime Database
@@ -58,4 +58,8 @@ To use Firebase Realtime Database to persist user data, set the following enviro
7. Set `FIREBASE_RTDB_URL` to the reference URL of your Firebase Realtime Database, e.g. `https://my-project-default-rtdb.firebaseio.com`.
8. Set `GATEKEEPER_STORE` to `firebase_rtdb` in your environment if you haven't already.
The proxy server will attempt to connect to your Firebase Realtime Database at startup and will throw an error if it cannot connect. If you see this error, check that your `FIREBASE_RTDB_URL` and `FIREBASE_KEY` secrets are set correctly.
The proxy will attempt to connect to your Firebase Realtime Database at startup and will throw an error if it cannot connect. If you see this error, check that your `FIREBASE_RTDB_URL` and `FIREBASE_KEY` secrets are set correctly.
---
Users are loaded from the database and changes are flushed periodically. You can use the PUT /admin/users API to bulk import users and force a flush to the database.
+195 -515
View File
File diff suppressed because it is too large Load Diff
+8 -17
View File
@@ -3,8 +3,10 @@
"version": "1.0.0",
"description": "Reverse proxy for the OpenAI API",
"scripts": {
"build": "tsc && copyfiles -u 1 src/**/*.ejs build",
"start:dev": "nodemon --watch src --exec ts-node --transpile-only src/server.ts",
"build:watch": "esbuild src/server.ts --outfile=build/server.js --platform=node --target=es2020 --format=cjs --bundle --sourcemap --watch",
"build": "tsc",
"start:dev": "concurrently \"npm run build:watch\" \"npm run start:watch\"",
"start:dev:tsc": "nodemon --watch src --exec ts-node --transpile-only src/server.ts",
"start:watch": "nodemon --require source-map-support/register build/server.js",
"start:replit": "tsc && node build/server.js",
"start": "node build/server.js",
@@ -16,43 +18,32 @@
"author": "",
"license": "MIT",
"dependencies": {
"@anthropic-ai/tokenizer": "^0.0.4",
"axios": "^1.3.5",
"cookie-parser": "^1.4.6",
"copyfiles": "^2.4.1",
"cors": "^2.8.5",
"csrf-csrf": "^2.3.0",
"dotenv": "^16.0.3",
"ejs": "^3.1.9",
"express": "^4.18.2",
"firebase-admin": "^11.10.1",
"googleapis": "^122.0.0",
"firebase-admin": "^11.8.0",
"googleapis": "^117.0.0",
"http-proxy-middleware": "^3.0.0-beta.1",
"multer": "^1.4.5-lts.1",
"openai": "^3.2.1",
"pino": "^8.11.0",
"pino-http": "^8.3.3",
"showdown": "^2.1.0",
"tiktoken": "^1.0.10",
"uuid": "^9.0.0",
"zlib": "^1.0.5",
"zod": "^3.21.4"
},
"devDependencies": {
"@types/cookie-parser": "^1.4.3",
"@types/cors": "^2.8.13",
"@types/express": "^4.17.17",
"@types/multer": "^1.4.7",
"@types/showdown": "^2.0.0",
"@types/uuid": "^9.0.1",
"concurrently": "^8.0.1",
"esbuild": "^0.17.16",
"esbuild-register": "^3.4.2",
"nodemon": "^3.0.1",
"nodemon": "^2.0.22",
"source-map-support": "^0.5.21",
"ts-node": "^10.9.1",
"typescript": "^5.0.4"
},
"overrides": {
"google-gax": "^3.6.1"
}
}
-58
View File
@@ -1,58 +0,0 @@
import { Request, Response, RequestHandler } from "express";
import { config } from "../config";
const ADMIN_KEY = config.adminKey;
const failedAttempts = new Map<string, number>();
type AuthorizeParams = { via: "cookie" | "header" };
export const authorize: ({ via }: AuthorizeParams) => RequestHandler =
({ via }) =>
(req, res, next) => {
const bearerToken = req.headers.authorization?.slice("Bearer ".length);
const cookieToken = req.cookies["admin-token"];
const token = via === "cookie" ? cookieToken : bearerToken;
const attempts = failedAttempts.get(req.ip) ?? 0;
if (!token) {
return res.status(401).json({ error: "Unauthorized" });
}
if (!ADMIN_KEY) {
req.log.warn(
{ ip: req.ip },
`Blocked admin request because no admin key is configured`
);
return res.status(401).json({ error: "Unauthorized" });
}
if (attempts > 5) {
req.log.warn(
{ ip: req.ip, token: bearerToken },
`Blocked admin request due to too many failed attempts`
);
return res.status(401).json({ error: "Too many attempts" });
}
if (token !== ADMIN_KEY) {
req.log.warn(
{ ip: req.ip, attempts, token },
`Attempted admin request with invalid token`
);
return handleFailedLogin(req, res);
}
req.log.info({ ip: req.ip }, `Admin request authorized`);
next();
};
function handleFailedLogin(req: Request, res: Response) {
const attempts = failedAttempts.get(req.ip) ?? 0;
const newAttempts = attempts + 1;
failedAttempts.set(req.ip, newAttempts);
if (req.accepts("json", "html") === "json") {
return res.status(401).json({ error: "Unauthorized" });
}
res.clearCookie("admin-token");
return res.redirect("/admin/login?failed=true");
}
-58
View File
@@ -1,58 +0,0 @@
import { z } from "zod";
import { Query } from "express-serve-static-core";
export function parseSort(sort: Query["sort"]) {
if (!sort) return null;
if (typeof sort === "string") return sort.split(",");
if (Array.isArray(sort)) return sort.splice(3) as string[];
return null;
}
export function sortBy(fields: string[], asc = true) {
return (a: any, b: any) => {
for (const field of fields) {
if (a[field] !== b[field]) {
// always sort nulls to the end
if (a[field] == null) return 1;
if (b[field] == null) return -1;
const valA = Array.isArray(a[field]) ? a[field].length : a[field];
const valB = Array.isArray(b[field]) ? b[field].length : b[field];
const result = valA < valB ? -1 : 1;
return asc ? result : -result;
}
}
return 0;
};
}
export function paginate(set: unknown[], page: number, pageSize: number = 20) {
const p = Math.max(1, Math.min(page, Math.ceil(set.length / pageSize)));
return {
page: p,
items: set.slice((p - 1) * pageSize, p * pageSize),
pageSize,
pageCount: Math.ceil(set.length / pageSize),
totalCount: set.length,
nextPage: p * pageSize < set.length ? p + 1 : null,
prevPage: p > 1 ? p - 1 : null,
};
}
export const UserSchema = z
.object({
ip: z.array(z.string()).optional(),
type: z.enum(["normal", "special"]).optional(),
promptCount: z.number().optional(),
tokenCount: z.number().optional(),
createdAt: z.number().optional(),
lastUsedAt: z.number().optional(),
disabledAt: z.number().optional(),
disabledReason: z.string().optional(),
})
.strict();
export const UserSchemaWithToken = UserSchema.extend({
token: z.string(),
}).strict();
-24
View File
@@ -1,24 +0,0 @@
import { doubleCsrf } from "csrf-csrf";
import { v4 as uuid } from "uuid";
import express from "express";
const CSRF_SECRET = uuid();
const { generateToken, doubleCsrfProtection } = doubleCsrf({
getSecret: () => CSRF_SECRET,
cookieName: "csrf",
cookieOptions: { sameSite: "strict", path: "/" },
getTokenFromRequest: (req) => req.body["_csrf"] || req.query["_csrf"],
});
const injectCsrfToken: express.RequestHandler = (req, res, next) => {
res.locals.csrfToken = generateToken(res, req);
// force generation of new token on back button
// TODO: implement session-based CSRF tokens
res.setHeader("Cache-Control", "no-cache, no-store, must-revalidate");
res.setHeader("Pragma", "no-cache");
res.setHeader("Expires", "0");
next();
};
export { injectCsrfToken, doubleCsrfProtection as checkCsrfToken };
-29
View File
@@ -1,29 +0,0 @@
import { Router } from "express";
const loginRouter = Router();
loginRouter.get("/login", (req, res) => {
res.render("admin/login", { failed: req.query.failed });
});
loginRouter.post("/login", (req, res) => {
res.cookie("admin-token", req.body.token, {
maxAge: 1000 * 60 * 60 * 24 * 14,
httpOnly: true,
});
res.redirect("/admin");
});
loginRouter.get("/logout", (req, res) => {
res.clearCookie("admin-token");
res.redirect("/admin/login");
});
loginRouter.get("/", (req, res) => {
if (req.cookies["admin-token"]) {
return res.redirect("/admin/manage");
}
res.redirect("/admin/login");
});
export { loginRouter };
+29 -17
View File
@@ -1,24 +1,36 @@
import express, { Router } from "express";
import cookieParser from "cookie-parser";
import { authorize } from "./auth";
import { injectCsrfToken, checkCsrfToken } from "./csrf";
import { usersApiRouter as apiRouter } from "./api/users";
import { usersUiRouter as uiRouter } from "./ui/users";
import { loginRouter } from "./login";
import { RequestHandler, Router } from "express";
import { config } from "../config";
import { usersRouter } from "./users";
const ADMIN_KEY = config.adminKey;
const failedAttempts = new Map<string, number>();
const adminRouter = Router();
adminRouter.use(
express.json({ limit: "20mb" }),
express.urlencoded({ extended: true, limit: "20mb" })
);
adminRouter.use(cookieParser());
adminRouter.use(injectCsrfToken);
const auth: RequestHandler = (req, res, next) => {
const token = req.headers.authorization?.slice("Bearer ".length);
const attempts = failedAttempts.get(req.ip) ?? 0;
if (attempts > 5) {
req.log.warn(
{ ip: req.ip, token },
`Blocked request to admin API due to too many failed attempts`
);
return res.status(401).json({ error: "Too many attempts" });
}
adminRouter.use("/users", authorize({ via: "header" }), apiRouter);
if (token !== ADMIN_KEY) {
const newAttempts = attempts + 1;
failedAttempts.set(req.ip, newAttempts);
req.log.warn(
{ ip: req.ip, attempts: newAttempts, token },
`Attempted admin API request with invalid token`
);
return res.status(401).json({ error: "Unauthorized" });
}
adminRouter.use(checkCsrfToken); // All UI routes require CSRF token
adminRouter.use("/", loginRouter);
adminRouter.use("/manage", authorize({ via: "cookie" }), uiRouter);
next();
};
adminRouter.use(auth);
adminRouter.use("/users", usersRouter);
export { adminRouter };
-135
View File
@@ -1,135 +0,0 @@
import { Router } from "express";
import multer from "multer";
import { z } from "zod";
import { config } from "../../config";
import * as userStore from "../../proxy/auth/user-store";
import {
UserSchemaWithToken,
parseSort,
sortBy,
paginate,
UserSchema,
} from "../common";
const router = Router();
const upload = multer({
storage: multer.memoryStorage(),
fileFilter: (_req, file, cb) => {
if (file.mimetype !== "application/json") {
cb(new Error("Invalid file type"));
} else {
cb(null, true);
}
},
});
router.get("/create-user", (req, res) => {
const recentUsers = userStore
.getUsers()
.sort(sortBy(["createdAt"], false))
.slice(0, 5);
res.render("admin/create-user", {
recentUsers,
newToken: !!req.query.created,
});
});
router.post("/create-user", (_req, res) => {
userStore.createUser();
return res.redirect(`/admin/manage/create-user?created=true`);
});
router.get("/view-user/:token", (req, res) => {
const user = userStore.getUser(req.params.token);
if (!user) {
return res.status(404).send("User not found");
}
res.render("admin/view-user", { user });
});
router.get("/list-users", (req, res) => {
const sort = parseSort(req.query.sort) || ["promptCount", "lastUsedAt"];
const requestedPageSize =
Number(req.query.perPage) || Number(req.cookies.perPage) || 20;
const perPage = Math.max(1, Math.min(1000, requestedPageSize));
const users = userStore.getUsers().sort(sortBy(sort, false));
const page = Number(req.query.page) || 1;
const { items, ...pagination } = paginate(users, page, perPage);
return res.render("admin/list-users", {
sort: sort.join(","),
users: items,
...pagination,
});
});
router.get("/import-users", (req, res) => {
const imported = Number(req.query.imported) || 0;
res.render("admin/import-users", { imported });
});
router.post("/import-users", upload.single("users"), (req, res) => {
if (!req.file) {
return res.status(400).json({ error: "No file uploaded" });
}
const data = JSON.parse(req.file.buffer.toString());
const result = z.array(UserSchemaWithToken).safeParse(data.users);
if (!result.success) {
return res.status(400).json({ error: result.error });
}
const upserts = result.data.map((user) => userStore.upsertUser(user));
res.redirect(`/admin/manage/import-users?imported=${upserts.length}`);
});
router.get("/export-users", (_req, res) => {
res.render("admin/export-users");
});
router.get("/export-users.json", (_req, res) => {
const users = userStore.getUsers();
res.setHeader("Content-Disposition", "attachment; filename=users.json");
res.setHeader("Content-Type", "application/json");
res.send(JSON.stringify({ users }, null, 2));
});
router.get("/", (_req, res) => {
res.render("admin/index", {
isPersistenceEnabled: config.gatekeeperStore !== "memory",
});
});
router.post("/edit-user/:token", (req, res) => {
const result = UserSchema.safeParse(req.body);
if (!result.success) {
return res.status(400).send(result.error);
}
userStore.upsertUser({ ...result.data, token: req.params.token });
return res.sendStatus(204);
});
router.post("/reactivate-user/:token", (req, res) => {
const user = userStore.getUser(req.params.token);
if (!user) {
return res.status(404).send("User not found");
}
userStore.upsertUser({
token: user.token,
disabledAt: 0,
disabledReason: "",
});
return res.sendStatus(204);
});
router.post("/disable-user/:token", (req, res) => {
const user = userStore.getUser(req.params.token);
if (!user) {
return res.status(404).send("User not found");
}
userStore.disableUser(req.params.token, req.body.reason);
return res.sendStatus(204);
});
export { router as usersUiRouter };
+33 -14
View File
@@ -1,17 +1,37 @@
import { Router } from "express";
import { z } from "zod";
import * as userStore from "../../proxy/auth/user-store";
import { UserSchema, UserSchemaWithToken, parseSort, sortBy } from "../common";
import * as userStore from "../proxy/auth/user-store";
const router = Router();
const usersRouter = Router();
const UserSchema = z
.object({
ip: z.array(z.string()).optional(),
type: z.enum(["normal", "special"]).optional(),
promptCount: z.number().optional(),
tokenCount: z.number().optional(),
createdAt: z.number().optional(),
lastUsedAt: z.number().optional(),
disabledAt: z.number().optional(),
disabledReason: z.string().optional(),
})
.strict();
const UserSchemaWithToken = UserSchema.extend({
token: z.string(),
}).strict();
/**
* Returns a list of all users, sorted by prompt count and then last used time.
* GET /admin/users
*/
router.get("/", (req, res) => {
const sort = parseSort(req.query.sort) || ["promptCount", "lastUsedAt"];
const users = userStore.getUsers().sort(sortBy(sort, false));
usersRouter.get("/", (_req, res) => {
const users = userStore.getUsers().sort((a, b) => {
if (a.promptCount !== b.promptCount) {
return b.promptCount - a.promptCount;
}
return (b.lastUsedAt ?? 0) - (a.lastUsedAt ?? 0);
});
res.json({ users, count: users.length });
});
@@ -19,7 +39,7 @@ router.get("/", (req, res) => {
* Returns the user with the given token.
* GET /admin/users/:token
*/
router.get("/:token", (req, res) => {
usersRouter.get("/:token", (req, res) => {
const user = userStore.getUser(req.params.token);
if (!user) {
return res.status(404).json({ error: "Not found" });
@@ -32,9 +52,8 @@ router.get("/:token", (req, res) => {
* Returns the created user's token.
* POST /admin/users
*/
router.post("/", (req, res) => {
const token = userStore.createUser();
res.json({ token });
usersRouter.post("/", (_req, res) => {
res.json({ token: userStore.createUser() });
});
/**
@@ -43,7 +62,7 @@ router.post("/", (req, res) => {
* Returns the upserted user.
* PUT /admin/users/:token
*/
router.put("/:token", (req, res) => {
usersRouter.put("/:token", (req, res) => {
const result = UserSchema.safeParse(req.body);
if (!result.success) {
return res.status(400).json({ error: result.error });
@@ -58,7 +77,7 @@ router.put("/:token", (req, res) => {
* Returns an object containing the upserted users and the number of upserts.
* PUT /admin/users
*/
router.put("/", (req, res) => {
usersRouter.put("/", (req, res) => {
const result = z.array(UserSchemaWithToken).safeParse(req.body.users);
if (!result.success) {
return res.status(400).json({ error: result.error });
@@ -76,7 +95,7 @@ router.put("/", (req, res) => {
* Returns the disabled user.
* DELETE /admin/users/:token
*/
router.delete("/:token", (req, res) => {
usersRouter.delete("/:token", (req, res) => {
const user = userStore.getUser(req.params.token);
const disabledReason = z
.string()
@@ -92,4 +111,4 @@ router.delete("/:token", (req, res) => {
res.json(userStore.getUser(req.params.token));
});
export { router as usersApiRouter };
export { usersRouter };
+35 -43
View File
@@ -9,14 +9,17 @@ const startupLogger = pino({ level: "debug" }).child({ module: "startup" });
const isDev = process.env.NODE_ENV !== "production";
type PromptLoggingBackend = "google_sheets";
export type DequeueMode = "fair" | "random" | "none";
export type Config = {
type Config = {
/** The port the proxy server will listen on. */
port: number;
/** Comma-delimited list of OpenAI API keys. */
openaiKey?: string;
/** Comma-delimited list of Anthropic API keys. */
anthropicKey?: string;
scaleKey?: string;
scaleMinDeployments: number;
/**
* The proxy key to require for requests. Only applicable if the user
* management mode is set to 'proxy_key', and required if so.
@@ -25,7 +28,7 @@ export type Config = {
/**
* The admin key used to access the /admin API. Required if the user
* management mode is set to 'user_token'.
**/
*/
adminKey?: string;
/**
* Which user management mode to use.
@@ -47,21 +50,13 @@ export type Config = {
* `memory`: Users are stored in memory and are lost on restart (default)
*
* `firebase_rtdb`: Users are stored in a Firebase Realtime Database; requires
* `firebaseKey` and `firebaseRtdbUrl` to be set. (deprecated)
*
* `huggingface_datasets`: Users are stored in a Huggingface Datasets git
* repository; requires `hfDatasetRepoUrl` and `hfPrivateSshKey` to be set.
**/
gatekeeperStore: "memory" | "firebase_rtdb" | "huggingface_datasets";
* `firebaseKey` and `firebaseRtdbUrl` to be set.
*/
gatekeeperStore: "memory" | "firebase_rtdb";
/** URL of the Firebase Realtime Database if using the Firebase RTDB store. */
firebaseRtdbUrl?: string;
/** Base64-encoded Firebase service account key if using the Firebase RTDB store. */
firebaseKey?: string;
/** URL of the Huggingface Datasets git repository if using the Huggingface
* Datasets store. */
hfDatasetRepoUrl?: string;
/** Private SSH key used to push to the Huggingface Dataset repository. */
hfPrivateSshKey?: string;
/**
* Maximum number of IPs per user, after which their token is disabled.
* Users with the manually-assigned `special` role are exempt from this limit.
@@ -70,20 +65,6 @@ export type Config = {
maxIpsPerUser: number;
/** Per-IP limit for requests per minute to OpenAI's completions endpoint. */
modelRateLimit: number;
/**
* For OpenAI, the maximum number of context tokens (prompt + max output) a
* user can request before their request is rejected.
* Context limits can help prevent excessive spend.
* Defaults to 0, which means no limit beyond OpenAI's stated maximums.
*/
maxContextTokensOpenAI: number;
/**
* For Anthropic, the maximum number of context tokens a user can request.
* Claude context limits can prevent requests from tying up concurrency slots
* for too long, which can lengthen queue times for other users.
* Defaults to 0, which means no limit beyond Anthropic's stated maximums.
*/
maxContextTokensAnthropic: number;
/** For OpenAI, the maximum number of sampled tokens a user can request. */
maxOutputTokensOpenAI: number;
/** For Anthropic, the maximum number of sampled tokens a user can request. */
@@ -104,6 +85,26 @@ export type Config = {
googleSheetsSpreadsheetId?: string;
/** Whether to periodically check keys for usage and validity. */
checkKeys?: boolean;
/**
* How to display quota information on the info page.
*
* `none`: Hide quota information
*
* `partial`: Display quota information only as a percentage
*
* `full`: Display quota information as usage against total capacity
*/
quotaDisplayMode: "none" | "partial" | "full";
/**
* Which request queueing strategy to use when keys are over their rate limit.
*
* `fair`: Requests are serviced in the order they were received (default)
*
* `random`: Requests are serviced randomly
*
* `none`: Requests are not queued and users have to retry manually
*/
queueMode: DequeueMode;
/**
* Comma-separated list of origins to block. Requests matching any of these
* origins or referers will be rejected.
@@ -120,11 +121,6 @@ export type Config = {
* Desination URL to redirect blocked requests to, for non-JSON requests.
*/
blockRedirect?: string;
/**
* Whether the proxy should disallow requests for GPT-4 models in order to
* prevent excessive spend. Applies only to OpenAI.
*/
turboOnly?: boolean;
};
// To change configs, create a file called .env in the root directory.
@@ -133,6 +129,8 @@ export const config: Config = {
port: getEnvWithDefault("PORT", 7860),
openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
scaleKey: getEnvWithDefault("SCALE_KEY", ""),
scaleMinDeployments: getEnvWithDefault("SCALE_MIN_DEPLOYMENTS", 0),
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
@@ -140,18 +138,11 @@ export const config: Config = {
maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0),
firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined),
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
hfDatasetRepoUrl: getEnvWithDefault("HF_DATASET_REPO_URL", undefined),
hfPrivateSshKey: getEnvWithDefault("HF_PRIVATE_SSH_KEY", undefined),
modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 0),
maxContextTokensAnthropic: getEnvWithDefault(
"MAX_CONTEXT_TOKENS_ANTHROPIC",
0
),
maxOutputTokensOpenAI: getEnvWithDefault("MAX_OUTPUT_TOKENS_OPENAI", 300),
maxOutputTokensAnthropic: getEnvWithDefault(
"MAX_OUTPUT_TOKENS_ANTHROPIC",
400
600
),
rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false),
rejectMessage: getEnvWithDefault(
@@ -160,6 +151,7 @@ export const config: Config = {
),
logLevel: getEnvWithDefault("LOG_LEVEL", "info"),
checkKeys: getEnvWithDefault("CHECK_KEYS", !isDev),
quotaDisplayMode: getEnvWithDefault("QUOTA_DISPLAY_MODE", "partial"),
promptLogging: getEnvWithDefault("PROMPT_LOGGING", false),
promptLoggingBackend: getEnvWithDefault("PROMPT_LOGGING_BACKEND", undefined),
googleSheetsKey: getEnvWithDefault("GOOGLE_SHEETS_KEY", undefined),
@@ -167,13 +159,13 @@ export const config: Config = {
"GOOGLE_SHEETS_SPREADSHEET_ID",
undefined
),
queueMode: getEnvWithDefault("QUEUE_MODE", "fair"),
blockedOrigins: getEnvWithDefault("BLOCKED_ORIGINS", undefined),
blockMessage: getEnvWithDefault(
"BLOCK_MESSAGE",
"You must be over the age of majority in your country to use this service."
),
blockRedirect: getEnvWithDefault("BLOCK_REDIRECT", "https://www.9gag.com"),
turboOnly: getEnvWithDefault("TURBO_ONLY", false),
} as const;
function migrateConfigs() {
@@ -274,14 +266,14 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"logLevel",
"openaiKey",
"anthropicKey",
"scaleKey",
"proxyKey",
"adminKey",
"checkKeys",
"quotaDisplayMode",
"googleSheetsKey",
"firebaseKey",
"firebaseRtdbUrl",
"hfDatasetRepoUrl",
"hfPrivateSshKey",
"gatekeeperStore",
"maxIpsPerUser",
"blockedOrigins",
+60 -42
View File
@@ -2,7 +2,7 @@ import fs from "fs";
import { Request, Response } from "express";
import showdown from "showdown";
import { config, listConfig } from "./config";
import { OpenAIKey, keyPool } from "./key-management";
import { keyPool } from "./key-management";
import { getUniqueIps } from "./proxy/rate-limit";
import {
QueuePartition,
@@ -78,9 +78,7 @@ function cacheInfoPageHtml(baseUrl: string) {
type ServiceInfo = {
activeKeys: number;
trialKeys?: number;
// activeLimit: string;
revokedKeys?: number;
overQuotaKeys?: number;
quota: string;
proomptersInQueue: number;
estimatedQueueTime: string;
};
@@ -90,55 +88,68 @@ type ServiceInfo = {
function getOpenAIInfo() {
const info: { [model: string]: Partial<ServiceInfo> } = {};
const keys = keyPool
.list()
.filter((k) => k.service === "openai") as OpenAIKey[];
const hasGpt4 = keys.some((k) => k.isGpt4) && !config.turboOnly;
const keys = keyPool.list().filter((k) => k.service === "openai");
const hasGpt4 = keys.some((k) => k.isGpt4);
if (keyPool.anyUnchecked()) {
const uncheckedKeys = keys.filter((k) => !k.lastChecked);
info.status =
`Performing startup key checks (${uncheckedKeys.length} left).` as any;
info.status = `Still checking ${uncheckedKeys.length} keys...` as any;
} else {
delete info.status;
}
if (config.checkKeys) {
const turboKeys = keys.filter((k) => !k.isGpt4);
const gpt4Keys = keys.filter((k) => k.isGpt4);
const turboKeys = keys.filter((k) => !k.isGpt4 && !k.isDisabled);
const gpt4Keys = keys.filter((k) => k.isGpt4 && !k.isDisabled);
const quota: Record<string, string> = { turbo: "", gpt4: "" };
const turboQuota = keyPool.remainingQuota("openai") * 100;
const gpt4Quota = keyPool.remainingQuota("openai", { gpt4: true }) * 100;
if (config.quotaDisplayMode === "full") {
const turboUsage = keyPool.usageInUsd("openai");
const gpt4Usage = keyPool.usageInUsd("openai", { gpt4: true });
quota.turbo = `${turboUsage} (${Math.round(turboQuota)}% remaining)`;
quota.gpt4 = `${gpt4Usage} (${Math.round(gpt4Quota)}% remaining)`;
} else {
quota.turbo = `${Math.round(turboQuota)}%`;
quota.gpt4 = `${Math.round(gpt4Quota * 100)}%`;
}
info.turbo = {
activeKeys: turboKeys.filter((k) => !k.isDisabled).length,
trialKeys: turboKeys.filter((k) => k.isTrial).length,
revokedKeys: turboKeys.filter((k) => k.isRevoked).length,
overQuotaKeys: turboKeys.filter((k) => k.isOverQuota).length,
quota: quota.turbo,
};
if (hasGpt4) {
info.gpt4 = {
activeKeys: gpt4Keys.filter((k) => !k.isDisabled).length,
trialKeys: gpt4Keys.filter((k) => k.isTrial).length,
revokedKeys: gpt4Keys.filter((k) => k.isRevoked).length,
overQuotaKeys: gpt4Keys.filter((k) => k.isOverQuota).length,
quota: quota.gpt4,
};
}
if (config.quotaDisplayMode === "none") {
delete info.turbo?.quota;
delete info.gpt4?.quota;
}
} else {
info.status = "Key checking is disabled." as any;
info.turbo = { activeKeys: keys.filter((k) => !k.isDisabled).length };
info.gpt4 = {
activeKeys: keys.filter((k) => !k.isDisabled && k.isGpt4).length,
};
}
const turboQueue = getQueueInformation("turbo");
if (config.queueMode !== "none") {
const turboQueue = getQueueInformation("turbo");
info.turbo.proomptersInQueue = turboQueue.proomptersInQueue;
info.turbo.estimatedQueueTime = turboQueue.estimatedQueueTime;
info.turbo.proomptersInQueue = turboQueue.proomptersInQueue;
info.turbo.estimatedQueueTime = turboQueue.estimatedQueueTime;
if (hasGpt4) {
const gpt4Queue = getQueueInformation("gpt-4");
info.gpt4.proomptersInQueue = gpt4Queue.proomptersInQueue;
info.gpt4.estimatedQueueTime = gpt4Queue.estimatedQueueTime;
if (hasGpt4) {
const gpt4Queue = getQueueInformation("gpt-4");
info.gpt4.proomptersInQueue = gpt4Queue.proomptersInQueue;
info.gpt4.estimatedQueueTime = gpt4Queue.estimatedQueueTime;
}
}
return info;
@@ -148,9 +159,11 @@ function getAnthropicInfo() {
const claudeInfo: Partial<ServiceInfo> = {};
const keys = keyPool.list().filter((k) => k.service === "anthropic");
claudeInfo.activeKeys = keys.filter((k) => !k.isDisabled).length;
const queue = getQueueInformation("claude");
claudeInfo.proomptersInQueue = queue.proomptersInQueue;
claudeInfo.estimatedQueueTime = queue.estimatedQueueTime;
if (config.queueMode !== "none") {
const queue = getQueueInformation("claude");
claudeInfo.proomptersInQueue = queue.proomptersInQueue;
claudeInfo.estimatedQueueTime = queue.estimatedQueueTime;
}
return { claude: claudeInfo };
}
@@ -176,23 +189,25 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t
**If you are uncomfortable with this, don't send prompts to this proxy!**`;
}
const waits: string[] = [];
infoBody += `\n## Estimated Wait Times\nIf the AI is busy, your prompt will processed when a slot frees up.`;
if (config.queueMode !== "none") {
const waits = [];
infoBody += `\n## Estimated Wait Times\nIf the AI is busy, your prompt will processed when a slot frees up.`;
if (config.openaiKey) {
const turboWait = getQueueInformation("turbo").estimatedQueueTime;
const gpt4Wait = getQueueInformation("gpt-4").estimatedQueueTime;
waits.push(`**Turbo:** ${turboWait}`);
if (keyPool.list().some((k) => k.isGpt4) && !config.turboOnly) {
waits.push(`**GPT-4:** ${gpt4Wait}`);
if (config.openaiKey) {
const turboWait = getQueueInformation("turbo").estimatedQueueTime;
const gpt4Wait = getQueueInformation("gpt-4").estimatedQueueTime;
waits.push(`**Turbo:** ${turboWait}`);
if (keyPool.list().some((k) => k.isGpt4)) {
waits.push(`**GPT-4:** ${gpt4Wait}`);
}
}
}
if (config.anthropicKey) {
const claudeWait = getQueueInformation("claude").estimatedQueueTime;
waits.push(`**Claude:** ${claudeWait}`);
if (config.anthropicKey) {
const claudeWait = getQueueInformation("claude").estimatedQueueTime;
waits.push(`**Claude:** ${claudeWait}`);
}
infoBody += "\n\n" + waits.join(" / ");
}
infoBody += "\n\n" + waits.join(" / ");
if (customGreeting) {
infoBody += `\n## Server Greeting\n
@@ -203,6 +218,9 @@ ${customGreeting}`;
/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */
function getQueueInformation(partition: QueuePartition) {
if (config.queueMode === "none") {
return {};
}
const waitMs = getEstimatedWaitTime(partition);
const waitTime =
waitMs < 60000
+22 -17
View File
@@ -3,13 +3,11 @@ import { Key, KeyProvider } from "..";
import { config } from "../../config";
import { logger } from "../../logger";
// https://docs.anthropic.com/claude/reference/selecting-a-model
export const ANTHROPIC_SUPPORTED_MODELS = [
"claude-instant-v1",
"claude-instant-v1-100k",
"claude-v1",
"claude-v1-100k",
"claude-2",
] as const;
export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number];
@@ -40,16 +38,10 @@ export interface AnthropicKey extends Key {
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
* We don't get rate limit headers from Anthropic so if we get a 429, we just
* lock out the key for a few seconds
*/
const RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
const RATE_LIMIT_LOCKOUT = 5000;
export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
readonly service = "anthropic";
@@ -135,7 +127,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
// Intended to throttle the queue processor as otherwise it will just
// flood the API with requests and we want to wait a sec to see if we're
// going to get a rate limit error on this key.
selectedKey.rateLimitedUntil = now + KEY_REUSE_DELAY;
selectedKey.rateLimitedUntil = now + 1000;
return { ...selectedKey };
}
@@ -189,9 +181,15 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
/**
* This is called when we receive a 429, which means there are already five
* concurrent requests running on this key. We don't have any information on
* when these requests will resolve, so all we can do is wait a bit and try
* again. We will lock the key for 2 seconds after getting a 429 before
* retrying in order to give the other requests a chance to finish.
* when these requests will resolve so all we can do is wait a bit and try
* again.
* We will lock the key for 10 seconds, which should let a few of the other
* generations finish. This is an arbitrary number but the goal is to balance
* between not hammering the API with requests and not locking out a key that
* is actually available.
* TODO; Try to assign requests to slots on each key so we have an idea of how
* long each slot has been running and can make a more informed decision on
* how long to lock the key.
*/
public markRateLimited(keyHash: string) {
this.log.warn({ key: keyHash }, "Key rate limited");
@@ -201,7 +199,14 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public activeLimitInUsd() {
return "∞";
public remainingQuota() {
const activeKeys = this.keys.filter((k) => !k.isDisabled).length;
const allKeys = this.keys.length;
if (activeKeys === 0) return 0;
return Math.round((activeKeys / allKeys) * 100) / 100;
}
public usageInUsd() {
return "$0.00 / ∞";
}
}
+3 -2
View File
@@ -5,7 +5,7 @@ import {
} from "./anthropic/provider";
import { KeyPool } from "./key-pool";
export type AIService = "openai" | "anthropic";
export type AIService = "openai" | "anthropic" | "scale";
export type Model = OpenAIModel | AnthropicModel;
export interface Key {
@@ -52,7 +52,8 @@ export interface KeyProvider<T extends Key = Key> {
anyUnchecked(): boolean;
incrementPrompt(hash: string): void;
getLockoutPeriod(model: Model): number;
activeLimitInUsd(options?: Record<string, unknown>): string;
remainingQuota(options?: Record<string, unknown>): number;
usageInUsd(options?: Record<string, unknown>): string;
markRateLimited(hash: string): void;
}
+10 -9
View File
@@ -32,15 +32,9 @@ export class KeyPool {
return this.keyProviders.flatMap((provider) => provider.list());
}
public disable(key: Key, reason: "quota" | "revoked"): void {
public disable(key: Key): void {
const service = this.getKeyProvider(key.service);
service.disable(key);
if (service instanceof OpenAIKeyProvider) {
service.update(key.hash, {
isRevoked: reason === "revoked",
isOverQuota: reason === "quota",
});
}
}
public update(key: Key, props: AllowedPartial): void {
@@ -81,11 +75,18 @@ export class KeyPool {
}
}
public activeLimitInUsd(
public remainingQuota(
service: AIService,
options?: Record<string, unknown>
): number {
return this.getKeyProvider(service).remainingQuota(options);
}
public usageInUsd(
service: AIService,
options?: Record<string, unknown>
): string {
return this.getKeyProvider(service).activeLimitInUsd(options);
return this.getKeyProvider(service).usageInUsd(options);
}
private getService(model: Model): AIService {
+92 -161
View File
@@ -1,24 +1,14 @@
import axios, { AxiosError } from "axios";
import { Configuration, OpenAIApi } from "openai";
import { logger } from "../../logger";
import type { OpenAIKey, OpenAIKeyProvider } from "./provider";
/** Minimum time in between any two key checks. */
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
/**
* Minimum time in between checks for a given key. Because we can no longer
* read quota usage, there is little reason to check a single key more often
* than this.
**/
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
const KEY_CHECK_PERIOD = 5 * 60 * 1000; // 5 minutes
const POST_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions";
const GET_MODELS_URL = "https://api.openai.com/v1/models";
const GET_SUBSCRIPTION_URL =
"https://api.openai.com/dashboard/billing/subscription";
type GetModelsResponse = {
data: [{ id: string }];
};
const GET_USAGE_URL = "https://api.openai.com/dashboard/billing/usage";
type GetSubscriptionResponse = {
plan: { title: string };
@@ -28,6 +18,10 @@ type GetSubscriptionResponse = {
system_hard_limit_usd: number;
};
type GetUsageResponse = {
total_usage: number;
};
type OpenAIError = {
error: { type: string; code: string; param: unknown; message: string };
};
@@ -60,8 +54,7 @@ export class OpenAIKeyChecker {
/**
* Schedules the next check. If there are still keys yet to be checked, it
* will schedule a check immediately for the next unchecked key. Otherwise,
* it will schedule a check for the least recently checked key, respecting
* the minimum check interval.
* it will schedule a check in several minutes for the oldest key.
**/
private scheduleNextCheck() {
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
@@ -101,8 +94,8 @@ export class OpenAIKeyChecker {
key.lastChecked < oldest.lastChecked ? key : oldest
);
// Don't check any individual key too often.
// Don't check anything at all at a rate faster than once per 3 seconds.
// Don't check any individual key more than once every 5 minutes.
// Also, don't check anything more often than once every 3 seconds.
const nextCheck = Math.max(
oldestKey.lastChecked + KEY_CHECK_PERIOD,
this.lastCheck + MIN_CHECK_INTERVAL
@@ -129,43 +122,47 @@ export class OpenAIKeyChecker {
this.log.debug({ key: key.hash }, "Checking key...");
let isInitialCheck = !key.lastChecked;
try {
// We only need to check for provisioned models on the initial check.
// During the initial check we need to get the subscription first because
// trials have different behavior.
if (isInitialCheck) {
const [/* subscription,*/ provisionedModels, livenessTest] =
await Promise.all([
// this.getSubscription(key),
this.getProvisionedModels(key),
this.testLiveness(key),
]);
const subscription = await this.getSubscription(key);
this.updateKey(key.hash, { isTrial: !subscription.has_payment_method });
if (key.isTrial) {
this.log.debug(
{ key: key.hash },
"Attempting generation on trial key."
);
await this.assertCanGenerate(key);
}
const [provisionedModels, usage] = await Promise.all([
this.getProvisionedModels(key),
this.getUsage(key),
]);
const updates = {
isGpt4: provisionedModels.gpt4,
// softLimit: subscription.soft_limit_usd,
// hardLimit: subscription.hard_limit_usd,
// systemHardLimit: subscription.system_hard_limit_usd,
isTrial: livenessTest.rateLimit <= 250,
softLimit: 0,
hardLimit: 0,
systemHardLimit: 0,
softLimit: subscription.soft_limit_usd,
hardLimit: subscription.hard_limit_usd,
systemHardLimit: subscription.system_hard_limit_usd,
usage,
};
this.updateKey(key.hash, updates);
} else {
// Provisioned models don't change, so we don't need to check them again
const [/* subscription, */ _livenessTest] = await Promise.all([
// this.getSubscription(key),
this.testLiveness(key),
// Don't check provisioned models after the initial check because it's
// not likely to change.
const [subscription, usage] = await Promise.all([
this.getSubscription(key),
this.getUsage(key),
]);
const updates = {
// softLimit: subscription.soft_limit_usd,
// hardLimit: subscription.hard_limit_usd,
// systemHardLimit: subscription.system_hard_limit_usd,
softLimit: 0,
hardLimit: 0,
systemHardLimit: 0,
softLimit: subscription.soft_limit_usd,
hardLimit: subscription.hard_limit_usd,
systemHardLimit: subscription.system_hard_limit_usd,
usage,
};
this.updateKey(key.hash, updates);
}
this.log.info(
{ key: key.hash, hardLimit: key.hardLimit },
{ key: key.hash, usage: key.usage, hardLimit: key.hardLimit },
"Key check complete."
);
} catch (error) {
@@ -178,28 +175,17 @@ export class OpenAIKeyChecker {
// Only enqueue the next check if this wasn't a startup check, since those
// are batched together elsewhere.
if (!isInitialCheck) {
// this.scheduleNextCheck();
this.scheduleNextCheck();
}
}
private async getProvisionedModels(
key: OpenAIKey
): Promise<{ turbo: boolean; gpt4: boolean }> {
const opts = { headers: { Authorization: `Bearer ${key.key}` } };
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
const models = data.data;
const openai = new OpenAIApi(new Configuration({ apiKey: key.key }));
const models = (await openai.listModels()!).data.data;
const turbo = models.some(({ id }) => id.startsWith("gpt-3.5"));
const gpt4 = models.some(({ id }) => id.startsWith("gpt-4"));
// We want to update the key's `isGpt4` flag here, but we don't want to
// update its `lastChecked` timestamp because we need to let the liveness
// check run before we can consider the key checked.
// Need to use `find` here because keys are cloned from the pool.
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
this.updateKey(key.hash, {
isGpt4: gpt4,
lastChecked: keyFromPool.lastChecked,
});
return { turbo, gpt4 };
}
@@ -208,137 +194,82 @@ export class OpenAIKeyChecker {
GET_SUBSCRIPTION_URL,
{ headers: { Authorization: `Bearer ${key.key}` } }
);
// See note above about updating the key's `lastChecked` timestamp.
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
this.updateKey(key.hash, {
isTrial: !data.has_payment_method,
lastChecked: keyFromPool.lastChecked,
});
return data;
}
private async getUsage(key: OpenAIKey) {
const querystring = OpenAIKeyChecker.getUsageQuerystring(key.isTrial);
const url = `${GET_USAGE_URL}?${querystring}`;
const { data } = await axios.get<GetUsageResponse>(url, {
headers: { Authorization: `Bearer ${key.key}` },
});
return parseFloat((data.total_usage / 100).toFixed(2));
}
private handleAxiosError(key: OpenAIKey, error: AxiosError) {
if (error.response && OpenAIKeyChecker.errorIsOpenAIError(error)) {
if (error.response && OpenAIKeyChecker.errorIsOpenAiError(error)) {
const { status, data } = error.response;
if (status === 401) {
this.log.warn(
{ key: key.hash, error: data },
"Key is invalid or revoked. Disabling key."
);
this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
isGpt4: false,
});
} else if (status === 429) {
switch (data.error.type) {
case "insufficient_quota":
case "access_terminated":
case "billing_not_active":
const isOverQuota = data.error.type === "insufficient_quota";
const isRevoked = !isOverQuota;
const isGpt4 = isRevoked ? false : key.isGpt4;
this.log.warn(
{ key: key.hash, rateLimitType: data.error.type, error: data },
"Key returned a non-transient 429 error. Disabling key."
);
this.updateKey(key.hash, {
isDisabled: true,
isRevoked,
isOverQuota,
isGpt4,
});
break;
case "requests":
// Trial keys have extremely low requests-per-minute limits and we
// can often hit them just while checking the key, so we need to
// retry the check later to know if the key has quota remaining.
this.log.warn(
{ key: key.hash, error: data },
"Key is currently rate limited, so its liveness cannot be checked. Retrying in fifteen seconds."
);
// To trigger a shorter than usual delay before the next check, we
// will set its `lastChecked` to (NOW - (KEY_CHECK_PERIOD - 15s)).
// This will cause the usual key check scheduling logic to schedule
// the next check in 15 seconds. This also prevents the key from
// holding up startup checks for other keys.
const fifteenSeconds = 15 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - fifteenSeconds);
this.updateKey(key.hash, { lastChecked: next });
break;
case "tokens":
// Hitting a token rate limit, even on a trial key, actually implies
// that the key is valid and can generate completions, so we will
// treat this as effectively a successful `testLiveness` call.
this.log.info(
{ key: key.hash },
"Key is currently `tokens` rate limited; assuming it is operational."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
break;
default:
this.log.error(
{ key: key.hash, rateLimitType: data.error.type, error: data },
"Encountered unexpected rate limit error class while checking key. This may indicate a change in the API; please report this."
);
// We don't know what this error means, so we just let the key
// through and maybe it will fail when someone tries to use it.
this.updateKey(key.hash, { lastChecked: Date.now() });
}
this.updateKey(key.hash, { isDisabled: true });
} else if (status === 429 && data.error.type === "insufficient_quota") {
this.log.warn(
{ key: key.hash, isTrial: key.isTrial, error: data },
"Key is out of quota. Disabling key."
);
this.updateKey(key.hash, { isDisabled: true });
} else {
this.log.error(
{ key: key.hash, status, error: data },
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
"Encountered API error while checking key."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
}
return;
}
this.log.error(
{ key: key.hash, error: error.message },
"Network error while checking key; trying this key again in a minute."
{ key: key.hash, error },
"Network error while checking key; trying again later."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
/**
* Tests whether the key is valid and has quota remaining. The request we send
* is actually not valid, but keys which are revoked or out of quota will fail
* with a 401 or 429 error instead of the expected 400 Bad Request error.
* This lets us avoid test keys without spending any quota.
*
* We use the rate limit header to determine whether it's a trial key.
* Trial key usage reporting is inaccurate, so we need to run an actual
* completion to test them for liveness.
*/
private async testLiveness(key: OpenAIKey): Promise<{ rateLimit: number }> {
const payload = {
private async assertCanGenerate(key: OpenAIKey): Promise<void> {
const openai = new OpenAIApi(new Configuration({ apiKey: key.key }));
// This will throw an AxiosError if the key is invalid or out of quota.
await openai.createChatCompletion({
model: "gpt-3.5-turbo",
max_tokens: -1,
messages: [{ role: "user", content: "" }],
};
const { headers, data } = await axios.post<OpenAIError>(
POST_CHAT_COMPLETIONS_URL,
payload,
{
headers: { Authorization: `Bearer ${key.key}` },
validateStatus: (status) => status === 400,
}
);
const rateLimitHeader = headers["x-ratelimit-limit-requests"];
const rateLimit = parseInt(rateLimitHeader) || 3500; // trials have 200
// invalid_request_error is the expected error
if (data.error.type !== "invalid_request_error") {
this.log.warn(
{ key: key.hash, error: data },
"Unexpected 400 error class while checking key; assuming key is valid, but this may indicate a change in the API."
);
}
return { rateLimit };
messages: [{ role: "user", content: "Hello" }],
max_tokens: 1,
});
}
static errorIsOpenAIError(
static getUsageQuerystring(isTrial: boolean) {
// For paid keys, the limit resets every month, so we can use the first day
// of the current month.
// For trial keys, the limit does not reset and we don't know when the key
// was created, so we use 99 days ago because that's as far back as the API
// will let us go.
// End date needs to be set to the beginning of the next day so that we get
// usage for the current day.
const today = new Date();
const startDate = isTrial
? new Date(today.getTime() - 99 * 24 * 60 * 60 * 1000)
: new Date(today.getFullYear(), today.getMonth(), 1);
const endDate = new Date(today.getTime() + 24 * 60 * 60 * 1000);
return `start_date=${startDate.toISOString().split("T")[0]}&end_date=${
endDate.toISOString().split("T")[0]
}`;
}
static errorIsOpenAiError(
error: AxiosError
): error is AxiosError<OpenAIError> {
const data = error.response?.data as any;
+37 -25
View File
@@ -18,10 +18,8 @@ export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
export interface OpenAIKey extends Key {
readonly service: "openai";
/** Set when key check returns a 401. */
isRevoked: boolean;
/** Set when key check returns a non-transient 429. */
isOverQuota: boolean;
/** The current usage of this key. */
usage: number;
/** Threshold at which a warning email will be sent by OpenAI. */
softLimit: number;
/** Threshold at which the key will be disabled because it has reached the user-defined limit. */
@@ -56,7 +54,7 @@ export interface OpenAIKey extends Key {
export type OpenAIKeyUpdate = Omit<
Partial<OpenAIKey>,
"key" | "hash" | "promptCount"
"key" | "hash" | "lastUsed" | "lastChecked" | "promptCount"
>;
export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
@@ -79,11 +77,9 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
const newKey = {
key: k,
service: "openai" as const,
isGpt4: true,
isGpt4: false,
isTrial: false,
isDisabled: false,
isRevoked: false,
isOverQuota: false,
softLimit: 0,
hardLimit: 0,
systemHardLimit: 0,
@@ -132,17 +128,11 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
);
if (availableKeys.length === 0) {
let message = needGpt4
? "No GPT-4 keys available. Try selecting a Turbo model."
? "No GPT-4 keys available. Try selecting a non-GPT-4 model."
: "No active OpenAI keys available.";
throw new Error(message);
}
if (needGpt4 && config.turboOnly) {
throw new Error(
"Proxy operator has disabled GPT-4 to reduce quota usage. Try selecting a Turbo model."
);
}
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. We ignore rate limits from over a minute ago
@@ -187,7 +177,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
/** Called by the key checker to update key information. */
public update(keyHash: string, update: OpenAIKeyUpdate) {
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
Object.assign(keyFromPool, { ...update, lastChecked: Date.now() });
// this.writeKeyStatus();
}
@@ -196,6 +186,9 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
const keyFromPool = this.keys.find((k) => k.key === key.key);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
// If it's disabled just set the usage to the hard limit so it doesn't
// mess with the aggregate usage.
keyFromPool.usage = keyFromPool.hardLimit;
this.log.warn({ key: key.hash }, "Key disabled");
}
@@ -302,16 +295,35 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
}
}
/**
* Returns the total quota limit of all keys in USD. Keys which are disabled
* are not included in the total.
*/
public activeLimitInUsd(
{ gpt4 }: { gpt4: boolean } = { gpt4: false }
): string {
const keys = this.keys.filter((k) => !k.isDisabled && k.isGpt4 === gpt4);
/** Returns the remaining aggregate quota for all keys as a percentage. */
public remainingQuota({ gpt4 }: { gpt4: boolean } = { gpt4: false }): number {
const keys = this.keys.filter((k) => k.isGpt4 === gpt4);
if (keys.length === 0) return 0;
const totalUsage = keys.reduce((acc, key) => {
// Keys can slightly exceed their quota
return acc + Math.min(key.usage, key.hardLimit);
}, 0);
const totalLimit = keys.reduce((acc, { hardLimit }) => acc + hardLimit, 0);
return `$${totalLimit.toFixed(2)}`;
return 1 - totalUsage / totalLimit;
}
/** Returns used and available usage in USD. */
public usageInUsd({ gpt4 }: { gpt4: boolean } = { gpt4: false }): string {
const keys = this.keys.filter((k) => k.isGpt4 === gpt4);
if (keys.length === 0) return "???";
const totalHardLimit = keys.reduce(
(acc, { hardLimit }) => acc + hardLimit,
0
);
const totalUsage = keys.reduce((acc, key) => {
// Keys can slightly exceed their quota
return acc + Math.min(key.usage, key.hardLimit);
}, 0);
return `$${totalUsage.toFixed(2)} / $${totalHardLimit.toFixed(2)}`;
}
/** Writes key status to disk. */
+155
View File
@@ -0,0 +1,155 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../config";
import { logger } from "../../logger";
export interface ScaleDeployment extends Key {
readonly service: "scale";
deploymentUrl: string;
createdAt: number;
}
/*
Scale is a bit different from the other providers. It doesn't have set API keys;
instead there are "deployments", which are created in the Scale dashboard and
are accessible via a URL and API key together.
The operator can provide these accounts via the SCALE_KEY environment variable,
but more likely they will want the proxy to just automatically create new
accounts and deployments as older ones reach their usage limits.
*/
export class ScaleKeyProvider implements KeyProvider<ScaleDeployment> {
readonly service = "scale";
private deployments: ScaleDeployment[] = [];
private log = logger.child({ module: "key-provider", service: this.service });
private churnerEnabled = false;
constructor() {
const keyConfig = config.scaleKey?.trim();
if (!keyConfig) return;
let initialKeys: string[];
initialKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const keyStr of initialKeys) {
const [key, deploymentUrl] = keyStr.split("$");
const newDeployment: ScaleDeployment = {
key,
deploymentUrl,
service: this.service,
isGpt4: false,
isTrial: false,
isDisabled: false,
promptCount: 0,
lastUsed: 0,
createdAt: Date.now(),
hash: `sca-${crypto
.createHash("sha256")
.update(keyStr)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
};
this.deployments.push(newDeployment);
}
this.log.info(
{ keyCount: this.deployments.length },
"Loaded initial Scale deployments"
);
}
public init() {
// TODO: Start account churner
this.churnerEnabled = true;
}
public list() {
return this.deployments.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: unknown) {
// Scale doesn't support changing models on the fly
const availableDeployments = this.deployments.filter((a) => !a.isDisabled);
const canCreateNewAccounts = config.scaleMinDeployments > 0;
if (availableDeployments.length === 0) {
if (canCreateNewAccounts) {
this.log.warn(
"Ran out of Scale deployments and the churner is not creating new ones fast enough."
);
throw new Error(
"No Scale deployments available. Try again in a few minutes when the churner has created new deployments."
);
} else {
throw new Error(
"No Scale deployments available and account churner is disabled (possible IP ban or signup rate limit)."
);
}
}
// Unlike other providers, Scale doesn't want to rotate keys. Instead, we
// want to use the same key for as long as possible while building up a
// reserve of new accounts. Once an account dies there should be a fresh
// one ready to go.
const now = Date.now();
const deploymentsByPriority = availableDeployments.sort((a, b) => {
return a.createdAt - b.createdAt;
});
const selectedKey = deploymentsByPriority[0];
selectedKey.lastUsed = now;
return { ...selectedKey };
}
public disable(deployment: ScaleDeployment) {
const deploymentFromPool = this.deployments.find(
(d) => d.hash === deployment.hash
);
if (!deploymentFromPool || deploymentFromPool.isDisabled) return;
deploymentFromPool.isDisabled = true;
this.log.warn({ key: deployment.hash }, "Scale deployment disabled");
}
public update(hash: string, update: Partial<ScaleDeployment>) {
const deploymentFromPool = this.deployments.find((d) => d.hash === hash)!;
Object.assign(deploymentFromPool, update);
}
public available() {
return this.deployments.filter((k) => !k.isDisabled).length;
}
// Normally this would return the number of unchecked keys but we will
// repurpose it to return the number of pending accounts the churner is
// creating.
public anyUnchecked() {
return config.scaleMinDeployments - this.available() > 0;
}
public incrementPrompt(hash?: string) {
const deployment = this.deployments.find((d) => d.hash === hash);
if (!deployment) return;
deployment.promptCount++;
}
public getLockoutPeriod(_model: unknown) {
// TODO: Scale doesn't have rate limits but this may need to be repurposed
// to lock out the request queue if the account churner enabled but falling
// behind.
return 0;
}
public markRateLimited(keyHash: string) {
// Do nothing
}
/** Doesn't really mean anything for Scale */
public remainingQuota() {
return 1;
}
public usageInUsd() {
return "$0.00 / ∞";
}
}
-167
View File
@@ -1,167 +0,0 @@
/**
* Very scuffed persistence system using a Huggingface's Datasets git repo as a
* file system. We use this because it's free and everyone is already deploying
* to Huggingface's Spaces feature anyway, so they can easily create a Dataset
* repository too rather than having to find some other place to host files.
*
* We periodically commit to the repo, and then pull from it when we need to
* read data. This is a bit slow, but it's fine for our purposes.
*/
import fs from "fs";
import os from "os";
import path from "path";
import { spawn } from "child_process";
import { config, Config } from "./config";
import { logger } from "./logger";
const log = logger.child({ module: "dataset-persistence" });
let singleton: DatasetPersistence | null = null;
class DatasetPersistence {
private initialized: boolean = false;
private keyPath = `${os.tmpdir()}/id_rsa`;
private repoPath = `${os.tmpdir()}/oai-proxy-dataset`;
private repoUrl!: string;
private sshKey!: string;
constructor() {
if (singleton) return singleton;
if (config.gatekeeperStore !== "huggingface_datasets") return;
DatasetPersistence.assertConfigured(config);
this.repoUrl = config.hfDatasetRepoUrl;
this.sshKey = config.hfPrivateSshKey.trim();
singleton = this;
}
async init() {
if (this.initialized) return;
log.info(
{ repoUrl: this.repoUrl, keyPath: this.keyPath, repoPath: this.repoPath },
"Initializing Huggingface Datasets persistence."
);
try {
this.setupSshKey();
await this.runGit(
"config user.email 'oai-proxy-persistence@example.com'"
);
await this.runGit("config user.name 'Proxy Persistence'");
log.info("Cloning repo...");
const cloneOutput = await this.runGit(
`clone --depth 1 ${this.repoUrl} ${this.repoPath}`
);
log.info({ output: cloneOutput.toString() }, "Cloned repo.");
// Test write access
const pushOutput = this.runGit("push").toString();
if (pushOutput !== "Everything up-to-date") {
log.error({ output: pushOutput }, "Unexpected output from git push.");
throw new Error("Unable to push to repo.");
}
log.info("Datasets configuration looks good.");
} catch (e) {
log.error(
{ error: e },
"Failed to initialize Huggingface Datasets persistence."
);
throw e;
}
this.initialized = true;
}
async get(key: string): Promise<Buffer | null> {
try {
await this.init();
this.runGit(`checkout HEAD -- ${key}`);
const filePath = path.join(this.repoPath, key);
return fs.promises.readFile(filePath);
} catch (e) {
log.error({ error: e }, "Failed to get key from Dataset repo.");
return null;
}
}
async set(key: string, value: Buffer) {
try {
await this.init();
await fs.promises.writeFile(`${this.repoPath}/${key}`, value);
// TODO: Need to set up LFS for >10MB files
if (fs.statSync(`${this.repoPath}/${key}`).size > 10 * 1024 * 1024) {
throw new Error("File too large for non-LFS storage.");
}
await this.runGit(`add ${key}`);
await this.runGit(`commit -m "Update ${key}"`);
await this.runGit("push");
} catch (e) {
log.error({ error: e }, "Failed to set key in Dataset repo.");
}
}
protected async cleanup() {
try {
await this.init();
await this.runGit("fetch --depth 1");
await this.runGit("reset --hard FETCH_HEAD");
} catch (e) {
log.error({ error: e }, "Failed to cleanup Dataset repo.");
}
}
protected async setupSshKey() {
fs.writeFileSync(this.keyPath, this.sshKey);
fs.chmodSync(this.keyPath, 0o600);
await this.runGit(`config core.sshCommand 'ssh -i ${this.keyPath}'`);
}
protected async runGit(command: string) {
const cmd = `git -C ${this.repoPath} ${command}`;
log.debug({ command: cmd }, "Running git command.");
return new Promise<string>((resolve, reject) => {
const proc = spawn(cmd, { shell: true });
const stdout: string[] = [];
const stderr: string[] = [];
proc.stdout.on("data", (data) => stdout.push(data.toString()));
proc.stderr.on("data", (data) => stderr.push(data.toString()));
proc.on("close", (code) => {
if (code !== 0) {
const errorOutput = stderr.join("");
log.error({ code, errorOutput }, "Git command failed.");
reject(
new Error(
`Git command failed with exit code ${code}: ${errorOutput}`
)
);
} else {
resolve(stdout.join(""));
}
});
});
}
static assertConfigured(input: Config): asserts input is ConfigWithDatasets {
if (!input.hfDatasetRepoUrl) {
throw new Error("HF_DATASET_REPO_URL is required when using Datasets.");
}
if (!input.hfPrivateSshKey) {
throw new Error("HF_PRIVATE_SSH_KEY is required when using Datasets.");
}
}
}
type ConfigWithDatasets = Config & {
hfDatasetRepoUrl: string;
hfPrivateSshKey: string;
};
export { DatasetPersistence };
+3 -3
View File
@@ -256,9 +256,9 @@ export const appendBatch = async (batch: PromptLogEntry[]) => {
return [
entry.model,
entry.endpoint,
entry.promptRaw.slice(0, 50000),
entry.promptFlattened.slice(0, 50000),
entry.response.slice(0, 50000),
entry.promptRaw,
entry.promptFlattened,
entry.response,
];
});
log.info({ sheetName, rowCount: newRows.length }, "Appending log batch.");
+3 -9
View File
@@ -13,6 +13,7 @@ import {
createPreprocessorMiddleware,
finalizeBody,
languageFilter,
limitOutputTokens,
removeOriginHeaders,
} from "./middleware/request";
import {
@@ -42,8 +43,6 @@ const getModelsResponse = () => {
"claude-instant-v1.1",
"claude-instant-v1.1-100k",
"claude-instant-v1.0",
"claude-2", // claude-2 is 100k by default it seems
"claude-2.0",
];
const models = claudeVariants.map((id) => ({
@@ -75,6 +74,7 @@ const rewriteAnthropicRequest = (
addKey,
addAnthropicPreamble,
languageFilter,
limitOutputTokens,
blockZoomerOrigins,
removeOriginHeaders,
finalizeBody,
@@ -106,16 +106,10 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
if (!req.originalUrl.includes("/v1/complete")) {
req.log.info("Transforming Anthropic response to OpenAI format");
body = transformAnthropicResponse(body);
}
// TODO: Remove once tokenization is stable
if (req.debug) {
body.proxy_tokenizer_debug_info = req.debug;
}
res.status(200).json(body);
};
-64
View File
@@ -1,64 +0,0 @@
/**
* Authenticates RisuAI.xyz users using a special x-risu-tk header provided by
* RisuAI.xyz. This lets us rate limit and limit queue concurrency properly,
* since otherwise RisuAI.xyz users share the same IP address and can't be
* distinguished.
* Contributors: @kwaroran
*/
import axios from "axios";
import { Request, Response, NextFunction } from "express";
const RISUAI_TOKEN_CHECKER_URL = "https://sv.risuai.xyz/public/api/checktoken";
const validRisuTokens = new Set<string>();
let lastFailedRisuTokenCheck = 0;
export async function checkRisuToken(
req: Request,
_res: Response,
next: NextFunction
) {
let header = req.header("x-risu-tk") || null;
if (!header) {
return next();
}
const timeSinceLastFailedCheck = Date.now() - lastFailedRisuTokenCheck;
if (timeSinceLastFailedCheck < 60 * 1000) {
req.log.warn(
{ timeSinceLastFailedCheck },
"Skipping RisuAI token check due to recent failed check"
);
return next();
}
try {
if (!validRisuTokens.has(header)) {
req.log.info("Authenticating new RisuAI token");
const validCheck = await axios.post<{ vaild: boolean }>(
RISUAI_TOKEN_CHECKER_URL,
{ token: header },
{ headers: { "Content-Type": "application/json" } }
);
if (!validCheck.data.vaild) {
req.log.warn("Invalid RisuAI token; using IP instead");
} else {
req.log.info("RisuAI token authenticated");
validRisuTokens.add(header);
req.risuToken = header;
}
} else {
req.log.debug("RisuAI token already known");
req.risuToken = header;
}
} catch (err) {
lastFailedRisuTokenCheck = Date.now();
req.log.warn(
{ error: err.message },
"Error authenticating RisuAI token; using IP instead"
);
}
next();
}
+1 -1
View File
@@ -33,7 +33,7 @@ export const gatekeeper: RequestHandler = (req, res, next) => {
// TODO: Generate anonymous users based on IP address for public or proxy_key
// modes so that all middleware can assume a user of some sort is present.
if (ADMIN_KEY && token === ADMIN_KEY) {
if (token === ADMIN_KEY) {
return next();
}
+8
View File
@@ -13,6 +13,7 @@ import {
createPreprocessorMiddleware,
finalizeBody,
languageFilter,
limitOutputTokens,
transformKoboldPayload,
} from "./middleware/request";
import {
@@ -33,11 +34,18 @@ const rewriteRequest = (
req: Request,
res: Response
) => {
if (config.queueMode !== "none") {
const msg = `Queueing is enabled on this proxy instance and is incompatible with the KoboldAI endpoint. Use the OpenAI endpoint instead.`;
proxyReq.destroy(new Error(msg));
return;
}
req.body.stream = false;
const rewriterPipeline = [
addKey,
transformKoboldPayload,
languageFilter,
limitOutputTokens,
finalizeBody,
];
+2 -5
View File
@@ -21,7 +21,7 @@ export function writeErrorResponse(
statusCode: number,
errorPayload: Record<string, any>
) {
const errorSource = errorPayload.error?.type?.startsWith("proxy")
const errorSource = errorPayload.error?.type.startsWith("proxy")
? "proxy"
: "upstream";
@@ -45,9 +45,6 @@ export function writeErrorResponse(
res.write(`data: [DONE]\n\n`);
res.end();
} else {
if (req.debug) {
errorPayload.error.proxy_tokenizer_debug_info = req.debug;
}
res.status(statusCode).json(errorPayload);
}
}
@@ -89,7 +86,7 @@ export const handleInternalError = (
} else {
writeErrorResponse(req, res, 500, {
error: {
type: "proxy_internal_error",
type: "proxy_rewriter_error",
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
message: err.message,
stack: err.stack,
+2
View File
@@ -41,6 +41,8 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
// For such cases, ignore the requested model entirely.
if (req.inboundApi === "openai" && req.outboundApi === "anthropic") {
req.log.debug("Using an Anthropic key for an OpenAI-compatible request");
// We don't assign the model here, that will happen when transforming the
// request body.
assignedKey = keyPool.get("claude-v1");
} else {
assignedKey = keyPool.get(req.body.model);
@@ -1,129 +0,0 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { countTokens } from "../../../tokenization";
import { RequestPreprocessor } from ".";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
/**
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
* and outbound API format, which combined determine the size of the context.
* If the context is too large, an error is thrown.
* This preprocessor should run after any preprocessor that transforms the
* request body.
*/
export const checkContextSize: RequestPreprocessor = async (req) => {
let prompt;
switch (req.outboundApi) {
case "openai":
req.outputTokens = req.body.max_tokens;
prompt = req.body.messages;
break;
case "anthropic":
req.outputTokens = req.body.max_tokens_to_sample;
prompt = req.body.prompt;
break;
default:
throw new Error(`Unknown outbound API: ${req.outboundApi}`);
}
const result = await countTokens({ req, prompt, service: req.outboundApi });
req.promptTokens = result.token_count;
// TODO: Remove once token counting is stable
req.log.debug({ result: result }, "Counted prompt tokens.");
req.debug = req.debug ?? {};
req.debug = { ...req.debug, ...result };
maybeReassignModel(req);
validateContextSize(req);
};
function validateContextSize(req: Request) {
assertRequestHasTokenCounts(req);
const promptTokens = req.promptTokens;
const outputTokens = req.outputTokens;
const contextTokens = promptTokens + outputTokens;
const model = req.body.model;
const proxyMax =
(req.outboundApi === "openai" ? OPENAI_MAX_CONTEXT : CLAUDE_MAX_CONTEXT) ||
Number.MAX_SAFE_INTEGER;
let modelMax = 0;
if (model.match(/gpt-3.5-turbo-16k/)) {
modelMax = 16384;
} else if (model.match(/gpt-3.5-turbo/)) {
modelMax = 4096;
} else if (model.match(/gpt-4-32k/)) {
modelMax = 32768;
} else if (model.match(/gpt-4/)) {
modelMax = 8192;
} else if (model.match(/claude-(?:instant-)?v1(?:\.\d)?(?:-100k)/)) {
modelMax = 100000;
} else if (model.match(/claude-(?:instant-)?v1(?:\.\d)?$/)) {
modelMax = 9000;
} else if (model.match(/claude-2/)) {
modelMax = 100000;
} else {
// Don't really want to throw here because I don't want to have to update
// this ASAP every time a new model is released.
req.log.warn({ model }, "Unknown model, using 100k token limit.");
modelMax = 100000;
}
const finalMax = Math.min(proxyMax, modelMax);
z.number()
.int()
.max(finalMax, {
message: `Your request exceeds the context size limit for this model or proxy. (max: ${finalMax} tokens, requested: ${promptTokens} prompt + ${outputTokens} output = ${contextTokens} context tokens)`,
})
.parse(contextTokens);
req.log.debug(
{ promptTokens, outputTokens, contextTokens, modelMax, proxyMax },
"Prompt size validated"
);
req.debug.prompt_tokens = promptTokens;
req.debug.max_model_tokens = modelMax;
req.debug.max_proxy_tokens = proxyMax;
}
function assertRequestHasTokenCounts(
req: Request
): asserts req is Request & { promptTokens: number; outputTokens: number } {
z.object({
promptTokens: z.number().int().min(1),
outputTokens: z.number().int().min(1),
})
.nonstrict()
.parse(req);
}
/**
* For OpenAI-to-Anthropic requests, users can't specify the model, so we need
* to pick one based on the final context size. Ideally this would happen in
* the `transformOutboundPayload` preprocessor, but we don't have the context
* size at that point (and need a transformed body to calculate it).
*/
function maybeReassignModel(req: Request) {
if (req.inboundApi !== "openai" || req.outboundApi !== "anthropic") {
return;
}
const bigModel = process.env.CLAUDE_BIG_MODEL || "claude-v1-100k";
const contextSize = req.promptTokens! + req.outputTokens!;
if (contextSize > 8500) {
req.log.debug(
{ model: bigModel, contextSize },
"Using Claude 100k model for OpenAI-to-Anthropic request"
);
req.body.model = bigModel;
}
// Small model is the default already set in `transformOutboundPayload`
}
+1 -1
View File
@@ -4,7 +4,6 @@ import type { ProxyReqCallback } from "http-proxy";
// Express middleware (runs before http-proxy-middleware, can be async)
export { createPreprocessorMiddleware } from "./preprocess";
export { checkContextSize } from "./check-context-size";
export { setApiFormat } from "./set-api-format";
export { transformOutboundPayload } from "./transform-outbound-payload";
@@ -15,6 +14,7 @@ export { blockZoomerOrigins } from "./block-zoomer-origins";
export { finalizeBody } from "./finalize-body";
export { languageFilter } from "./language-filter";
export { limitCompletions } from "./limit-completions";
export { limitOutputTokens } from "./limit-output-tokens";
export { removeOriginHeaders } from "./remove-origin-headers";
export { transformKoboldPayload } from "./transform-kobold-payload";
@@ -0,0 +1,46 @@
import { Request } from "express";
import { config } from "../../../config";
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
/** Enforce a maximum number of tokens requested from the model. */
export const limitOutputTokens: ProxyRequestMiddleware = (_proxyReq, req) => {
// TODO: do all of this shit in the zod validator
if (isCompletionRequest(req)) {
const requestedMax = Number.parseInt(getMaxTokensFromRequest(req));
const apiMax =
req.outboundApi === "openai"
? config.maxOutputTokensOpenAI
: config.maxOutputTokensAnthropic;
let maxTokens = requestedMax;
if (typeof requestedMax !== "number") {
maxTokens = apiMax;
}
maxTokens = Math.min(maxTokens, apiMax);
if (req.outboundApi === "openai") {
req.body.max_tokens = maxTokens;
} else if (req.outboundApi === "anthropic") {
req.body.max_tokens_to_sample = maxTokens;
}
if (requestedMax !== maxTokens) {
req.log.info(
{ requestedMax, configMax: apiMax, final: maxTokens },
"Limiting user's requested max output tokens"
);
}
}
};
function getMaxTokensFromRequest(req: Request) {
switch (req.outboundApi) {
case "anthropic":
return req.body?.max_tokens_to_sample;
case "openai":
return req.body?.max_tokens;
default:
throw new Error(`Unknown service: ${req.outboundApi}`);
}
}
+1 -7
View File
@@ -1,11 +1,6 @@
import { RequestHandler } from "express";
import { handleInternalError } from "../common";
import {
RequestPreprocessor,
checkContextSize,
setApiFormat,
transformOutboundPayload,
} from ".";
import { RequestPreprocessor, setApiFormat, transformOutboundPayload } from ".";
/**
* Returns a middleware function that processes the request body into the given
@@ -18,7 +13,6 @@ export const createPreprocessorMiddleware = (
const preprocessors: RequestPreprocessor[] = [
setApiFormat(apiFormat),
transformOutboundPayload,
checkContextSize,
...(additionalPreprocessors ?? []),
];
@@ -1,12 +1,8 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { OpenAIPromptMessage } from "../../../tokenization";
import { isCompletionRequest } from "../common";
import { RequestPreprocessor } from ".";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
// import { countTokens } from "../../../tokenization";
// https://console.anthropic.com/docs/api/reference#-v1-complete
const AnthropicV1CompleteSchema = z.object({
@@ -15,10 +11,7 @@ const AnthropicV1CompleteSchema = z.object({
required_error:
"No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?",
}),
max_tokens_to_sample: z.coerce
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
max_tokens_to_sample: z.coerce.number(),
stop_sequences: z.array(z.string()).optional(),
stream: z.boolean().optional().default(false),
temperature: z.coerce.number().optional().default(1),
@@ -39,8 +32,6 @@ const OpenAIV1ChatCompletionSchema = z.object({
{
required_error:
"No prompt found. Are you sending an Anthropic-formatted request to the OpenAI endpoint?",
invalid_type_error:
"Messages were not formatted correctly. Refer to the OpenAI Chat API documentation for more information.",
}
),
temperature: z.number().optional().default(1),
@@ -54,12 +45,7 @@ const OpenAIV1ChatCompletionSchema = z.object({
.optional(),
stream: z.boolean().optional().default(false),
stop: z.union([z.string(), z.array(z.string())]).optional(),
max_tokens: z.coerce
.number()
.int()
.optional()
.default(16)
.transform((v) => Math.min(v, OPENAI_OUTPUT_MAX)),
max_tokens: z.coerce.number().optional(),
frequency_penalty: z.number().optional().default(0),
presence_penalty: z.number().optional().default(0),
logit_bias: z.any().optional(),
@@ -77,6 +63,7 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
}
if (sameService) {
// Just validate, don't transform.
const validator =
req.outboundApi === "openai"
? OpenAIV1ChatCompletionSchema
@@ -89,12 +76,11 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
);
throw result.error;
}
req.body = result.data;
return;
}
if (req.inboundApi === "openai" && req.outboundApi === "anthropic") {
req.body = await openaiToAnthropic(req.body, req);
req.body = openaiToAnthropic(req.body, req);
return;
}
@@ -103,7 +89,7 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
);
};
async function openaiToAnthropic(body: any, req: Request) {
function openaiToAnthropic(body: any, req: Request) {
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.error(
@@ -121,7 +107,45 @@ async function openaiToAnthropic(body: any, req: Request) {
req.headers["anthropic-version"] = "2023-01-01";
const { messages, ...rest } = result.data;
const prompt = openAIMessagesToClaudePrompt(messages);
const prompt =
result.data.messages
.map((m) => {
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "Human";
}
// https://console.anthropic.com/docs/prompt-design
// `name` isn't supported by Anthropic but we can still try to use it.
return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${
m.content
}`;
})
.join("") + "\n\nAssistant: ";
// Claude 1.2 has been selected as the default for smaller prompts because it
// is said to be less pozzed than the newer 1.3 model. But this is not based
// on any empirical testing, just speculation based on Anthropic stating that
// 1.3 is "safer and less susceptible to adversarial attacks" than 1.2.
// From my own interactions, both are pretty easy to jailbreak so I don't
// think there's much of a difference, honestly.
// If you want to override the model selection, you can set the
// CLAUDE_BIG_MODEL and CLAUDE_SMALL_MODEL environment variables in your
// .env file.
// Using "v1" of a model will automatically select the latest version of that
// model on the Anthropic side.
const CLAUDE_BIG = process.env.CLAUDE_BIG_MODEL || "claude-v1-100k";
const CLAUDE_SMALL = process.env.CLAUDE_SMALL_MODEL || "claude-v1.2";
// TODO: Finish implementing tokenizer for more accurate model selection.
// This currently uses _character count_, not token count.
const model = prompt.length > 25000 ? CLAUDE_BIG : CLAUDE_SMALL;
let stops = rest.stop
? Array.isArray(rest.stop)
@@ -138,35 +162,9 @@ async function openaiToAnthropic(body: any, req: Request) {
return {
...rest,
// Model may be overridden in `calculate-context-size.ts` to avoid having
// a circular dependency (`calculate-context-size.ts` needs an already-
// transformed request body to count tokens, but this function would like
// to know the count to select a model).
model: process.env.CLAUDE_SMALL_MODEL || "claude-v1",
model,
prompt: prompt,
max_tokens_to_sample: rest.max_tokens,
stop_sequences: stops,
};
}
export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
return (
messages
.map((m) => {
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "Human";
}
// https://console.anthropic.com/docs/prompt-design
// `name` isn't supported by Anthropic but we can still try to use it.
return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${
m.content
}`;
})
.join("") + "\n\nAssistant:"
);
}
@@ -282,7 +282,7 @@ function convertEventsToFinalResponse(events: string[], req: Request) {
* the final SSE event before the "DONE" event, so we can reuse that
*/
const lastEvent = events[events.length - 2].toString();
const data = JSON.parse(lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length));
const data = JSON.parse(lastEvent.slice("data: ".length));
const response: AnthropicCompletionResponse = {
...data,
log_id: req.id,
+20 -16
View File
@@ -269,7 +269,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
}
} else if (statusCode === 401) {
// Key is invalid or was revoked
keyPool.disable(req.key!, "revoked");
keyPool.disable(req.key!);
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
} else if (statusCode === 429) {
// OpenAI uses this for a bunch of different rate-limiting scenarios.
@@ -341,8 +341,11 @@ function maybeHandleMissingPreambleError(
"Request failed due to missing preamble. Key will be marked as such for subsequent requests."
);
keyPool.update(req.key!, { requiresPreamble: true });
reenqueueRequest(req);
throw new RetryableError("Claude request re-enqueued to add preamble.");
if (config.queueMode !== "none") {
reenqueueRequest(req);
throw new RetryableError("Claude request re-enqueued to add preamble.");
}
errorPayload.proxy_note = `This Claude key requires special prompt formatting. Try again; the proxy will reformat your prompt next time.`;
} else {
errorPayload.proxy_note = `Proxy received unrecognized error from Anthropic. Check the specific error for more information.`;
}
@@ -354,8 +357,11 @@ function handleAnthropicRateLimitError(
) {
if (errorPayload.error?.type === "rate_limit_error") {
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
throw new RetryableError("Claude rate-limited request re-enqueued.");
if (config.queueMode !== "none") {
reenqueueRequest(req);
throw new RetryableError("Claude rate-limited request re-enqueued.");
}
errorPayload.proxy_note = `There are too many in-flight requests for this key. Try again later.`;
} else {
errorPayload.proxy_note = `Unrecognized rate limit error from Anthropic. Key may be over quota.`;
}
@@ -369,24 +375,22 @@ function handleOpenAIRateLimitError(
const type = errorPayload.error?.type;
if (type === "insufficient_quota") {
// Billing quota exceeded (key is dead, disable it)
keyPool.disable(req.key!, "quota");
keyPool.disable(req.key!);
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
} else if (type === "access_terminated") {
// Account banned (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`;
} else if (type === "billing_not_active") {
// Billing is not active (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
keyPool.disable(req.key!);
errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`;
} else if (type === "requests" || type === "tokens") {
// Per-minute request or token rate limit is exceeded, which we can retry
keyPool.markRateLimited(req.key!);
// I'm aware this is confusing -- throwing this class of error will cause
// the proxy response handler to return without terminating the request,
// so that it can be placed back in the queue.
reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
if (config.queueMode !== "none") {
reenqueueRequest(req);
// This is confusing, but it will bubble up to the top-level response
// handler and cause the request to go back into the request queue.
throw new RetryableError("Rate-limited request re-enqueued.");
}
errorPayload.proxy_note = `Assigned key's '${type}' rate limit has been exceeded. Try again later.`;
} else {
// OpenAI probably overloaded
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
+2 -5
View File
@@ -14,6 +14,7 @@ import {
finalizeBody,
languageFilter,
limitCompletions,
limitOutputTokens,
removeOriginHeaders,
} from "./middleware/request";
import {
@@ -92,6 +93,7 @@ const rewriteRequest = (
const rewriterPipeline = [
addKey,
languageFilter,
limitOutputTokens,
limitCompletions,
blockZoomerOrigins,
removeOriginHeaders,
@@ -123,11 +125,6 @@ const openaiResponseHandler: ProxyResHandlerWithBody = async (
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
// TODO: Remove once tokenization is stable
if (req.debug) {
body.proxy_tokenizer_debug_info = req.debug;
}
res.status(200).json(body);
};
+31 -25
View File
@@ -16,6 +16,7 @@
*/
import type { Handler, Request } from "express";
import { config, DequeueMode } from "../config";
import { keyPool, SupportedModel } from "../key-management";
import { logger } from "../logger";
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
@@ -26,39 +27,31 @@ export type QueuePartition = "claude" | "turbo" | "gpt-4";
const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
let dequeueMode: DequeueMode = "fair";
/** Maximum number of queue slots for Agnai.chat requests. */
const AGNAI_CONCURRENCY_LIMIT = 15;
/** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = 1;
/**
* Returns a unique identifier for a request. This is used to determine if a
* request is already in the queue.
* This can be (in order of preference):
* - user token assigned by the proxy operator
* - x-risu-tk header, if the request is from RisuAI.xyz
* - IP address
*/
function getIdentifier(req: Request) {
if (req.user) {
return req.user.token;
}
if (req.risuToken) {
return req.risuToken;
}
return req.ip;
}
const sameIpPredicate = (incoming: Request) => (queued: Request) =>
queued.ip === incoming.ip;
const sameUserPredicate = (incoming: Request) => (queued: Request) => {
const queuedId = getIdentifier(queued);
const incomingId = getIdentifier(incoming);
return queuedId === incomingId;
const incomingUser = incoming.user ?? { token: incoming.ip };
const queuedUser = queued.user ?? { token: queued.ip };
return queuedUser.token === incomingUser.token;
};
export function enqueue(req: Request) {
const enqueuedRequestCount = queue.filter(sameUserPredicate(req)).length;
let enqueuedRequestCount = 0;
let isGuest = req.user?.token === undefined;
if (isGuest) {
enqueuedRequestCount = queue.filter(sameIpPredicate(req)).length;
} else {
enqueuedRequestCount = queue.filter(sameUserPredicate(req)).length;
}
// All Agnai.chat requests come from the same IP, so we allow them to have
// more spots in the queue. Can't make it unlimited because people will
// intentionally abuse it.
@@ -157,9 +150,18 @@ export function dequeue(partition: QueuePartition): Request | undefined {
return undefined;
}
const req = modelQueue.reduce((prev, curr) =>
prev.startTime < curr.startTime ? prev : curr
);
let req: Request;
if (dequeueMode === "fair") {
// Dequeue the request that has been waiting the longest
req = modelQueue.reduce((prev, curr) =>
prev.startTime < curr.startTime ? prev : curr
);
} else {
// Dequeue a random request
const index = Math.floor(Math.random() * modelQueue.length);
req = modelQueue[index];
}
queue.splice(queue.indexOf(req), 1);
if (req.onAborted) {
@@ -281,6 +283,10 @@ export function getQueueLength(partition: QueuePartition | "all" = "all") {
export function createQueueMiddleware(proxyMiddleware: Handler): Handler {
return (req, res, next) => {
if (config.queueMode === "none") {
return proxyMiddleware(req, res, next);
}
req.proceed = () => {
proxyMiddleware(req, res, next);
};
+3 -8
View File
@@ -2,7 +2,6 @@ import { Request, Response, NextFunction } from "express";
import { config } from "../config";
export const AGNAI_DOT_CHAT_IP = "157.230.249.32";
const RATE_LIMIT_ENABLED = Boolean(config.modelRateLimit);
const RATE_LIMIT = Math.max(1, config.modelRateLimit);
const ONE_MINUTE_MS = 60 * 1000;
@@ -53,11 +52,7 @@ export const getUniqueIps = () => {
return lastAttempts.size;
};
export const ipLimiter = async (
req: Request,
res: Response,
next: NextFunction
) => {
export const ipLimiter = (req: Request, res: Response, next: NextFunction) => {
if (!RATE_LIMIT_ENABLED) {
next();
return;
@@ -73,7 +68,7 @@ export const ipLimiter = async (
// If user is authenticated, key rate limiting by their token. Otherwise, key
// rate limiting by their IP address. Mitigates key sharing.
const rateLimitKey = req.user?.token || req.risuToken || req.ip;
const rateLimitKey = req.user?.token || req.ip;
const { remaining, reset } = getStatus(rateLimitKey);
res.set("X-RateLimit-Limit", config.modelRateLimit.toString());
@@ -88,7 +83,7 @@ export const ipLimiter = async (
type: "proxy_rate_limited",
message: `This proxy is rate limited to ${
config.modelRateLimit
} prompts per minute. Please try again in ${Math.ceil(
} model requests per minute. Please try again in ${Math.ceil(
tryAgainInMs / 1000
)} seconds.`,
},
+7 -17
View File
@@ -6,24 +6,14 @@ equivalent OpenAI requests. */
import * as express from "express";
import { gatekeeper } from "./auth/gatekeeper";
import { checkRisuToken } from "./auth/check-risu-token";
import { kobold } from "./kobold";
import { openai } from "./openai";
import { anthropic } from "./anthropic";
const proxyRouter = express.Router();
proxyRouter.use(
express.json({ limit: "1536kb" }),
express.urlencoded({ extended: true, limit: "1536kb" })
);
proxyRouter.use(gatekeeper);
proxyRouter.use(checkRisuToken);
proxyRouter.use((req, _res, next) => {
req.startTime = Date.now();
req.retryCount = 0;
next();
});
proxyRouter.use("/kobold", kobold);
proxyRouter.use("/openai", openai);
proxyRouter.use("/anthropic", anthropic);
export { proxyRouter as proxyRouter };
const router = express.Router();
router.use(gatekeeper);
router.use("/kobold", kobold);
router.use("/openai", openai);
router.use("/anthropic", anthropic);
export { router as proxyRouter };
+21 -13
View File
@@ -2,7 +2,6 @@ import { assertConfigIsValid, config } from "./config";
import "source-map-support/register";
import express from "express";
import cors from "cors";
import path from "path";
import pinoHttp from "pino-http";
import childProcess from "child_process";
import { logger } from "./logger";
@@ -13,7 +12,6 @@ import { handleInfoPage } from "./info-page";
import { logQueue } from "./prompt-logging";
import { start as startRequestQueue } from "./proxy/queue";
import { init as initUserStore } from "./proxy/auth/user-store";
import { init as initTokenizers } from "./tokenization";
import { checkOrigin } from "./proxy/check-origin";
const PORT = config.port;
@@ -36,6 +34,10 @@ app.use(
'res.headers["set-cookie"]',
"req.headers.authorization",
'req.headers["x-api-key"]',
'req.headers["x-forwarded-for"]',
'req.headers["x-real-ip"]',
'req.headers["true-client-ip"]',
'req.headers["cf-connecting-ip"]',
// Don't log the prompt text on transform errors
"body.messages",
"body.prompt",
@@ -45,19 +47,25 @@ app.use(
})
);
app.get("/health", (_req, res) => res.sendStatus(200));
app.use((req, _res, next) => {
req.startTime = Date.now();
req.retryCount = 0;
next();
});
app.use(cors());
app.use(
express.json({ limit: "10mb" }),
express.urlencoded({ extended: true, limit: "10mb" })
);
// TODO: Detect (or support manual configuration of) whether the app is behind
// a load balancer/reverse proxy, which is necessary to determine request IP
// addresses correctly.
app.set("trust proxy", true);
app.set("view engine", "ejs");
app.set("views", path.join(__dirname, "views"));
app.get("/health", (_req, res) => res.sendStatus(200));
app.use(cors());
app.use(checkOrigin);
// routes
app.use(checkOrigin);
app.get("/", handleInfoPage);
app.use("/admin", adminRouter);
app.use("/proxy", proxyRouter);
@@ -91,8 +99,6 @@ async function start() {
keyPool.init();
await initTokenizers();
if (config.gatekeeper === "user_token") {
await initUserStore();
}
@@ -102,8 +108,10 @@ async function start() {
logQueue.start();
}
logger.info("Starting request queue...");
startRequestQueue();
if (config.queueMode !== "none") {
logger.info("Starting request queue...");
startRequestQueue();
}
app.listen(PORT, async () => {
logger.info({ port: PORT }, "Now listening for connections.");
-27
View File
@@ -1,27 +0,0 @@
import { getTokenizer } from "@anthropic-ai/tokenizer";
import { Tiktoken } from "tiktoken/lite";
let encoder: Tiktoken;
export function init() {
// they export a `countTokens` function too but it instantiates a new
// tokenizer every single time and it is not fast...
encoder = getTokenizer();
return true;
}
export function getTokenCount(prompt: string, _model: string) {
// Don't try tokenizing if the prompt is massive to prevent DoS.
// 500k characters should be sufficient for all supported models.
if (prompt.length > 500000) {
return {
tokenizer: "length fallback",
token_count: 100000,
};
}
return {
tokenizer: "@anthropic-ai/tokenizer",
token_count: encoder.encode(prompt.normalize("NFKC"), "all").length,
};
}
-2
View File
@@ -1,2 +0,0 @@
export { OpenAIPromptMessage } from "./openai";
export { init, countTokens } from "./tokenizer";
-58
View File
@@ -1,58 +0,0 @@
import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
let encoder: Tiktoken;
export function init() {
encoder = new Tiktoken(
cl100k_base.bpe_ranks,
cl100k_base.special_tokens,
cl100k_base.pat_str
);
return true;
}
// Tested against:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
export function getTokenCount(messages: any[], model: string) {
const gpt4 = model.startsWith("gpt-4");
const tokensPerMessage = gpt4 ? 3 : 4;
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
let numTokens = 0;
for (const message of messages) {
numTokens += tokensPerMessage;
for (const key of Object.keys(message)) {
{
const value = message[key];
// Break if we get a huge message or exceed the token limit to prevent
// DoS.
// 100k tokens allows for future 100k GPT-4 models and 500k characters
// is just a sanity check
if (value.length > 500000 || numTokens > 100000) {
numTokens = 100000;
return {
tokenizer: "tiktoken (prompt length limit exceeded)",
token_count: numTokens,
};
}
numTokens += encoder.encode(message[key]).length;
if (key === "name") {
numTokens += tokensPerName;
}
}
}
}
numTokens += 3; // every reply is primed with <|start|>assistant<|message|>
return { tokenizer: "tiktoken", token_count: numTokens };
}
export type OpenAIPromptMessage = {
name?: string;
content: string;
role: string;
};
-58
View File
@@ -1,58 +0,0 @@
import { Request } from "express";
import { config } from "../config";
import {
init as initClaude,
getTokenCount as getClaudeTokenCount,
} from "./claude";
import {
init as initOpenAi,
getTokenCount as getOpenAITokenCount,
OpenAIPromptMessage,
} from "./openai";
export async function init() {
if (config.anthropicKey) {
initClaude();
}
if (config.openaiKey) {
initOpenAi();
}
}
type TokenCountResult = {
token_count: number;
tokenizer: string;
tokenization_duration_ms: number;
};
type TokenCountRequest = {
req: Request;
} & (
| { prompt: string; service: "anthropic" }
| { prompt: OpenAIPromptMessage[]; service: "openai" }
);
export async function countTokens({
req,
service,
prompt,
}: TokenCountRequest): Promise<TokenCountResult> {
const time = process.hrtime();
switch (service) {
case "anthropic":
return {
...getClaudeTokenCount(prompt, req.body.model),
tokenization_duration_ms: getElapsedMs(time),
};
case "openai":
return {
...getOpenAITokenCount(prompt, req.body.model),
tokenization_duration_ms: getElapsedMs(time),
};
default:
throw new Error(`Unknown service: ${service}`);
}
}
function getElapsedMs(time: [number, number]) {
const diff = process.hrtime(time);
return diff[0] * 1000 + diff[1] / 1e6;
}
-6
View File
@@ -10,8 +10,6 @@ declare global {
inboundApi: AIService | "kobold";
/** Denotes the format of the request being proxied to the API. */
outboundApi: AIService;
/** If the request comes from a RisuAI.xyz user, this is their token. */
risuToken?: string;
user?: User;
isStreaming?: boolean;
startTime: number;
@@ -20,10 +18,6 @@ declare global {
onAborted?: () => void;
proceed: () => void;
heartbeatInterval?: NodeJS.Timeout;
promptTokens?: number;
outputTokens?: number;
// TODO: remove later
debug: Record<string, any>;
}
}
}
-6
View File
@@ -1,6 +0,0 @@
<hr />
<footer>
<a href="/admin">Index</a> | <a href="/admin/logout">Logout</a>
</footer>
</body>
</html>
-61
View File
@@ -1,61 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="csrf-token" content="<%= csrfToken %>">
<title><%= title %></title>
<style>
.pagination {
list-style-type: none;
padding: 0;
}
.pagination li {
display: inline-block;
}
.pagination li a {
display: block;
padding: 0.5em 1em;
text-decoration: none;
}
.pagination li.active a {
background-color: #58739c;
color: #fff;
}
table {
border-collapse: collapse;
border: 1px solid #ccc;
}
table td, table th {
border: 1px solid #ccc;
padding: 0.25em 0.5em;
}
th.active {
background-color: #e0e6f6;
}
td.actions {
padding: 0;
text-align: center;
}
td.actions a {
text-decoration: none;
padding: 0.5em;
height: 100%;
width: 100%;
}
td.actions:hover {
background-color: #ccc;
}
@media (max-width: 600px) {
table {
width: 100%;
}
table td, table th {
display: block;
width: 100%;
}
}
</style>
</head>
<body style="font-family: sans-serif; background-color: #f0f0f0; padding: 1em;">
-23
View File
@@ -1,23 +0,0 @@
<div>
<label for="pageSize">Page Size</label>
<select id="pageSize" onchange="setPageSize(this.value)" style="margin-bottom: 1rem;">
<option value="10" <% if (pageSize === 10) { %>selected<% } %>>10</option>
<option value="20" <% if (pageSize === 20) { %>selected<% } %>>20</option>
<option value="50" <% if (pageSize === 50) { %>selected<% } %>>50</option>
<option value="100" <% if (pageSize === 100) { %>selected<% } %>>100</option>
<option value="200" <% if (pageSize === 200) { %>selected<% } %>>200</option>
</select>
</div>
<script>
function getPageSize() {
var match = window.location.search.match(/perPage=(\d+)/);
if (match) return parseInt(match[1]); else return document.cookie.match(/perPage=(\d+)/)?.[1] ?? 10;
}
function setPageSize(size) {
document.cookie = "perPage=" + size + "; path=/admin";
window.location.reload();
}
document.getElementById("pageSize").value = getPageSize();
</script>
-18
View File
@@ -1,18 +0,0 @@
<%- include("../_partials/admin-header", { title: "Create User - OAI Reverse Proxy Admin" }) %>
<!--
-->
<h1>Create User Token</h1>
<form action="/admin/manage/create-user" method="post">
<input type="hidden" name="_csrf" value="<%= csrfToken %>" />
<input type="submit" value="Create" />
</form>
<% if (newToken) { %>
<p>Just created <code><%= recentUsers[0].token %></code>.</p>
<% } %>
<h3>Recent Tokens</h2>
<ul>
<% recentUsers.forEach(function(user) { %>
<li><a href="/admin/manage/view-user/<%= user.token %>"><%= user.token %></a></li>
<% }) %>
</ul>
<%- include("../_partials/admin-footer") %>
-28
View File
@@ -1,28 +0,0 @@
<%- include("../_partials/admin-header", { title: "Export Users - OAI Reverse Proxy Admin" }) %>
<h1>Export Users</h1>
<p>
Export users to JSON. The JSON will be an array of objects under the key
<code>users</code>. You can use this JSON to import users later.
</p>
<script>
function exportUsers() {
var xhr = new XMLHttpRequest();
xhr.open("GET", "/admin/manage/export-users.json", true);
xhr.responseType = "blob";
xhr.onload = function() {
if (this.status === 200) {
var blob = new Blob([this.response], { type: "application/json" });
var url = URL.createObjectURL(blob);
var a = document.createElement("a");
a.href = url;
a.download = "users.json";
document.body.appendChild(a);
a.click();
a.remove();
}
};
xhr.send();
}
</script>
<button onclick="exportUsers()">Export</button>
<%- include("../_partials/admin-footer") %>
-44
View File
@@ -1,44 +0,0 @@
<%- include("../_partials/admin-header", { title: "Import Users - OAI Reverse Proxy Admin" }) %>
<h1>Import Users</h1>
<p>
Import users from JSON. The JSON should be an array of objects under the key
<code>users</code>. Each object should have the following fields:
</p>
<ul>
<li><code>token</code> (required): a unique identifier for the user</li>
<li><code>ip</code> (optional): IP addresses the user has connected from</li>
<li>
<code>type</code> (optional): either <code>normal</code> or
<code>special</code>
</li>
<li>
<code>promptCount</code> (optional): the number of times the user has sent a
prompt
</li>
<li>
<code>tokenCount</code> (optional): the number of tokens the user has
consumed (not yet implemented)
</li>
<li>
<code>createdAt</code> (optional): the timestamp when the user was created
</li>
<li>
<code>disabledAt</code> (optional): the timestamp when the user was disabled
</li>
<li>
<code>disabledReason</code> (optional): the reason the user was disabled
</li>
</ul>
<p>
If a user with the same token already exists, the existing user will be
updated with the new values.
</p>
<form action="/admin/manage/import-users?_csrf=<%= csrfToken %>" method="post" enctype="multipart/form-data">
<input type="file" name="users" />
<input type="submit" value="Import" />
</form>
</form>
<% if (imported > 0) { %>
<p>Imported <code><%= imported %></code> users.</p>
<% } %>
<%- include("../_partials/admin-footer") %>
-20
View File
@@ -1,20 +0,0 @@
<%- include("../_partials/admin-header", { title: "OAI Reverse Proxy Admin" }) %>
<h1>OAI Reverse Proxy Admin</h1>
<% if (!isPersistenceEnabled) { %>
<p style="color: red; background-color: #eedddd; padding: 1em">
<strong>⚠️ Users will be lost when the server restarts because persistence is
not configured.</strong><br />
<br />Be sure to export your users and import them again after restarting the
server if you want to keep them.<br />
<br /> See the <a target="_blank"
href="https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/docs/user-management.md#firebase-realtime-database">
user management documentation</a> to learn how to set up persistence.
</p>
<% } %>
<ul>
<li><a href="/admin/manage/list-users">List Users</a></li>
<li><a href="/admin/manage/create-user">Create User</a></li>
<li><a href="/admin/manage/import-users">Import Users</a></li>
<li><a href="/admin/manage/export-users">Export Users</a></li>
</ul>
<%- include("../_partials/admin-footer") %>
-105
View File
@@ -1,105 +0,0 @@
<%- include("../_partials/admin-header", { title: "Users - OAI Reverse Proxy Admin" }) %>
<h1>User Token List</h1>
<input type="hidden" name="_csrf" value="<%= csrfToken %>" />
<% if (users.length === 0) { %>
<p>No users found.</p>
<% } else { %>
<table>
<thead>
<tr>
<th>Token</th>
<th <% if (sort.includes("ip")) { %>class="active"<% } %> ><a href="/admin/manage/list-users?sort=ip">IPs</a></th>
<th <% if (sort.includes("promptCount")) { %>class="active"<% } %> ><a href="/admin/manage/list-users?sort=promptCount">Prompts</a></th>
<th>Type</th>
<th <% if (sort.includes("createdAt")) { %>class="active"<% } %> ><a href="/admin/manage/list-users?sort=createdAt">Created (UTC)</a></th>
<th <% if (sort.includes("lastUsedAt")) { %>class="active"<% } %> ><a href="/admin/manage/list-users?sort=lastUsedAt">Last Used (UTC)</a></th>
<th colspan="2">Banned?</th>
</tr>
</thead>
<tbody>
<% users.forEach(function(user){ %>
<tr>
<td>
<code><a href="/admin/manage/view-user/<%= user.token %>"><%= user.token %></a></code>
</td>
<td><%= user.ip.length %></td>
<td><%= user.promptCount %></td>
<td><%= user.type %></td>
<td><%= user.createdAt %></td>
<td><%= user.lastUsedAt ?? "never" %></td>
<td class="actions">
<% if (user.disabledAt) { %>
<a title="Unban" href="#" class="unban" data-token="<%= user.token %>">🔄️</a>
<% } else { %>
<a title="Ban" href="#" class="ban" data-token="<%= user.token %>">🚫</a>
<% } %>
<td><%= user.disabledAt ? "Yes" : "No" %> <%= user.disabledReason ? `(${user.disabledReason})` : "" %></td>
</td>
</tr>
<% }); %>
</table>
<ul class="pagination">
<% if (page > 1) { %>
<li><a href="/admin/manage/list-users?sort=<%= sort %>&page=<%= page - 1 %>">&laquo;</a></li>
<% } %> <% for (var i = 1; i <= pageCount; i++) { %>
<li <% if (i === page) { %>class="active"<% } %>><a href="/admin/manage/list-users?sort=<%= sort %>&page=<%= i %>"><%= i %></a></li>
<% } %> <% if (page < pageCount) { %>
<li><a href="/admin/manage/list-users?sort=<%= sort %>&page=<%= page + 1 %>">&raquo;</a></li>
<% } %>
</ul>
<p>Showing <%= page * pageSize - pageSize + 1 %> to <%= users.length + page * pageSize - pageSize %> of <%= totalCount %> users.</p>
<%- include("../_partials/pagination") %>
<% } %>
<script>
document.querySelectorAll("td.actions a.ban").forEach(function (a) {
a.addEventListener("click", function (e) {
e.preventDefault();
var token = a.getAttribute("data-token");
if (confirm("Are you sure you want to ban this user?")) {
let reason = prompt("Reason for ban:");
fetch(
"/admin/manage/disable-user/" + token,
{
method: "POST",
credentials: "same-origin",
body: JSON.stringify({ reason, _csrf: document.querySelector("meta[name=csrf-token]").getAttribute("content") }),
headers: { "Content-Type": "application/json" }
}).then(() => window.location.reload());
}
});
});
document.querySelectorAll("td.actions a.unban").forEach(function (a) {
a.addEventListener("click", function (e) {
e.preventDefault();
var token = a.getAttribute("data-token");
if (confirm("Are you sure you want to unban this user?")) {
fetch(
"/admin/manage/reactivate-user/" + token,
{
method: "POST",
credentials: "same-origin",
body: JSON.stringify({ _csrf: document.querySelector("meta[name=csrf-token]").getAttribute("content") }),
headers: { "Content-Type": "application/json" }
}
).then(() => window.location.reload());
}
});
});
</script>
<script>
document.querySelectorAll("td").forEach(function(td) {
if (td.innerText.match(/^\d{13}$/)) {
if (td.innerText == 0) return 'never';
var date = new Date(parseInt(td.innerText));
td.innerText = date.toISOString().replace("T", " ").replace(/\.\d+Z$/, "");
}
});
</script>
<%- include("../_partials/admin-footer") %>
-13
View File
@@ -1,13 +0,0 @@
<%- include("../_partials/admin-header", { title: "Login" }) %>
<h1>Login</h1>
<% if (failed) { %>
<p style="color: red;">Please try again.</p>
<% } %>
<form action="/admin/login" method="post">
<input type="hidden" name="_csrf" value="<%= csrfToken %>" />
<label for="token">Admin Key</label>
<input type="password" name="token" />
<input type="submit" value="Login" />
</form>
</body>
</html>
-64
View File
@@ -1,64 +0,0 @@
<%- include("../_partials/admin-header", { title: "View User - OAI Reverse Proxy Admin" }) %>
<h1>View User</h1>
<table class="table table-striped">
<thead>
<tr>
<th scope="col">Key</th>
<th scope="col">Value</th>
</tr>
<tbody>
<tr>
<th scope="row">Token</th>
<td><%- user.token %></td>
<tr>
<th scope="row">Type</th>
<td><%- user.type %></td>
</tr>
<tr>
<th scope="row">Prompt Count</th>
<td><%- user.promptCount %></td>
</tr>
<tr>
<th scope="row">Token Count</th>
<td><%- user.tokenCount %></td>
</tr>
<tr>
<th scope="row">Created At</th>
<td><%- user.createdAt %></td>
</tr>
<tr>
<th scope="row">Last Used At</th>
<td><%- user.lastUsedAt || "never" %></td>
</tr>
<tr>
<th scope="row">Disabled At</th>
<td><%- user.disabledAt %></td>
</tr>
<tr>
<th scope="row">Disabled Reason</th>
<td><%- user.disabledReason %></td>
</tr>
<tr>
<th scope="row">IPs</th>
<td>
<a href="#" id="ip-list-toggle">Show all (<%- user.ip.length %>)</a>
<ol id="ip-list" style="display:none; padding-left:1em; margin: 0;">
<% user.ip.forEach((ip) => { %>
<li><code><%- ip %></code></li>
<% }) %>
</ol>
</td>
</tr>
</tbody>
</table>
<script>
document.getElementById("ip-list-toggle").addEventListener("click", (e) => {
e.preventDefault();
document.getElementById("ip-list").style.display = "block";
document.getElementById("ip-list-toggle").style.display = "none";
});
</script>
<%- include("../_partials/admin-footer") %>
+1 -3
View File
@@ -9,9 +9,7 @@
"skipLibCheck": true,
"skipDefaultLibCheck": true,
"outDir": "build",
"sourceMap": true,
"resolveJsonModule": true,
"useUnknownInCatchVariables": false
"sourceMap": true
},
"include": ["src"],
"exclude": ["node_modules"],