Add temporary user tokens (khanon/oai-reverse-proxy!42)
This commit is contained in:
@@ -9,3 +9,15 @@ export class UserInputError extends HttpError {
|
||||
super(400, message);
|
||||
}
|
||||
}
|
||||
|
||||
export class ForbiddenError extends HttpError {
|
||||
constructor(message: string) {
|
||||
super(403, message);
|
||||
}
|
||||
}
|
||||
|
||||
export class NotFoundError extends HttpError {
|
||||
constructor(message: string) {
|
||||
super(404, message);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { RequestHandler } from "express";
|
||||
import sanitize from "sanitize-html";
|
||||
import { config } from "../config";
|
||||
import { getTokenCostUsd, prettyTokens } from "./stats";
|
||||
import * as userStore from "./users/user-store";
|
||||
@@ -13,23 +12,17 @@ export const injectLocals: RequestHandler = (req, res, next) => {
|
||||
res.locals.nextQuotaRefresh = userStore.getNextQuotaRefresh();
|
||||
res.locals.persistenceEnabled = config.gatekeeperStore !== "memory";
|
||||
res.locals.showTokenCosts = config.showTokenCosts;
|
||||
res.locals.maxIps = config.maxIpsPerUser;
|
||||
|
||||
// flash message
|
||||
if (req.query.flash) {
|
||||
const content = sanitize(String(req.query.flash))
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">");
|
||||
const match = content.match(/^([a-z]+):(.*)/);
|
||||
if (match) {
|
||||
res.locals.flash = { type: match[1], message: match[2] };
|
||||
} else {
|
||||
res.locals.flash = { type: "error", message: content };
|
||||
}
|
||||
// flash messages
|
||||
if (req.session.flash) {
|
||||
res.locals.flash = req.session.flash;
|
||||
delete req.session.flash;
|
||||
} else {
|
||||
res.locals.flash = null;
|
||||
}
|
||||
|
||||
// utils
|
||||
// view helpers
|
||||
res.locals.prettyTokens = prettyTokens;
|
||||
res.locals.tokenCost = getTokenCostUsd;
|
||||
|
||||
|
||||
+12
-1
@@ -3,7 +3,10 @@ import { logger } from "../logger";
|
||||
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k";
|
||||
export type AnthropicModelFamily = "claude";
|
||||
export type ModelFamily = OpenAIModelFamily | AnthropicModelFamily;
|
||||
export type ModelFamilyMap = { [regex: string]: ModelFamily };
|
||||
|
||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
||||
) => arr)(["turbo", "gpt4", "gpt4-32k", "claude"] as const);
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
|
||||
@@ -25,3 +28,11 @@ export function getOpenAIModelFamily(model: string): OpenAIModelFamily {
|
||||
export function getClaudeModelFamily(_model: string): ModelFamily {
|
||||
return "claude";
|
||||
}
|
||||
|
||||
export function assertIsKnownModelFamily(
|
||||
modelFamily: string
|
||||
): asserts modelFamily is ModelFamily {
|
||||
if (!MODEL_FAMILIES.includes(modelFamily as ModelFamily)) {
|
||||
throw new Error(`Unknown model family: ${modelFamily}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { config } from "../config";
|
||||
import { ModelFamily } from "./models";
|
||||
|
||||
// technically slightly underestimates, because completion tokens cost more
|
||||
// than prompt tokens but we don't track those separately right now
|
||||
export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
||||
if (!config.showTokenCosts) return 0;
|
||||
|
||||
let cost = 0;
|
||||
switch (model) {
|
||||
case "gpt4-32k":
|
||||
|
||||
+26
-18
@@ -1,5 +1,6 @@
|
||||
import { ZodType, z } from "zod";
|
||||
import type { ModelFamily } from "../models";
|
||||
import { makeOptionalPropsNullable } from "../utils";
|
||||
|
||||
export const tokenCountsSchema: ZodType<UserTokenCounts> = z
|
||||
.object({
|
||||
@@ -15,44 +16,51 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z
|
||||
|
||||
export const UserSchema = z
|
||||
.object({
|
||||
/** The user's personal access token. */
|
||||
/** User's personal access token. */
|
||||
token: z.string(),
|
||||
/** The IP addresses the user has connected from. */
|
||||
/** IP addresses the user has connected from. */
|
||||
ip: z.array(z.string()),
|
||||
/** The user's nickname. */
|
||||
nickname: z.string().max(80).nullish(),
|
||||
/** User's nickname. */
|
||||
nickname: z.string().max(80).optional(),
|
||||
/**
|
||||
* The user's privilege level.
|
||||
* - `normal`: Default role. Subject to usual rate limits and quotas.
|
||||
* - `special`: Special role. Higher quotas and exempt from
|
||||
* auto-ban/lockout.
|
||||
**/
|
||||
type: z.enum(["normal", "special"]),
|
||||
/** The number of prompts the user has made. */
|
||||
type: z.enum(["normal", "special", "temporary"]),
|
||||
/** Number of prompts the user has made. */
|
||||
promptCount: z.number(),
|
||||
/**
|
||||
* @deprecated Use `tokenCounts` instead.
|
||||
* Never used; retained for backwards compatibility.
|
||||
*/
|
||||
tokenCount: z.any().optional(),
|
||||
/** The number of tokens the user has consumed, by model family. */
|
||||
/** Number of tokens the user has consumed, by model family. */
|
||||
tokenCounts: tokenCountsSchema,
|
||||
/** The maximum number of tokens the user can consume, by model family. */
|
||||
/** Maximum number of tokens the user can consume, by model family. */
|
||||
tokenLimits: tokenCountsSchema,
|
||||
/** The time at which the user was created. */
|
||||
/** Time at which the user was created. */
|
||||
createdAt: z.number(),
|
||||
/** The time at which the user last connected. */
|
||||
lastUsedAt: z.number().nullish(),
|
||||
/** The time at which the user was disabled, if applicable. */
|
||||
disabledAt: z.number().nullish(),
|
||||
/** The reason for which the user was disabled, if applicable. */
|
||||
disabledReason: z.string().nullish(),
|
||||
/** Time at which the user last connected. */
|
||||
lastUsedAt: z.number().optional(),
|
||||
/** Time at which the user was disabled, if applicable. */
|
||||
disabledAt: z.number().optional(),
|
||||
/** Reason for which the user was disabled, if applicable. */
|
||||
disabledReason: z.string().optional(),
|
||||
/** Time at which the user will expire and be disabled (for temp users). */
|
||||
expiresAt: z.number().optional(),
|
||||
})
|
||||
.strict();
|
||||
|
||||
export const UserPartialSchema = UserSchema.partial().extend({
|
||||
token: z.string(),
|
||||
});
|
||||
/**
|
||||
* Variant of `UserSchema` which allows for partial updates, and makes any
|
||||
* optional properties on the base schema nullable. Null values are used to
|
||||
* indicate that the property should be deleted from the user object.
|
||||
*/
|
||||
export const UserPartialSchema = makeOptionalPropsNullable(UserSchema)
|
||||
.partial()
|
||||
.extend({ token: z.string() });
|
||||
|
||||
// gpt4-32k was added after the initial release, so this tries to allow for
|
||||
// data imported from older versions of the app which may be missing the
|
||||
|
||||
@@ -22,6 +22,7 @@ const MAX_IPS_PER_USER = config.maxIpsPerUser;
|
||||
const users: Map<string, User> = new Map();
|
||||
const usersToFlush = new Set<string>();
|
||||
let quotaRefreshJob: schedule.Job | null = null;
|
||||
let userCleanupJob: schedule.Job | null = null;
|
||||
|
||||
export async function init() {
|
||||
log.info({ store: config.gatekeeperStore }, "Initializing user store...");
|
||||
@@ -29,16 +30,8 @@ export async function init() {
|
||||
await initFirebase();
|
||||
}
|
||||
if (config.quotaRefreshPeriod) {
|
||||
quotaRefreshJob = schedule.scheduleJob(getRefreshCrontab(), () => {
|
||||
for (const user of users.values()) {
|
||||
refreshQuota(user.token);
|
||||
}
|
||||
log.info(
|
||||
{ users: users.size, nextRefresh: quotaRefreshJob!.nextInvocation() },
|
||||
"Token quotas refreshed."
|
||||
);
|
||||
});
|
||||
|
||||
const crontab = getRefreshCrontab();
|
||||
quotaRefreshJob = schedule.scheduleJob(crontab, refreshAllQuotas);
|
||||
if (!quotaRefreshJob) {
|
||||
throw new Error(
|
||||
"Unable to schedule quota refresh. Is QUOTA_REFRESH_PERIOD set correctly?"
|
||||
@@ -49,26 +42,42 @@ export async function init() {
|
||||
"Scheduled token quota refresh."
|
||||
);
|
||||
}
|
||||
|
||||
userCleanupJob = schedule.scheduleJob("* * * * *", cleanupExpiredTokens);
|
||||
|
||||
log.info("User store initialized.");
|
||||
}
|
||||
|
||||
export function getNextQuotaRefresh() {
|
||||
if (!quotaRefreshJob) return "never (manual refresh only)";
|
||||
return quotaRefreshJob.nextInvocation().getTime();
|
||||
}
|
||||
|
||||
/** Creates a new user and returns their token. */
|
||||
export function createUser() {
|
||||
/**
|
||||
* Creates a new user and returns their token. Optionally accepts parameters
|
||||
* for setting an expiry date and/or token limits for temporary users.
|
||||
**/
|
||||
export function createUser(createOptions?: {
|
||||
type?: User["type"];
|
||||
expiresAt?: number;
|
||||
tokenLimits?: User["tokenLimits"];
|
||||
}) {
|
||||
const token = uuid();
|
||||
users.set(token, {
|
||||
const newUser: User = {
|
||||
token,
|
||||
ip: [],
|
||||
type: "normal",
|
||||
promptCount: 0,
|
||||
tokenCounts: { turbo: 0, gpt4: 0, "gpt4-32k": 0, claude: 0 },
|
||||
tokenLimits: { ...config.tokenQuota },
|
||||
tokenLimits: createOptions?.tokenLimits ?? { ...config.tokenQuota },
|
||||
createdAt: Date.now(),
|
||||
});
|
||||
};
|
||||
|
||||
if (createOptions?.type === "temporary") {
|
||||
Object.assign(newUser, {
|
||||
type: "temporary",
|
||||
expiresAt: createOptions.expiresAt,
|
||||
});
|
||||
} else {
|
||||
Object.assign(newUser, { type: createOptions?.type ?? "normal" });
|
||||
}
|
||||
|
||||
users.set(token, newUser);
|
||||
usersToFlush.add(token);
|
||||
return token;
|
||||
}
|
||||
@@ -114,6 +123,14 @@ export function upsertUser(user: UserUpdate) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Write firebase migration to backfill gpt4-32k token counts
|
||||
if (updates.tokenCounts) {
|
||||
updates.tokenCounts["gpt4-32k"] ??= 0;
|
||||
}
|
||||
if (updates.tokenLimits) {
|
||||
updates.tokenLimits["gpt4-32k"] ??= 0;
|
||||
}
|
||||
|
||||
users.set(user.token, Object.assign(existing, updates));
|
||||
usersToFlush.add(user.token);
|
||||
|
||||
@@ -161,7 +178,7 @@ export function authenticate(token: string, ip: string) {
|
||||
const ipLimit =
|
||||
user.type === "special" || !MAX_IPS_PER_USER ? Infinity : MAX_IPS_PER_USER;
|
||||
if (user.ip.length > ipLimit) {
|
||||
disableUser(token, "Too many IP addresses associated with this token.");
|
||||
disableUser(token, "IP address limit exceeded.");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -223,6 +240,48 @@ export function disableUser(token: string, reason?: string) {
|
||||
usersToFlush.add(token);
|
||||
}
|
||||
|
||||
export function getNextQuotaRefresh() {
|
||||
if (!quotaRefreshJob) return "never (manual refresh only)";
|
||||
return quotaRefreshJob.nextInvocation().getTime();
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleans up expired temporary tokens by disabling tokens past their access
|
||||
* expiry date and permanently deleting tokens three days after their access
|
||||
* expiry date.
|
||||
*/
|
||||
function cleanupExpiredTokens() {
|
||||
const now = Date.now();
|
||||
let disabled = 0;
|
||||
let deleted = 0;
|
||||
for (const user of users.values()) {
|
||||
if (user.type !== "temporary") continue;
|
||||
if (user.expiresAt && user.expiresAt < now && !user.disabledAt) {
|
||||
disableUser(user.token, "Temporary token expired.");
|
||||
disabled++;
|
||||
}
|
||||
if (user.disabledAt && user.disabledAt + 72 * 60 * 60 * 1000 < now) {
|
||||
users.delete(user.token);
|
||||
usersToFlush.add(user.token);
|
||||
deleted++;
|
||||
}
|
||||
}
|
||||
log.debug({ disabled, deleted }, "Expired tokens cleaned up.");
|
||||
}
|
||||
|
||||
function refreshAllQuotas() {
|
||||
let count = 0;
|
||||
for (const user of users.values()) {
|
||||
if (user.type === "temporary") continue;
|
||||
refreshQuota(user.token);
|
||||
count++;
|
||||
}
|
||||
log.info(
|
||||
{ refreshed: count, nextRefresh: quotaRefreshJob!.nextInvocation() },
|
||||
"Token quotas refreshed."
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: Firebase persistence is pretend right now and just polls the in-memory
|
||||
// store to sync it with Firebase when it changes. Will refactor to abstract
|
||||
// persistence layer later so we can support multiple stores.
|
||||
@@ -253,10 +312,12 @@ async function flushUsers() {
|
||||
const db = admin.database(app);
|
||||
const usersRef = db.ref("users");
|
||||
const updates: Record<string, User> = {};
|
||||
const deletions = [];
|
||||
|
||||
for (const token of usersToFlush) {
|
||||
const user = users.get(token);
|
||||
if (!user) {
|
||||
deletions.push(token);
|
||||
continue;
|
||||
}
|
||||
updates[token] = user;
|
||||
@@ -264,13 +325,17 @@ async function flushUsers() {
|
||||
|
||||
usersToFlush.clear();
|
||||
|
||||
const numUpdates = Object.keys(updates).length;
|
||||
const numUpdates = Object.keys(updates).length + deletions.length;
|
||||
if (numUpdates === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
await usersRef.update(updates);
|
||||
log.info({ users: Object.keys(updates).length }, "Flushed users to Firebase");
|
||||
await Promise.all(deletions.map((token) => usersRef.child(token).remove()));
|
||||
log.info(
|
||||
{ users: Object.keys(updates).length, deletions: deletions.length },
|
||||
"Flushed changes to Firebase"
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: use key-management/models.ts for family mapping
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Query } from "express-serve-static-core";
|
||||
import sanitize from "sanitize-html";
|
||||
import { z } from "zod";
|
||||
|
||||
export function parseSort(sort: Query["sort"]) {
|
||||
if (!sort) return null;
|
||||
@@ -49,3 +50,28 @@ export function sanitizeAndTrim(
|
||||
) {
|
||||
return sanitize((input ?? "").trim(), options);
|
||||
}
|
||||
|
||||
// https://github.com/colinhacks/zod/discussions/2050#discussioncomment-5018870
|
||||
export function makeOptionalPropsNullable<Schema extends z.AnyZodObject>(
|
||||
schema: Schema
|
||||
) {
|
||||
const entries = Object.entries(schema.shape) as [
|
||||
keyof Schema["shape"],
|
||||
z.ZodTypeAny
|
||||
][];
|
||||
const newProps = entries.reduce(
|
||||
(acc, [key, value]) => {
|
||||
acc[key] =
|
||||
value instanceof z.ZodOptional ? value.unwrap().nullable() : value;
|
||||
return acc;
|
||||
},
|
||||
{} as {
|
||||
[key in keyof Schema["shape"]]: Schema["shape"][key] extends z.ZodOptional<
|
||||
infer T
|
||||
>
|
||||
? z.ZodNullable<T>
|
||||
: Schema["shape"][key];
|
||||
}
|
||||
);
|
||||
return z.object(newProps);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
<% if (flashData) {
|
||||
let flashStyle = { title: "", style: "" };
|
||||
switch (flashData.type) {
|
||||
case "success":
|
||||
flashStyle.title = "✅ Success:";
|
||||
flashStyle.style = "color: green; background-color: #ddffee; padding: 1em";
|
||||
break;
|
||||
case "error":
|
||||
flashStyle.title = "⚠️ Error:";
|
||||
flashStyle.style = "color: red; background-color: #eedddd; padding: 1em";
|
||||
break;
|
||||
case "warning":
|
||||
flashStyle.title = "⚠️ Alert:";
|
||||
flashStyle.style = "color: darkorange; background-color: #ffeecc; padding: 1em";
|
||||
break;
|
||||
case "info":
|
||||
flashStyle.title = "ℹ️ Notice:";
|
||||
flashStyle.style = "color: blue; background-color: #ddeeff; padding: 1em";
|
||||
break;
|
||||
}
|
||||
%>
|
||||
<p style="<%= flashStyle.style %>">
|
||||
<strong><%= flashStyle.title %></strong> <%= flashData.message %>
|
||||
</p>
|
||||
<% } %>
|
||||
@@ -63,15 +63,6 @@
|
||||
</style>
|
||||
</head>
|
||||
<body style="font-family: sans-serif; background-color: #f0f0f0; padding: 1em;">
|
||||
<% if (flash && flash.type === "error") { %>
|
||||
<p style="color: red; background-color: #eedddd; padding: 1em">
|
||||
<strong>⚠️ Error:</strong> <%= flash.message %>
|
||||
</p>
|
||||
<% } %>
|
||||
<% if (flash && flash.type === "success") { %>
|
||||
<p style="color: green; background-color: #ddffee; padding: 1em">
|
||||
<strong>✅ Success:</strong> <%= flash.message %>
|
||||
</p>
|
||||
<% } %>
|
||||
<%- include("partials/shared_flash", { flashData: flash }) %>
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,11 @@
|
||||
<td><%- prettyTokens(user.tokenLimits[key]) %></td>
|
||||
<td><%- prettyTokens(user.tokenLimits[key] - user.tokenCounts[key]) %></td>
|
||||
<% } %>
|
||||
<% if (user.type === "temporary") { %>
|
||||
<td>N/A</td>
|
||||
<% } else { %>
|
||||
<td><%- prettyTokens(quota[key]) %></td>
|
||||
<% } %>
|
||||
</tr>
|
||||
<% }) %>
|
||||
</tbody>
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
<a href="#" id="ip-list-toggle">Show all (<%- user.ip.length %>)</a>
|
||||
<ol id="ip-list" style="display: none; padding-left: 1em; margin: 0">
|
||||
<% user.ip.forEach((ip) => { %>
|
||||
<li><code><%- ip %></code></li>
|
||||
<% }) %>
|
||||
</ol>
|
||||
|
||||
<script>
|
||||
document.getElementById("ip-list-toggle").addEventListener("click", (e) => {
|
||||
e.preventDefault();
|
||||
document.getElementById("ip-list").style.display = "block";
|
||||
document.getElementById("ip-list-toggle").style.display = "none";
|
||||
});
|
||||
</script>
|
||||
Reference in New Issue
Block a user