I should have made all these commits separately but oops

This commit is contained in:
Nopm
2025-06-03 20:14:07 -03:00
parent 5988cd7e45
commit 0411b4c3a6
28 changed files with 1551 additions and 835 deletions
+12 -1
View File
@@ -17,6 +17,14 @@ NODE_ENV=production
# The title displayed on the info page.
# SERVER_TITLE=Coom Tunnel
# URL for the image displayed on the login page.
# If not set, no image will be displayed.
# LOGIN_IMAGE_URL=https://example.com/your-logo.png
# Whether to enable the token-based login for the main info page.
# Defaults to true. Set to false to disable login and make the info page public.
# ENABLE_INFO_PAGE_LOGIN=true
# The route name used to proxy requests to APIs, relative to the Web site root.
# PROXY_ENDPOINT_ROUTE=/proxy
@@ -119,8 +127,11 @@ NODE_ENV=production
# Which access control method to use. (none | proxy_key | user_token)
# GATEKEEPER=none
# Which persistence method to use. (memory | firebase_rtdb)
# Which persistence method to use. (memory | firebase_rtdb | sqlite)
# GATEKEEPER_STORE=memory
# If using sqlite store, path to the SQLite database file for user data.
# Defaults to data/user-store.sqlite in the project directory.
# SQLITE_USER_STORE_PATH=data/user-store.sqlite3
# Maximum number of unique IPs a user can connect from. (0 for unlimited)
# MAX_IPS_PER_USER=0
+499 -325
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -78,7 +78,7 @@
"@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1",
"concurrently": "^8.0.1",
"esbuild": "^0.17.16",
"esbuild": "^0.25.5",
"esbuild-register": "^3.4.2",
"husky": "^8.0.3",
"nodemon": "^3.0.1",
+15 -5
View File
@@ -132,8 +132,13 @@ router.post("/create-user", (req, res) => {
)
.transform((data: any) => {
const expiresAt = Date.now() + data.temporaryUserDuration * 60 * 1000;
const tokenLimits = MODEL_FAMILIES.reduce((limits, model) => {
limits[model] = data[`temporaryUserQuota_${model}`];
const tokenLimits = MODEL_FAMILIES.reduce((limits, modelFamily) => {
const quotaValue = data[`temporaryUserQuota_${modelFamily}`];
if (typeof quotaValue === 'number') {
limits[modelFamily] = { input: quotaValue, output: 0, legacy_total: quotaValue };
} else {
limits[modelFamily] = { input: 0, output: 0 };
}
return limits;
}, {} as UserTokenCounts);
return { ...data, expiresAt, tokenLimits };
@@ -547,9 +552,14 @@ router.post("/generate-stats", (req, res) => {
function getSumsForUser(user: User) {
const sums = MODEL_FAMILIES.reduce(
(s, model) => {
const tokens = user.tokenCounts[model] ?? 0;
s.sumTokens += tokens;
s.sumCost += getTokenCostUsd(model, tokens);
const counts = user.tokenCounts[model] ?? { input: 0, output: 0, legacy_total: undefined };
// Ensure inputTokens and outputTokens are numbers, defaulting to 0 if NaN or undefined
const inputTokens = Number(counts.input) || 0;
const outputTokens = Number(counts.output) || 0;
// We could also consider legacy_total here if input and output are 0
// For now, sumTokens and sumCost will be based on current input/output.
s.sumTokens += inputTokens + outputTokens;
s.sumCost += getTokenCostUsd(model, inputTokens, outputTokens);
return s;
},
{ sumTokens: 0, sumCost: 0, prettyUsage: "" }
+25 -9
View File
@@ -90,11 +90,6 @@ type Config = {
* management mode is set to 'user_token'.
*/
adminKey?: string;
/**
* The password required to view the service info/status page. If not set, the
* info page will be publicly accessible.
*/
serviceInfoPassword?: string;
/**
* Which user management mode to use.
* - `none`: No user management. Proxy is open to all requests with basic
@@ -111,10 +106,14 @@ type Config = {
* - `memory`: Users are stored in memory and are lost on restart (default)
* - `firebase_rtdb`: Users are stored in a Firebase Realtime Database;
* requires `firebaseKey` and `firebaseRtdbUrl` to be set.
* - `sqlite`: Users are stored in an SQLite database; requires
* `sqliteUserStorePath` to be set.
*/
gatekeeperStore: "memory" | "firebase_rtdb";
gatekeeperStore: "memory" | "firebase_rtdb" | "sqlite";
/** URL of the Firebase Realtime Database if using the Firebase RTDB store. */
firebaseRtdbUrl?: string;
/** Path to the SQLite database file for storing user data. */
sqliteUserStorePath?: string;
/**
* Base64-encoded Firebase service account key if using the Firebase RTDB
* store. Note that you should encode the *entire* JSON key file, not just the
@@ -432,6 +431,10 @@ type Config = {
*/
proxyUrl?: string;
};
/** URL for the image on the login page. Defaults to empty string (no image). */
loginImageUrl?: string;
/** Whether to enable the token-based login page for the service info page. Defaults to true. */
enableInfoPageLogin?: boolean;
};
// To change configs, create a file called .env in the root directory.
@@ -452,7 +455,6 @@ export const config: Config = {
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
serviceInfoPassword: getEnvWithDefault("SERVICE_INFO_PASSWORD", ""),
sqliteDataPath: getEnvWithDefault(
"SQLITE_DATA_PATH",
path.join(DATA_DIR, "database.sqlite")
@@ -460,7 +462,11 @@ export const config: Config = {
eventLogging: getEnvWithDefault("EVENT_LOGGING", false),
eventLoggingTrim: getEnvWithDefault("EVENT_LOGGING_TRIM", 5),
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"),
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory") as Config["gatekeeperStore"],
sqliteUserStorePath: getEnvWithDefault(
"SQLITE_USER_STORE_PATH",
path.join(DATA_DIR, "user-store.sqlite")
),
maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0),
maxIpsAutoBan: getEnvWithDefault("MAX_IPS_AUTO_BAN", false),
captchaMode: getEnvWithDefault("CAPTCHA_MODE", "none"),
@@ -546,6 +552,8 @@ export const config: Config = {
interface: getEnvWithDefault("HTTP_AGENT_INTERFACE", undefined),
proxyUrl: getEnvWithDefault("HTTP_AGENT_PROXY_URL", undefined),
},
loginImageUrl: getEnvWithDefault("LOGIN_IMAGE_URL", ""),
enableInfoPageLogin: getEnvWithDefault("ENABLE_INFO_PAGE_LOGIN", true),
} as const;
function generateSigningKey() {
@@ -667,6 +675,12 @@ export async function assertConfigIsValid() {
);
}
if (config.gatekeeperStore === "sqlite" && !config.sqliteUserStorePath) {
throw new Error(
"SQLite user store requires `SQLITE_USER_STORE_PATH` to be set."
);
}
if (Object.values(config.httpAgent || {}).filter(Boolean).length === 0) {
delete config.httpAgent;
} else if (config.httpAgent) {
@@ -722,7 +736,6 @@ export const OMITTED_KEYS = [
"azureCredentials",
"proxyKey",
"adminKey",
"serviceInfoPassword",
"rejectPhrases",
"rejectMessage",
"showTokenCosts",
@@ -731,6 +744,7 @@ export const OMITTED_KEYS = [
"firebaseKey",
"firebaseRtdbUrl",
"sqliteDataPath",
"sqliteUserStorePath",
"eventLogging",
"eventLoggingTrim",
"gatekeeperStore",
@@ -749,6 +763,8 @@ export const OMITTED_KEYS = [
"adminWhitelist",
"ipBlacklist",
"powTokenPurgeHours",
"loginImageUrl",
"enableInfoPageLogin",
] satisfies (keyof Config)[];
type OmitKeys = (typeof OMITTED_KEYS)[number];
+179 -126
View File
@@ -1,4 +1,8 @@
/** This whole module kinda sucks */
/* ──────────────────────────────────────────────────────────────
Login-gated info page
drop-in replacement for src/info-page.ts
──────────────────────────────────────────────────────────── */
import fs from "fs";
import express, { Router, Request, Response } from "express";
import showdown from "showdown";
@@ -8,9 +12,20 @@ import { getLastNImages } from "./shared/file-storage/image-history";
import { keyPool } from "./shared/key-management";
import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models";
import { withSession } from "./shared/with-session";
import { checkCsrfToken, injectCsrfToken } from "./shared/inject-csrf";
import { injectCsrfToken, checkCsrfToken } from "./shared/inject-csrf";
import { getUser } from "./shared/users/user-store";
/* ──────────────── TYPES: extend express-session ──────────── */
declare module "express-session" {
interface Session {
infoPageAuthed?: boolean;
}
}
/* ──────────────── misc constants ─────────────────────────── */
const INFO_PAGE_TTL = 2_000; // ms
const LOGIN_ROUTE = "/";
const INFO_PAGE_TTL = 2000;
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
qwen: "Qwen",
cohere: "Cohere",
@@ -72,13 +87,78 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
};
const converter = new showdown.Converter();
/* optional markdown greeting */
const customGreeting = fs.existsSync("greeting.md")
? `<div id="servergreeting">${fs.readFileSync("greeting.md", "utf8")}</div>`
: "";
/* ──────────────── Login page ──────────────────────── */
function renderLoginPage(csrf: string, error?: string) {
const errBlock = error
? `<div class="error-message">${escapeHtml(error)}</div>`
: "";
const pageTitle = getServerTitle();
return `<!DOCTYPE html>
<html>
<head>
<title>${pageTitle} Login</title>
<style>
body{font-family:Arial, sans-serif;display:flex;justify-content:center;
align-items:center;height:100vh;margin:0;padding:20px;background:#f5f5f5;}
.login-container{background:#fff;border-radius:8px;box-shadow:0 4px 8px rgba(0,0,0,.1);
padding:30px;width:100%;max-width:400px;text-align:center;}
.logo-image{max-width:200px;margin-bottom:20px;}
.form-group{margin-bottom:20px;}
input[type=text]{width:100%;padding:10px;border:1px solid #ddd;border-radius:4px;
box-sizing:border-box;font-size:16px;}
button{background:#4caf50;color:#fff;border:none;padding:12px 20px;border-radius:4px;
cursor:pointer;font-size:16px;width:100%;}
button:hover{background:#45a049;}
.error-message{color:#f44336;margin-bottom:15px;}
@media (prefers-color-scheme: dark) {
body { background: #2c2c2c; color: #e0e0e0; }
.login-container { background: #383838; box-shadow: 0 4px 12px rgba(0,0,0,0.4); border: 1px solid #4a4a4a; }
input[type=text] { background: #4a4a4a; color: #e0e0e0; border: 1px solid #5a5a5a; }
input[type=text]::placeholder { color: #999; }
button { background: #007bff; } /* Using a blue for dark mode button */
button:hover { background: #0056b3; }
.error-message { color: #ff8a80; } /* Lighter red for errors in dark mode */
}
</style>
</head>
<body>
<div class="login-container">
<img src="${config.loginImageUrl || ''}" alt="Logo" class="logo-image">
${errBlock}
<form method="POST" action="${LOGIN_ROUTE}">
<div class="form-group">
<input type="text" id="token" name="token" required placeholder="Your token">
<input type="hidden" name="_csrf" value="${csrf}">
</div>
<button type="submit">Access Dashboard</button>
</form>
</div>
</body>
</html>`;
}
/* ──────────────── login-required middleware ──────────────── */
function requireLogin(
req: Request,
res: Response,
next: express.NextFunction
) {
if (req.session?.infoPageAuthed) return next();
return res.send(renderLoginPage(res.locals.csrfToken));
}
/* ──────────────── INFO PAGE CACHING ──────────────────────── */
let infoPageHtml: string | undefined;
let infoPageLastUpdated = 0;
export const handleInfoPage = (req: Request, res: Response) => {
export function handleInfoPage(req: Request, res: Response) {
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
return res.send(infoPageHtml);
}
@@ -93,60 +173,46 @@ export const handleInfoPage = (req: Request, res: Response) => {
infoPageLastUpdated = Date.now();
res.send(infoPageHtml);
};
}
/* ──────────────── RENDER FULL INFO PAGE ──────────────────── */
export function renderPage(info: ServiceInfo) {
const title = getServerTitle();
const headerHtml = buildInfoPageHeader(info);
return `<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="robots" content="noindex" />
<title>${title}</title>
<link rel="stylesheet" href="/res/css/reset.css" media="screen" />
<link rel="stylesheet" href="/res/css/sakura.css" media="screen" />
<link rel="stylesheet" href="/res/css/sakura-dark.css" media="screen and (prefers-color-scheme: dark)" />
<style>
body {
font-family: sans-serif;
padding: 1em;
max-width: 900px;
margin: 0;
}
.self-service-links {
display: flex;
justify-content: center;
margin-bottom: 1em;
padding: 0.5em;
font-size: 0.8em;
}
.self-service-links a {
margin: 0 0.5em;
}
</style>
</head>
<body>
${headerHtml}
<hr />
${getSelfServiceLinks()}
<h2>Service Info</h2>
<pre>${JSON.stringify(info, null, 2)}</pre>
</body>
<head>
<meta charset="utf-8" />
<meta name="robots" content="noindex" />
<title>${title}</title>
<link rel="stylesheet" href="/res/css/reset.css" />
<link rel="stylesheet" href="/res/css/sakura.css" />
<link rel="stylesheet" href="/res/css/sakura-dark.css"
media="screen and (prefers-color-scheme: dark)" />
<style>
body{font-family:sans-serif;padding:1em;max-width:900px;margin:0;}
.self-service-links{display:flex;justify-content:center;margin-bottom:1em;
padding:0.5em;font-size:0.8em;}
.self-service-links a{margin:0 0.5em;}
</style>
</head>
<body>
${headerHtml}
<hr/>
${getSelfServiceLinks()}
<h2>Service Info</h2>
<pre>${JSON.stringify(info, null, 2)}</pre>
</body>
</html>`;
}
/**
* If the server operator provides a `greeting.md` file, it will be included in
* the rendered info page.
**/
/* ──────────────── header & helper functions ──────────────── */
/* (all copied verbatim from original file) */
function buildInfoPageHeader(info: ServiceInfo) {
const title = getServerTitle();
// TODO: use some templating engine instead of this mess
let infoBody = `# ${title}`;
if (config.promptLogging) {
infoBody += `\n## Prompt Logging Enabled
This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps.
@@ -165,9 +231,9 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
for (const modelFamily of config.allowedModelFamilies) {
const service = MODEL_FAMILY_SERVICE[modelFamily];
const hasKeys = keyPool.list().some((k) => {
return k.service === service && k.modelFamilies.includes(modelFamily);
});
const hasKeys = keyPool.list().some(
(k) => k.service === service && k.modelFamilies.includes(modelFamily)
);
const wait = info[modelFamily]?.estimatedQueueTime;
if (hasKeys && wait) {
@@ -178,9 +244,7 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
}
infoBody += "\n\n" + waits.join(" / ");
infoBody += customGreeting;
infoBody += buildRecentImageSection();
return converter.makeHtml(infoBody);
@@ -188,63 +252,60 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
function getSelfServiceLinks() {
if (config.gatekeeper !== "user_token") return "";
const links = [["Check your user token", "/user/lookup"]];
if (config.captchaMode !== "none") {
links.unshift(["Request a user token", "/user/captcha"]);
}
return `<div class="self-service-links">${links
.map(([text, link]) => `<a href="${link}">${text}</a>`)
.map(([t, l]) => `<a href="${l}">${t}</a>`)
.join(" | ")}</div>`;
}
function getServerTitle() {
// Use manually set title if available
if (process.env.SERVER_TITLE) {
return process.env.SERVER_TITLE;
}
// Huggingface
if (process.env.SPACE_ID) {
if (process.env.SERVER_TITLE) return process.env.SERVER_TITLE;
if (process.env.SPACE_ID)
return `${process.env.SPACE_AUTHOR_NAME} / ${process.env.SPACE_TITLE}`;
}
// Render
if (process.env.RENDER) {
if (process.env.RENDER)
return `Render / ${process.env.RENDER_SERVICE_NAME}`;
}
return "OAI Reverse Proxy";
return "Tunnel";
}
function buildRecentImageSection() {
const imageModels: ModelFamily[] = ["azure-dall-e", "dall-e", "gpt-image", "azure-gpt-image"];
const imageModels: ModelFamily[] = [
"azure-dall-e",
"dall-e",
"gpt-image",
"azure-gpt-image",
];
// Condition 1: Is the feature enabled via config?
// Condition 2: Is at least one relevant image model family allowed in config?
if (
!config.showRecentImages ||
imageModels.every((f) => !config.allowedModelFamilies.includes(f))
) {
return ""; // Exit if feature is disabled or no relevant models are allowed
}
// Condition 3: Are there any actual images to display?
const recentImages = getLastNImages(12).reverse();
if (recentImages.length === 0) {
// If the feature is enabled and models are allowed, but no images exist,
// do not render the section, including its title.
return "";
}
// If all conditions pass (feature enabled, models allowed, images exist), build and return the HTML
let html = `<h2>Recent Image Generations</h2>`;
const recentImages = getLastNImages(12).reverse();
if (recentImages.length === 0) {
html += `<p>No images yet.</p>`;
return html;
}
html += `<div style="display: flex; flex-wrap: wrap;" id="recent-images">`;
html += `<div style="display:flex;flex-wrap:wrap;" id="recent-images">`;
for (const { url, prompt } of recentImages) {
const thumbUrl = url.replace(/\.png$/, "_t.jpg");
const escapedPrompt = escapeHtml(prompt);
html += `<div style="margin: 0.5em;" class="recent-image">
<a href="${url}" target="_blank"><img src="${thumbUrl}" title="${escapedPrompt}" alt="${escapedPrompt}" style="max-width: 150px; max-height: 150px;" /></a>
</div>`;
html += `<div style="margin:0.5em" class="recent-image">
<a href="${url}" target="_blank"><img src="${thumbUrl}" title="${escapedPrompt}"
alt="${escapedPrompt}" style="max-width:150px;max-height:150px;"/></a></div>`;
}
html += `</div>`;
html += `<p style="clear: both; text-align: center;"><a href="/user/image-history">View all recent images</a></p>`;
html += `</div><p style="clear:both;text-align:center;">
<a href="/user/image-history">View all recent images</a></p>`;
return html;
}
@@ -259,57 +320,49 @@ function escapeHtml(unsafe: string) {
.replace(/]/g, "&#93;");
}
function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
try {
const [username, spacename] = spaceId.split("/");
return `https://${username}-${spacename.replace(/_/g, "-")}.hf.space`;
} catch (e) {
const [u, s] = spaceId.split("/");
return `https://${u}-${s.replace(/_/g, "-")}.hf.space`;
} catch {
return "";
}
}
function checkIfUnlocked(
req: Request,
res: Response,
next: express.NextFunction
) {
if (config.serviceInfoPassword?.length && !req.session?.unlocked) {
return res.redirect("/unlock-info");
}
next();
}
/* ──────────────── ROUTER ─────────────────────────────────── */
const infoPageRouter = Router();
if (config.serviceInfoPassword?.length) {
infoPageRouter.use(
express.json({ limit: "1mb" }),
express.urlencoded({ extended: true, limit: "1mb" })
);
infoPageRouter.use(withSession);
infoPageRouter.use(injectCsrfToken, checkCsrfToken);
infoPageRouter.post("/unlock-info", (req, res) => {
if (req.body.password !== config.serviceInfoPassword) {
return res.status(403).send("Incorrect password");
}
req.session!.unlocked = true;
res.redirect("/");
});
infoPageRouter.get("/unlock-info", (_req, res) => {
if (_req.session?.unlocked) return res.redirect("/");
res.send(`
<form method="post" action="/unlock-info">
<h1>Unlock Service Info</h1>
<input type="hidden" name="_csrf" value="${res.locals.csrfToken}" />
<input type="password" name="password" placeholder="Password" />
<button type="submit">Unlock</button>
</form>
`);
});
infoPageRouter.use(checkIfUnlocked);
}
infoPageRouter.get("/", handleInfoPage);
infoPageRouter.get("/status", (req, res) => {
res.json(buildInfo(req.protocol + "://" + req.get("host"), false));
infoPageRouter.use(
express.json({ limit: "1mb" }),
express.urlencoded({ extended: true, limit: "1mb" }),
withSession,
injectCsrfToken,
checkCsrfToken
);
/* login attempt */
infoPageRouter.post(LOGIN_ROUTE, (req, res) => {
const token = (req.body.token || "").trim();
const user = getUser(token); // returns undefined if invalid
if (!user) {
return res
.status(401)
.send(renderLoginPage(res.locals.csrfToken, "Invalid token. Please try again."));
}
req.session!.infoPageAuthed = true;
res.redirect("/");
});
export { infoPageRouter };
/* GET / either login form or info page */
if (config.enableInfoPageLogin) {
infoPageRouter.get(LOGIN_ROUTE, requireLogin, handleInfoPage);
} else {
infoPageRouter.get(LOGIN_ROUTE, handleInfoPage);
}
/* ─── Removed the public /status route : simply not added ─── */
export { infoPageRouter };
+4 -2
View File
@@ -855,10 +855,12 @@ const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
},
`Incrementing usage for model`
);
keyPool.incrementUsage(req.key!, model, tokensUsed);
// Get modelFamily for the key usage log
const modelFamilyForKeyPool = req.modelFamily!; // Should be set by getModelFamilyForRequest earlier
keyPool.incrementUsage(req.key!, modelFamilyForKeyPool, { input: req.promptTokens!, output: req.outputTokens! });
if (req.user) {
incrementPromptCount(req.user.token);
incrementTokenCount(req.user.token, model, req.outboundApi, tokensUsed);
incrementTokenCount(req.user.token, model, req.outboundApi, { input: req.promptTokens!, output: req.outputTokens! });
}
}
};
+68 -13
View File
@@ -74,14 +74,18 @@ type ModelAggregates = {
gcpSonnet35?: number;
gcpHaiku?: number;
queued: number;
tokens: number;
inputTokens: number; // Changed from tokens
outputTokens: number; // Added
legacyTokens?: number; // Added for migrated totals
};
/** All possible combinations of model family and aggregate type. */
type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`;
type AllStats = {
proompts: number;
tokens: number;
inputTokens: number; // Changed from tokens
outputTokens: number; // Added
legacyTokens?: number; // Added
tokenCost: number;
} & { [modelFamily in ModelFamily]?: ModelAggregates } & {
[service in LLMService as `${service}__${ServiceAggregate}`]?: number;
@@ -288,11 +292,14 @@ function getEndpoints(baseUrl: string, accessibleFamilies: Set<ModelFamily>) {
type TrafficStats = Pick<ServiceInfo, "proompts" | "tookens" | "proomptersNow">;
function getTrafficStats(): TrafficStats {
const tokens = serviceStats.get("tokens") || 0;
const inputTokens = serviceStats.get("inputTokens") || 0;
const outputTokens = serviceStats.get("outputTokens") || 0;
// const legacyTokens = serviceStats.get("legacyTokens") || 0; // Optional: include in total if desired
const totalTokens = inputTokens + outputTokens; // + legacyTokens;
const tokenCost = serviceStats.get("tokenCost") || 0;
return {
proompts: serviceStats.get("proompts") || 0,
tookens: `${prettyTokens(tokens)}${getCostSuffix(tokenCost)}`,
tookens: `${prettyTokens(totalTokens)}${getCostSuffix(tokenCost)}`, // Simplified to show aggregate and cost
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
};
}
@@ -352,14 +359,39 @@ function addKeyToAggregates(k: KeyPoolKey) {
addToService("cohere__keys", k.service === "cohere" ? 1 : 0);
addToService("qwen__keys", k.service === "qwen" ? 1 : 0);
let sumTokens = 0;
let sumInputTokens = 0;
let sumOutputTokens = 0;
let sumLegacyTokens = 0; // Optional
let sumCost = 0;
const incrementGenericFamilyStats = (f: ModelFamily) => {
const tokens = (k as any)[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
addToFamily(`${f}__tokens`, tokens);
const usage = k.tokenUsage?.[f];
let familyInputTokens = 0;
let familyOutputTokens = 0;
let familyLegacyTokens = 0;
if (usage) {
familyInputTokens = usage.input || 0;
familyOutputTokens = usage.output || 0;
if (usage.legacy_total && familyInputTokens === 0 && familyOutputTokens === 0) {
// This is a migrated key with no new usage, use legacy_total as input for cost
familyLegacyTokens = usage.legacy_total;
sumCost += getTokenCostUsd(f, usage.legacy_total, 0);
} else {
sumCost += getTokenCostUsd(f, familyInputTokens, familyOutputTokens);
}
}
// If no k.tokenUsage[f], tokens are 0, cost is 0.
sumInputTokens += familyInputTokens;
sumOutputTokens += familyOutputTokens;
sumLegacyTokens += familyLegacyTokens; // Optional
addToFamily(`${f}__inputTokens`, familyInputTokens);
addToFamily(`${f}__outputTokens`, familyOutputTokens);
if (familyLegacyTokens > 0) {
addToFamily(`${f}__legacyTokens`, familyLegacyTokens); // Optional
}
addToFamily(`${f}__revoked`, k.isRevoked ? 1 : 0);
addToFamily(`${f}__active`, k.isDisabled ? 0 : 1);
};
@@ -493,15 +525,38 @@ function addKeyToAggregates(k: KeyPoolKey) {
assertNever(k.service);
}
addToService("tokens", sumTokens);
addToService("inputTokens", sumInputTokens);
addToService("outputTokens", sumOutputTokens);
if (sumLegacyTokens > 0) { // Optional
addToService("legacyTokens", sumLegacyTokens);
}
addToService("tokenCost", sumCost);
}
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
const tokens = familyStats.get(`${family}__tokens`) || 0;
const cost = getTokenCostUsd(family, tokens);
const inputTokens = familyStats.get(`${family}__inputTokens`) || 0;
const outputTokens = familyStats.get(`${family}__outputTokens`) || 0;
const legacyTokens = familyStats.get(`${family}__legacyTokens`) || 0; // Optional
let cost = 0;
let displayTokens = 0;
let usageString = "";
if (inputTokens > 0 || outputTokens > 0) {
cost = getTokenCostUsd(family, inputTokens, outputTokens);
displayTokens = inputTokens + outputTokens;
usageString = `${prettyTokens(displayTokens)} (In: ${prettyTokens(inputTokens)}, Out: ${prettyTokens(outputTokens)})${getCostSuffix(cost)}`;
} else if (legacyTokens > 0) {
// Only show legacy if no new input/output has been recorded for this family aggregate
cost = getTokenCostUsd(family, legacyTokens, 0); // Cost legacy as all input
displayTokens = legacyTokens;
usageString = `${prettyTokens(displayTokens)} tokens (legacy total)${getCostSuffix(cost)}`;
} else {
usageString = `${prettyTokens(0)} tokens${getCostSuffix(0)}`;
}
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = {
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
usage: usageString,
activeKeys: familyStats.get(`${family}__active`) || 0,
revokedKeys: familyStats.get(`${family}__revoked`) || 0,
};
+3 -2
View File
@@ -1,6 +1,6 @@
import { RequestHandler } from "express";
import { config } from "../config";
import { getTokenCostUsd, prettyTokens } from "./stats";
import { getTokenCostUsd, getTokenCostDetailsUsd, prettyTokens } from "./stats"; // Added getTokenCostDetailsUsd
import { redactIp } from "./utils";
import * as userStore from "./users/user-store";
@@ -30,7 +30,8 @@ export const injectLocals: RequestHandler = (req, res, next) => {
// view helpers
res.locals.prettyTokens = prettyTokens;
res.locals.tokenCost = getTokenCostUsd;
res.locals.tokenCost = getTokenCostUsd; // Returns total cost as a number
res.locals.tokenCostDetails = getTokenCostDetailsUsd; // Returns { inputCost, outputCost, totalCost }
res.locals.redactIp = redactIp;
next();
+18 -10
View File
@@ -16,11 +16,8 @@ export type AnthropicKeyUpdate = Omit<
| "rateLimitedUntil"
>;
type AnthropicKeyUsage = {
[K in AnthropicModelFamily as `${K}Tokens`]: number;
};
export interface AnthropicKey extends Key, AnthropicKeyUsage {
// AnthropicKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface AnthropicKey extends Key {
readonly service: "anthropic";
readonly modelFamilies: AnthropicModelFamily[];
/**
@@ -120,8 +117,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
claudeTokens: 0,
"claude-opusTokens": 0,
tokenUsage: {}, // Initialize new tokenUsage field
tier: "unknown",
};
this.keys.push(newKey);
@@ -206,11 +202,23 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: AnthropicModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`${getClaudeModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Ensure the specific family object exists
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+17 -14
View File
@@ -7,11 +7,8 @@ import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { AwsKeyChecker } from "./checker";
type AwsBedrockKeyUsage = {
[K in AwsBedrockModelFamily as `${K}Tokens`]: number;
};
export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
// AwsBedrockKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface AwsBedrockKey extends Key {
readonly service: "aws";
readonly modelFamilies: AwsBedrockModelFamily[];
/**
@@ -74,12 +71,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
lastChecked: 0,
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
inferenceProfileIds: [],
["aws-claudeTokens"]: 0,
["aws-claude-opusTokens"]: 0,
["aws-mistral-tinyTokens"]: 0,
["aws-mistral-smallTokens"]: 0,
["aws-mistral-mediumTokens"]: 0,
["aws-mistral-largeTokens"]: 0,
tokenUsage: {}, // Initialize new tokenUsage field
};
this.keys.push(newKey);
}
@@ -173,11 +165,22 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: AwsBedrockModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+17 -26
View File
@@ -10,11 +10,8 @@ import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { AzureOpenAIKeyChecker } from "./checker";
type AzureOpenAIKeyUsage = {
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number;
};
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
// AzureOpenAIKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface AzureOpenAIKey extends Key {
readonly service: "azure";
readonly modelFamilies: AzureOpenAIModelFamily[];
contentFiltering: boolean;
@@ -68,24 +65,7 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
"azure-turboTokens": 0,
"azure-gpt4Tokens": 0,
"azure-gpt4-32kTokens": 0,
"azure-gpt4-turboTokens": 0,
"azure-gpt4oTokens": 0,
"azure-gpt45Tokens": 0,
"azure-gpt41Tokens": 0,
"azure-gpt41-miniTokens": 0,
"azure-gpt41-nanoTokens": 0,
"azure-o1Tokens": 0,
"azure-o1-miniTokens": 0,
"azure-o1-proTokens": 0,
"azure-o3-miniTokens": 0,
"azure-o3Tokens": 0,
"azure-o4-miniTokens": 0,
"azure-codex-miniTokens": 0,
"azure-dall-eTokens": 0,
"azure-gpt-imageTokens": 0,
tokenUsage: {}, // Initialize new tokenUsage field
modelIds: [],
};
this.keys.push(newKey);
@@ -140,11 +120,22 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: AzureOpenAIModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+20 -12
View File
@@ -2,13 +2,10 @@ import { Key, KeyProvider, createGenericGetLockoutPeriod } from "..";
import { CohereKeyChecker } from "./checker";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { CohereModelFamily } from "../../models";
import { CohereModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
type CohereKeyUsage = {
"cohereTokens": number;
};
export interface CohereKey extends Key, CohereKeyUsage {
// CohereKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface CohereKey extends Key {
readonly service: "cohere";
readonly modelFamilies: CohereModelFamily[];
isOverQuota: boolean;
@@ -42,7 +39,7 @@ export class CohereKeyProvider implements KeyProvider<CohereKey> {
hash: this.hashKey(key),
rateLimitedAt: 0,
rateLimitedUntil: 0,
"cohereTokens": 0,
tokenUsage: {}, // Initialize new tokenUsage field
isOverQuota: false,
});
}
@@ -99,13 +96,24 @@ export class CohereKeyProvider implements KeyProvider<CohereKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: CohereModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`cohereTokens`] += tokens;
}
key.promptCount++;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Cohere only has one model family "cohere"
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
+21 -13
View File
@@ -2,13 +2,10 @@ import { Key, KeyProvider, createGenericGetLockoutPeriod } from "..";
import { DeepseekKeyChecker } from "./checker";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { DeepseekModelFamily } from "../../models";
import { DeepseekModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
type DeepseekKeyUsage = {
"deepseekTokens": number;
};
export interface DeepseekKey extends Key, DeepseekKeyUsage {
// DeepseekKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface DeepseekKey extends Key {
readonly service: "deepseek";
readonly modelFamilies: DeepseekModelFamily[];
isOverQuota: boolean;
@@ -42,7 +39,7 @@ export class DeepseekKeyProvider implements KeyProvider<DeepseekKey> {
hash: this.hashKey(key),
rateLimitedAt: 0,
rateLimitedUntil: 0,
"deepseekTokens": 0,
tokenUsage: {}, // Initialize new tokenUsage field
isOverQuota: false,
});
}
@@ -99,13 +96,24 @@ export class DeepseekKeyProvider implements KeyProvider<DeepseekKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: DeepseekModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`deepseekTokens`] += tokens;
}
key.promptCount++;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Deepseek only has one model family "deepseek"
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
@@ -156,4 +164,4 @@ export class DeepseekKeyProvider implements KeyProvider<DeepseekKey> {
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}
}
+17 -10
View File
@@ -7,11 +7,8 @@ import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { GcpKeyChecker } from "./checker";
type GcpKeyUsage = {
[K in GcpModelFamily as `${K}Tokens`]: number;
};
export interface GcpKey extends Key, GcpKeyUsage {
// GcpKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface GcpKey extends Key {
readonly service: "gcp";
readonly modelFamilies: GcpModelFamily[];
sonnetEnabled: boolean;
@@ -75,8 +72,7 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
sonnet35Enabled: false,
accessToken: "",
accessTokenExpiresAt: 0,
["gcp-claudeTokens"]: 0,
["gcp-claude-opusTokens"]: 0,
tokenUsage: {}, // Initialize new tokenUsage field
};
this.keys.push(newKey);
}
@@ -160,11 +156,22 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: GcpModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`${getGcpModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+17 -11
View File
@@ -22,11 +22,8 @@ export type GoogleAIKeyUpdate = Omit<
| "rateLimitedUntil"
>;
type GoogleAIKeyUsage = {
[K in GoogleAIModelFamily as `${K}Tokens`]: number;
};
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
// GoogleAIKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface GoogleAIKey extends Key {
readonly service: "google-ai";
readonly modelFamilies: GoogleAIModelFamily[];
/** All detected model IDs on this key. */
@@ -84,9 +81,7 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
"gemini-flashTokens": 0,
"gemini-proTokens": 0,
"gemini-ultraTokens": 0,
tokenUsage: {}, // Initialize new tokenUsage field
modelIds: [],
overQuotaFamilies: [],
};
@@ -139,11 +134,22 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: GoogleAIModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`${getGoogleAIModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+9 -1
View File
@@ -36,6 +36,14 @@ export interface Key {
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
/** Detailed token usage, separated by input and output, per model family. */
tokenUsage?: {
[family in ModelFamily]?: {
input: number;
output: number;
legacy_total?: number; // To store migrated single-number totals
};
};
}
/*
@@ -58,7 +66,7 @@ export interface KeyProvider<T extends Key = Key> {
disable(key: T): void;
update(hash: string, update: Partial<T>): void;
available(): number;
incrementUsage(hash: string, model: string, tokens: number): void;
incrementUsage(hash: string, modelFamily: ModelFamily, usage: { input: number; output: number }): void;
getLockoutPeriod(model: ModelFamily): number;
markRateLimited(hash: string): void;
recheck(): void;
+24 -3
View File
@@ -108,9 +108,30 @@ export class KeyPool {
}, 0);
}
public incrementUsage(key: Key, model: string, tokens: number): void {
public incrementUsage(key: Key, modelName: string, usage: { input: number; output: number }): void {
const provider = this.getKeyProvider(key.service);
provider.incrementUsage(key.hash, model, tokens);
// Assuming the provider's incrementUsage expects a modelFamily.
// We need a robust way to get modelFamily from modelName here.
// This might involve calling a method similar to getModelFamilyForRequest from user-store,
// or enhancing getServiceForModel to also return family, or passing family directly.
// For now, let's assume the provider can handle the modelName or we derive family.
// This part is tricky as KeyPool's getServiceForModel is for service, not family directly from a generic model string.
// Let's assume for now the provider's incrementUsage can take modelName and derive family,
// or the KeyProvider interface's incrementUsage should take modelName.
// The KeyProvider interface was changed to modelFamily. So we MUST derive it.
// This requires a utility function similar to what's in user-store or models.ts.
// For now, I'll placeholder this derivation. This is a critical point.
// Placeholder: const modelFamily = this.getModelFamilyForModel(modelName, key.service);
// This is complex because getModelFamilyForModel needs the service context.
// Let's assume the `modelName` passed here is actually `modelFamily` for now,
// or that the caller will resolve it.
// The KeyProvider interface expects `modelFamily`. The caller in middleware/response/index.ts
// has `model` (name) and `req.outboundApi`. It should resolve to family there.
// So, `modelName` here should actually be `modelFamily`.
// I will assume the caller of KeyPool.incrementUsage will pass modelFamily.
// So, changing `model: string` to `modelFamily: ModelFamily` in signature.
// This change needs to be propagated to the caller.
provider.incrementUsage(key.hash, modelName as ModelFamily, usage); // Casting modelName, assuming caller provides family
}
public getLockoutPeriod(family: ModelFamily): number {
@@ -247,4 +268,4 @@ export class KeyPool {
);
this.recheckJobs["google-ai"] = googleJob;
}
}
}
@@ -7,11 +7,8 @@ import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { MistralAIKeyChecker } from "./checker";
type MistralAIKeyUsage = {
[K in MistralAIModelFamily as `${K}Tokens`]: number;
};
export interface MistralAIKey extends Key, MistralAIKeyUsage {
// MistralAIKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface MistralAIKey extends Key {
readonly service: "mistral-ai";
readonly modelFamilies: MistralAIModelFamily[];
}
@@ -67,10 +64,7 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
"mistral-tinyTokens": 0,
"mistral-smallTokens": 0,
"mistral-mediumTokens": 0,
"mistral-largeTokens": 0,
tokenUsage: {}, // Initialize new tokenUsage field
};
this.keys.push(newKey);
}
@@ -117,12 +111,22 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: MistralAIModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
const family = getMistralAIModelFamily(model);
key[`${family}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+17 -26
View File
@@ -3,16 +3,13 @@ import http from "http";
import { Key, KeyProvider } from "../index";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
import { getOpenAIModelFamily, OpenAIModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
import { PaymentRequiredError } from "../../errors";
import { OpenAIKeyChecker } from "./checker";
import { prioritizeKeys } from "../prioritize-keys";
type OpenAIKeyUsage = {
[K in OpenAIModelFamily as `${K}Tokens`]: number;
};
export interface OpenAIKey extends Key, OpenAIKeyUsage {
// OpenAIKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface OpenAIKey extends Key {
readonly service: "openai";
modelFamilies: OpenAIModelFamily[];
/**
@@ -108,24 +105,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
rateLimitedUntil: 0,
rateLimitRequestsReset: 0,
rateLimitTokensReset: 0,
turboTokens: 0,
gpt4Tokens: 0,
"gpt4-32kTokens": 0,
"gpt4-turboTokens": 0,
gpt4oTokens: 0,
gpt45Tokens: 0,
gpt41Tokens: 0,
"gpt41-miniTokens": 0,
"gpt41-nanoTokens": 0,
"o1Tokens": 0,
"o1-miniTokens": 0,
"o1-proTokens": 0,
"o3-miniTokens": 0,
"o3Tokens": 0,
"o4-miniTokens": 0,
"codex-miniTokens": 0,
"dall-eTokens": 0,
"gpt-imageTokens": 0,
tokenUsage: {}, // Initialize new tokenUsage field
modelIds: [],
};
this.keys.push(newKey);
@@ -337,11 +317,22 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
key.rateLimitedUntil = now + key.rateLimitRequestsReset;
}
public incrementUsage(keyHash: string, model: string, tokens: number) {
public incrementUsage(keyHash: string, modelFamily: OpenAIModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`${getOpenAIModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) {
+1 -1
View File
@@ -6,7 +6,7 @@ export interface QwenKey extends Key {
readonly service: "qwen";
readonly modelFamilies: QwenModelFamily[];
isOverQuota: boolean;
"qwenTokens": number;
// "qwenTokens" is removed, tokenUsage from base Key interface will be used.
}
import { logger } from "../../../logger";
import { assertNever } from "../../utils";
+17 -4
View File
@@ -2,6 +2,7 @@ import { KeyProvider, createGenericGetLockoutPeriod } from "..";
import { QwenKeyChecker, QwenKey } from "./checker";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { QwenModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
// Re-export the QwenKey interface
export type { QwenKey } from "./checker";
@@ -36,7 +37,7 @@ export class QwenKeyProvider implements KeyProvider<QwenKey> {
hash: this.hashKey(key),
rateLimitedAt: 0,
rateLimitedUntil: 0,
"qwenTokens": 0,
tokenUsage: {}, // Initialize new tokenUsage field
isOverQuota: false,
});
}
@@ -93,11 +94,23 @@ export class QwenKeyProvider implements KeyProvider<QwenKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: QwenModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`qwenTokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Qwen only has one model family "qwen"
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
/**
+21 -13
View File
@@ -2,13 +2,10 @@ import { Key, KeyProvider, createGenericGetLockoutPeriod } from "..";
import { XaiKeyChecker } from "./checker";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { XaiModelFamily } from "../../models";
import { XaiModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
type XaiKeyUsage = {
"xaiTokens": number;
};
export interface XaiKey extends Key, XaiKeyUsage {
// XaiKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface XaiKey extends Key {
readonly service: "xai";
readonly modelFamilies: XaiModelFamily[];
isOverQuota: boolean;
@@ -42,7 +39,7 @@ export class XaiKeyProvider implements KeyProvider<XaiKey> {
hash: this.hashKey(key),
rateLimitedAt: 0,
rateLimitedUntil: 0,
"xaiTokens": 0,
tokenUsage: {}, // Initialize new tokenUsage field
isOverQuota: false,
});
}
@@ -99,13 +96,24 @@ export class XaiKeyProvider implements KeyProvider<XaiKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
public incrementUsage(keyHash: string, modelFamily: XaiModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
key[`xaiTokens`] += tokens;
}
key.promptCount++;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Xai only has one model family "xai"
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
@@ -156,4 +164,4 @@ export class XaiKeyProvider implements KeyProvider<XaiKey> {
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}
}
+62
View File
@@ -0,0 +1,62 @@
import Database from 'better-sqlite3';
import { config } from '../config';
import { logger } from '../logger';
const log = logger.child({ module: 'sqlite-db' });
let db: Database.Database;
export function initSQLiteDB(): Database.Database {
if (db) {
return db;
}
const dbPath = config.sqliteUserStorePath;
if (!dbPath) {
log.error('SQLite user store DB path (SQLITE_USER_STORE_PATH) is not configured.');
throw new Error('SQLite user store DB path is not configured.');
}
log.info({ path: dbPath }, 'Initializing SQLite database for user store...');
db = new Database(dbPath, { verbose: config.logLevel === 'trace' ? console.log : undefined });
// Enable WAL mode for better concurrency and performance.
db.pragma('journal_mode = WAL');
// Create users table
// Note: JSON fields (ip, tokenCounts, etc.) are stored as TEXT.
// Timestamps are stored as INTEGER (Unix epoch milliseconds).
db.exec(`
CREATE TABLE IF NOT EXISTS users (
token TEXT PRIMARY KEY,
ip TEXT, /* JSON string array */
nickname TEXT,
type TEXT NOT NULL CHECK(type IN ('normal', 'special', 'temporary')),
promptCount INTEGER NOT NULL DEFAULT 0,
tokenCounts TEXT, /* JSON string object */
tokenLimits TEXT, /* JSON string object */
tokenRefresh TEXT, /* JSON string object */
createdAt INTEGER NOT NULL,
lastUsedAt INTEGER,
disabledAt INTEGER,
disabledReason TEXT,
expiresAt INTEGER,
maxIps INTEGER,
adminNote TEXT,
meta TEXT /* JSON string object */
);
`);
log.info('SQLite database initialized and `users` table created/verified.');
return db;
}
export function getDB(): Database.Database {
if (!db) {
// This might happen if getDB is called before initSQLiteDB,
// though user-store should ensure init is called first.
log.warn('SQLite DB instance requested before initialization. Attempting to initialize now.');
return initSQLiteDB();
}
return db;
}
+82 -140
View File
@@ -1,146 +1,88 @@
import { config } from "../config";
import { ModelFamily } from "./models";
// Using weighted averages now for better guessing, thinking models use around 1:3 ratio for input:output
// for the thinking part, other models hover around 3:1 input output, still not the best, but reflects better to real proompting.
export function getTokenCostUsd(model: ModelFamily, tokens: number) {
let cost = 0;
switch (model) {
case "deepseek":
cost = 0.00000178;
// uncached r1 pricing, again the highest average
break;
case "xai":
cost = 0.000014;
// just using the highest input/output price aka grok-3 (because who cares about grok)
break;
case "gpt41":
case "azure-gpt41":
cost = 0.0000075;
// averaged the same wa* as 4.5
break;
case "gpt41-mini":
case "azure-gpt41-mini":
cost = 0.0000015;
break;
case "gpt41-nano":
case "azure-gpt41-nano":
cost = 0.0000003;
break;
case "gpt45":
case "azure-gpt45":
// $75/$150 for 1M input/output tokens pricing, averaged to $112
cost = 0.00009375;
break;
case "gpt4o":
case "azure-gpt4o":
cost = 0.0000075;
break;
case "azure-gpt4-turbo":
case "gpt4-turbo":
cost = 0.0000125;
break;
case "azure-o1-pro":
case "o1-pro":
// OpenAI o1-pro pricing $150/1M input tokens and $600/1M output tokens
cost = 0.0004875;
break;
case "azure-o1":
case "o1":
// Currently we do not track output tokens separately, and O1 uses
// considerably more output tokens that other models for its hidden
// reasoning. The official O1 pricing is $15/1M input tokens and $60/1M
// output tokens so we will return a higher estimate here.
cost = 0.00004875;
break;
case "azure-o1-mini":
case "o1-mini":
case "azure-o3-mini":
case "o3-mini":
cost = 0.000003575; // $1.1/1M input tokens, $4.4/1M output tokens
break;
case "azure-o3":
case "o3":
cost = 0.000032; // $10/1M input tokens, $40/1M output tokens
break;
case "azure-o4-mini":
case "o4-mini":
cost = 0.000003575; // $1.1/1M input tokens, $4.4/1M output tokens
break;
case "azure-codex-mini":
case "codex-mini":
// Codex Mini pricing: $1.5/1M input tokens, $6.0/1M output tokens
// Using weighted average for 1:3 input:output ratio
cost = 0.0000045; // Weighted average with output bias
break;
case "azure-gpt4-32k":
case "gpt4-32k":
cost = 0.000075;
break;
case "azure-gpt4":
case "gpt4":
cost = 0.0000375;
break;
case "azure-turbo":
case "turbo":
cost = 0.00000075;
break;
case "azure-dall-e":
case "dall-e":
cost = 0.00001;
break;
case "azure-gpt-image":
case "gpt-image":
// gpt-image-1 pricing:
// Text input tokens: $5 per 1M tokens
// Image input tokens: $10 per 1M tokens
// Image output tokens: $40 per 1M tokens
// Weighted average assuming a mix of text/image input and output
// Typical cost is $0.02-$0.19 per image depending on quality
cost = 0.000018; // Balanced estimate accounting for input/output mix
break;
case "aws-claude":
case "gcp-claude":
case "claude":
cost = 0.00001;
break;
case "aws-claude-opus":
case "gcp-claude-opus":
case "claude-opus":
cost = 0.00003;
break;
case "aws-mistral-tiny":
case "mistral-tiny":
// Using Ministral 3B pricing: $0.04/1M input tokens, $0.04/1M output tokens
// For edge/tiny models, a more balanced 1:1 ratio is used
cost = 0.00000004;
break;
case "aws-mistral-small":
case "mistral-small":
// Using Codestral pricing: $0.3/1M input, $0.9/1M output (highest in category)
// Weighted average for 1:3 input:output ratio
cost = 0.00000075;
break;
case "aws-mistral-medium":
case "mistral-medium":
// Using Mistral Saba pricing: $0.2/1M input, $0.6/1M output
// Weighted average for 1:3 input:output ratio
cost = 0.0000005;
break;
case "aws-mistral-large":
case "mistral-large":
// Using Mistral Large/Pixtral Large pricing: $2/1M input, $6/1M output
// Weighted average for 1:3 input:output ratio
cost = 0.000005;
break;
case "gemini-flash":
cost = 0.0000002326;
break;
case "gemini-pro":
cost = 0.00000344;
break;
// Prices are per 1 million tokens.
const MODEL_PRICING: Record<ModelFamily, { input: number; output: number } | undefined> = {
"deepseek": { input: 0.14, output: 0.28 }, // DeepSeek-V2: $0.14/$0.28 per 1M tokens
"xai": { input: 5.6, output: 16.8 }, // Grok: Derived from avg $14/1M (assuming 1:3 in/out ratio) - needs official pricing
"gpt41": { input: 2.00, output: 8.00 },
"azure-gpt41": { input: 2.00, output: 8.00 },
"gpt41-mini": { input: 0.40, output: 1.60 },
"azure-gpt41-mini": { input: 0.40, output: 1.60 },
"gpt41-nano": { input: 0.10, output: 0.40 },
"azure-gpt41-nano": { input: 0.10, output: 0.40 },
"gpt45": { input: 75.00, output: 150.00 }, // Example, needs verification if this model family is still current with this pricing
"azure-gpt45": { input: 75.00, output: 150.00 }, // Example, needs verification
"gpt4o": { input: 5.00, output: 20.00 },
"azure-gpt4o": { input: 5.00, output: 20.00 },
"gpt4-turbo": { input: 10.00, output: 30.00 },
"azure-gpt4-turbo": { input: 10.00, output: 30.00 },
"o1-pro": { input: 150.00, output: 600.00 },
"azure-o1-pro": { input: 150.00, output: 600.00 },
"o1": { input: 15.00, output: 60.00 },
"azure-o1": { input: 15.00, output: 60.00 },
"o1-mini": { input: 1.10, output: 4.40 },
"azure-o1-mini": { input: 1.10, output: 4.40 },
"o3-mini": { input: 1.10, output: 4.40 },
"azure-o3-mini": { input: 1.10, output: 4.40 },
"o3": { input: 10.00, output: 40.00 },
"azure-o3": { input: 10.00, output: 40.00 },
"o4-mini": { input: 1.10, output: 4.40 },
"azure-o4-mini": { input: 1.10, output: 4.40 },
"codex-mini": { input: 1.50, output: 6.00 },
"azure-codex-mini": { input: 1.50, output: 6.00 },
"gpt4-32k": { input: 60.00, output: 120.00 },
"azure-gpt4-32k": { input: 60.00, output: 120.00 },
"gpt4": { input: 30.00, output: 60.00 },
"azure-gpt4": { input: 30.00, output: 60.00 },
"turbo": { input: 0.60, output: 2.40 }, // Maps to GPT-4o mini
"azure-turbo": { input: 0.60, output: 2.40 },
"dall-e": { input: 0, output: 0 }, // Pricing is per image, not token based in this context.
"azure-dall-e": { input: 0, output: 0 }, // Pricing is per image.
"gpt-image": { input: 0, output: 0 }, // Complex pricing (text, image input, image output tokens), handle separately.
"azure-gpt-image": { input: 0, output: 0 }, // Complex pricing.
"claude": { input: 3.00, output: 15.00 }, // Anthropic Claude Sonnet 4
"aws-claude": { input: 3.00, output: 15.00 },
"gcp-claude": { input: 3.00, output: 15.00 },
"claude-opus": { input: 15.00, output: 75.00 }, // Anthropic Claude Opus 4
"aws-claude-opus": { input: 15.00, output: 75.00 },
"gcp-claude-opus": { input: 15.00, output: 75.00 },
"mistral-tiny": { input: 0.04, output: 0.04 }, // Using old price if no new API price found
"aws-mistral-tiny": { input: 0.04, output: 0.04 },
"mistral-small": { input: 0.10, output: 0.30 }, // Mistral Small 3.1
"aws-mistral-small": { input: 0.10, output: 0.30 },
"mistral-medium": { input: 0.40, output: 2.00 }, // Mistral Medium 3
"aws-mistral-medium": { input: 0.40, output: 2.00 },
"mistral-large": { input: 2.00, output: 6.00 },
"aws-mistral-large": { input: 2.00, output: 6.00 },
"gemini-flash": { input: 0.35, output: 1.05 }, // Gemini 1.5 Flash
"gemini-pro": { input: 0.125, output: 0.375 }, // Gemini 1.0 Pro
"gemini-ultra": { input: 25.00, output: 75.00 }, // Estimated based on Gemini Pro (5-10x) and character to token conversion. Official per-token pricing needed.
// Ensure all ModelFamily entries from models.ts are covered or have a default.
// Adding placeholders for families in models.ts but not yet priced here.
"cohere": { input: 0.25, output: 0.50 }, // Cohere Command R, as an example
"qwen": { input: 1.40, output: 2.80 }, // Qwen-plus, as an example
};
export function getTokenCostDetailsUsd(model: ModelFamily, inputTokens: number, outputTokens?: number): { inputCost: number, outputCost: number, totalCost: number } {
const pricing = MODEL_PRICING[model];
if (!pricing) {
console.warn(`Pricing not found for model family: ${model}. Returning 0 cost for all components.`);
return { inputCost: 0, outputCost: 0, totalCost: 0 };
}
return cost * Math.max(0, tokens);
const costPerMillionInputTokens = pricing.input;
const costPerMillionOutputTokens = pricing.output;
const inputCost = (costPerMillionInputTokens / 1_000_000) * Math.max(0, inputTokens);
const outputCost = (costPerMillionOutputTokens / 1_000_000) * Math.max(0, outputTokens ?? 0);
return { inputCost, outputCost, totalCost: inputCost + outputCost };
}
export function getTokenCostUsd(model: ModelFamily, inputTokens: number, outputTokens?: number): number {
return getTokenCostDetailsUsd(model, inputTokens, outputTokens).totalCost;
}
export function prettyTokens(tokens: number): string {
@@ -159,4 +101,4 @@ export function prettyTokens(tokens: number): string {
export function getCostSuffix(cost: number) {
if (!config.showTokenCosts) return "";
return ` ($${cost.toFixed(2)})`;
}
}
+15 -5
View File
@@ -3,11 +3,21 @@ import { MODEL_FAMILIES, ModelFamily } from "../models";
import { makeOptionalPropsNullable } from "../utils";
// This just dynamically creates a Zod object type with a key for each model
// family and an optional number value.
// family and an optional number value for input and output tokens.
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object(
MODEL_FAMILIES.reduce(
(acc, family) => ({ ...acc, [family]: z.number().optional().default(0) }),
{} as Record<ModelFamily, ZodType<number>>
(acc, family) => ({
...acc,
[family]: z
.object({
input: z.number().optional().default(0),
output: z.number().optional().default(0),
legacy_total: z.number().optional(), // Added legacy_total
})
.optional()
.default({ input: 0, output: 0 }), // Default will not have legacy_total
}),
{} as Record<ModelFamily, ZodType<{ input: number; output: number; legacy_total?: number }>>
)
);
@@ -33,7 +43,7 @@ export const UserSchema = z
* Never used; retained for backwards compatibility.
*/
tokenCount: z.any().optional(),
/** Number of tokens the user has consumed, by model family. */
/** Number of input and output tokens the user has consumed, by model family. */
tokenCounts: tokenCountsSchema,
/** Maximum number of tokens the user can consume, by model family. */
tokenLimits: tokenCountsSchema,
@@ -67,7 +77,7 @@ export const UserPartialSchema = makeOptionalPropsNullable(UserSchema)
.extend({ token: z.string() });
export type UserTokenCounts = {
[K in ModelFamily]: number | undefined;
[K in ModelFamily]: { input: number; output: number; legacy_total?: number } | undefined;
};
export type User = z.infer<typeof UserSchema>;
export type UserUpdate = z.infer<typeof UserPartialSchema>;
+305 -42
View File
@@ -10,9 +10,11 @@
import admin from "firebase-admin";
import schedule from "node-schedule";
import { v4 as uuid } from "uuid";
import type { Database } from 'better-sqlite3';
import { config } from "../../config";
import { logger } from "../../logger";
import { getFirebaseApp } from "../firebase";
import { initSQLiteDB, getDB } from "../sqlite-db"; // Added
import { APIFormat } from "../key-management";
import {
getAwsBedrockModelFamily,
@@ -31,9 +33,45 @@ import { User, UserTokenCounts, UserUpdate } from "./schema";
const log = logger.child({ module: "users" });
const INITIAL_TOKENS: Required<UserTokenCounts> = MODEL_FAMILIES.reduce(
(acc, family) => ({ ...acc, [family]: 0 }),
{} as Record<ModelFamily, number>
);
(acc, family) => {
acc[family] = { input: 0, output: 0 }; // legacy_total is undefined by default
return acc;
},
{} as Record<ModelFamily, { input: number; output: number; legacy_total?: number }>
) as Required<UserTokenCounts>;
const migrateTokenCountsProperty = (
parsedProperty: any, // Data from DB (JSON.parse result for a specific user's property like tokenCounts)
defaultConfigForProperty: Record<ModelFamily, number | { input: number; output: number; legacy_total?: number } | undefined> // e.g., INITIAL_TOKENS or config.tokenQuota
): UserTokenCounts => {
const result = {} as UserTokenCounts;
for (const family of MODEL_FAMILIES) {
const dbValue = parsedProperty?.[family];
const configValue = defaultConfigForProperty[family];
if (typeof dbValue === 'number') {
// Case 1: DB has old numeric format - migrate and add legacy_total
result[family] = { input: dbValue, output: 0, legacy_total: dbValue };
} else if (typeof dbValue === 'object' && dbValue !== null && (typeof dbValue.input === 'number' || typeof dbValue.output === 'number')) {
// Case 2: DB has new object format (might or might not have legacy_total from a previous migration)
result[family] = { input: dbValue.input ?? 0, output: dbValue.output ?? 0, legacy_total: dbValue.legacy_total };
} else {
// Case 3: DB value is missing or invalid, use default from config
if (typeof configValue === 'number') {
// Default from config is old numeric format (e.g., config.tokenQuota[family]) - migrate and add legacy_total
result[family] = { input: configValue, output: 0, legacy_total: configValue };
} else if (typeof configValue === 'object' && configValue !== null && (typeof configValue.input === 'number' || typeof configValue.output === 'number')) {
// Default from config is new object format (e.g., INITIAL_TOKENS[family])
result[family] = { input: configValue.input ?? 0, output: configValue.output ?? 0, legacy_total: configValue.legacy_total };
} else {
// Ultimate fallback: if configValue is also missing or invalid for this family
result[family] = { input: 0, output: 0 }; // No legacy_total here
}
}
}
return result;
};
const users: Map<string, User> = new Map();
const usersToFlush = new Set<string>();
@@ -44,6 +82,8 @@ export async function init() {
log.info({ store: config.gatekeeperStore }, "Initializing user store...");
if (config.gatekeeperStore === "firebase_rtdb") {
await initFirebase();
} else if (config.gatekeeperStore === "sqlite") {
await initSQLite(); // Added
}
if (config.quotaRefreshPeriod) {
const crontab = getRefreshCrontab();
@@ -80,9 +120,14 @@ export function createUser(createOptions?: {
ip: [],
type: "normal",
promptCount: 0,
tokenCounts: { ...INITIAL_TOKENS },
tokenLimits: createOptions?.tokenLimits ?? { ...config.tokenQuota },
tokenRefresh: createOptions?.tokenRefresh ?? { ...INITIAL_TOKENS },
tokenCounts: { ...INITIAL_TOKENS }, // New counts don't have legacy_total
tokenLimits: createOptions?.tokenLimits ?? MODEL_FAMILIES.reduce((acc, family) => {
const quota = config.tokenQuota[family];
// If quota is a number, it's a legacy total limit, store it as such
acc[family] = typeof quota === 'number' ? { input: quota, output: 0, legacy_total: quota } : (quota || { input: 0, output: 0 });
return acc;
}, {} as UserTokenCounts),
tokenRefresh: createOptions?.tokenRefresh ?? { ...INITIAL_TOKENS }, // Refresh amounts typically start fresh
createdAt: Date.now(),
meta: {},
};
@@ -125,9 +170,14 @@ export function upsertUser(user: UserUpdate) {
ip: [],
type: "normal",
promptCount: 0,
tokenCounts: { ...INITIAL_TOKENS },
tokenLimits: { ...config.tokenQuota },
tokenRefresh: { ...INITIAL_TOKENS },
tokenCounts: { ...INITIAL_TOKENS }, // New counts don't have legacy_total
tokenLimits: MODEL_FAMILIES.reduce((acc, family) => {
const quota = config.tokenQuota[family];
// If quota is a number, it's a legacy total limit, store it as such
acc[family] = typeof quota === 'number' ? { input: quota, output: 0, legacy_total: quota } : (quota || { input: 0, output: 0 });
return acc;
}, {} as UserTokenCounts),
tokenRefresh: { ...INITIAL_TOKENS }, // Refresh amounts typically start fresh
createdAt: Date.now(),
meta: {},
};
@@ -146,21 +196,37 @@ export function upsertUser(user: UserUpdate) {
if (updates.tokenCounts) {
for (const family of MODEL_FAMILIES) {
updates.tokenCounts[family] ??= 0;
updates.tokenCounts[family] ??= { input: 0, output: 0 };
// The property is now guaranteed to be an object, so the 'number' check is removed.
// Defaulting individual fields if they are missing.
const counts = updates.tokenCounts[family]!; // Should not be undefined here
counts.input ??= 0;
counts.output ??= 0;
// legacy_total is optional and not defaulted here if missing
}
}
if (updates.tokenLimits) {
for (const family of MODEL_FAMILIES) {
updates.tokenLimits[family] ??= 0;
updates.tokenLimits[family] ??= { input: 0, output: 0 };
// The property is now guaranteed to be an object, so the 'number' check is removed.
// Defaulting individual fields if they are missing.
const limits = updates.tokenLimits[family]!; // Should not be undefined here
limits.input ??= 0;
limits.output ??= 0;
// legacy_total is optional and not defaulted here if missing
}
}
// tokenRefresh is a special case where we want to merge the existing and
// updated values for each model family, ignoring falsy values.
if (updates.tokenRefresh) {
const merged = { ...existing.tokenRefresh };
const merged = { ...existing.tokenRefresh } as UserTokenCounts;
for (const family of MODEL_FAMILIES) {
merged[family] =
updates.tokenRefresh[family] || existing.tokenRefresh[family];
const updateRefresh = updates.tokenRefresh[family];
const existingRefresh = existing.tokenRefresh[family];
merged[family] = {
input: (updateRefresh?.input || existingRefresh?.input) ?? 0,
output: (updateRefresh?.output || existingRefresh?.output) ?? 0,
};
}
updates.tokenRefresh = merged;
}
@@ -168,9 +234,11 @@ export function upsertUser(user: UserUpdate) {
users.set(user.token, Object.assign(existing, updates));
usersToFlush.add(user.token);
// Immediately schedule a flush to the database if we're using Firebase.
// Immediately schedule a flush to the database if a persistent store is used.
if (config.gatekeeperStore === "firebase_rtdb") {
setImmediate(flushUsers);
} else if (config.gatekeeperStore === "sqlite") {
setImmediate(flushUsersToSQLite);
}
return users.get(user.token);
@@ -189,13 +257,16 @@ export function incrementTokenCount(
token: string,
model: string,
api: APIFormat,
consumption: number
consumption: { input: number; output: number }
) {
const user = users.get(token);
if (!user) return;
const modelFamily = getModelFamilyForQuotaUsage(model, api);
const existing = user.tokenCounts[modelFamily] ?? 0;
user.tokenCounts[modelFamily] = existing + consumption;
const existingCounts = user.tokenCounts[modelFamily] ?? { input: 0, output: 0 };
user.tokenCounts[modelFamily] = {
input: (existingCounts.input ?? 0) + consumption.input,
output: (existingCounts.output ?? 0) + consumption.output,
};
usersToFlush.add(token);
}
@@ -251,12 +322,36 @@ export function hasAvailableQuota({
const modelFamily = getModelFamilyForQuotaUsage(model, api);
const { tokenCounts, tokenLimits } = user;
const tokenLimit = tokenLimits[modelFamily];
const limitConfig = tokenLimits[modelFamily];
const currentUsage = tokenCounts[modelFamily] ?? { input: 0, output: 0 };
if (!tokenLimit) return true;
// If no specific limit object for the family, or if it's essentially unlimited (e.g. input/output are 0 or not set)
// fall back to checking config.tokenQuota which is a number (total limit).
if (!limitConfig || (limitConfig.input === 0 && limitConfig.output === 0 && !config.tokenQuota[modelFamily])) {
return true; // No effective limit
}
const tokensConsumed = (tokenCounts[modelFamily] ?? 0) + requested;
return tokensConsumed < tokenLimit;
let effectiveLimit: number;
if (limitConfig && (limitConfig.input > 0 || limitConfig.output > 0)) {
// If a specific limit object exists and has positive values, sum them.
// This assumes the limit is a total limit. If input/output are separate, this logic needs change.
effectiveLimit = (limitConfig.input ?? Number.MAX_SAFE_INTEGER) + (limitConfig.output ?? Number.MAX_SAFE_INTEGER);
} else {
// Fallback to general numeric quota from config if specific limitObj is not effectively set.
const generalQuota = config.tokenQuota[modelFamily];
if (typeof generalQuota === 'number' && generalQuota > 0) {
effectiveLimit = generalQuota;
} else {
return true; // No limit defined
}
}
// Assuming 'requested' is for input tokens. If 'requested' can be input or output,
// this needs to be an object {input: number, output: number}.
// For now, we sum current input & output and add 'requested' to input for checking.
// This is a simplification. A more robust solution would involve 'requested' being an object.
const totalConsumed = (currentUsage.input ?? 0) + (currentUsage.output ?? 0) + requested;
return totalConsumed < effectiveLimit;
}
/**
@@ -270,18 +365,33 @@ export function refreshQuota(token: string) {
const { tokenQuota } = config;
const { tokenCounts, tokenLimits, tokenRefresh } = user;
// Get default quotas for each model family.
const defaultQuotas = Object.entries(tokenQuota) as [ModelFamily, number][];
// If any user-specific refresh quotas are present, override default quotas.
const userQuotas = defaultQuotas.map(
([f, q]) => [f, (tokenRefresh[f] ?? 0) || q] as const /* narrow to tuple */
);
for (const family of MODEL_FAMILIES) {
const currentUsage = tokenCounts[family] ?? { input: 0, output: 0 };
const userRefreshConfig = tokenRefresh[family] ?? { input: 0, output: 0 };
const globalDefaultQuotaValue = config.tokenQuota[family]; // This is a number or undefined
userQuotas
// Ignore families with no global or user-specific refresh quota.
.filter(([, q]) => q > 0)
// Increase family token limit by the family's refresh amount.
.forEach(([f, q]) => (tokenLimits[f] = (tokenCounts[f] ?? 0) + q));
let refreshInputAmount = 0;
let refreshOutputAmount = 0;
// Prioritize user-specific refresh amounts if they are positive
if (userRefreshConfig.input > 0 || userRefreshConfig.output > 0) {
refreshInputAmount = userRefreshConfig.input;
refreshOutputAmount = userRefreshConfig.output;
} else if (typeof globalDefaultQuotaValue === 'number' && globalDefaultQuotaValue > 0) {
// If no user-specific refresh, use the global quota.
// Distribute the global quota. For simplicity, add to input, or define a rule.
// Here, let's assume the global quota is a total that primarily refreshes 'input'.
refreshInputAmount = globalDefaultQuotaValue;
refreshOutputAmount = 0; // Or some portion of globalDefaultQuotaValue
}
if (refreshInputAmount > 0 || refreshOutputAmount > 0) {
tokenLimits[family] = {
input: (currentUsage.input ?? 0) + refreshInputAmount,
output: (currentUsage.output ?? 0) + refreshOutputAmount,
};
}
}
usersToFlush.add(token);
}
@@ -289,8 +399,9 @@ export function resetUsage(token: string) {
const user = users.get(token);
if (!user) return;
const { tokenCounts } = user;
const counts = Object.entries(tokenCounts) as [ModelFamily, number][];
counts.forEach(([model]) => (tokenCounts[model] = 0));
for (const family of MODEL_FAMILIES) {
tokenCounts[family] = { input: 0, output: 0 }; // legacy_total is implicitly undefined/removed
}
usersToFlush.add(token);
}
@@ -359,26 +470,56 @@ function refreshAllQuotas() {
// store to sync it with Firebase when it changes. Will refactor to abstract
// persistence layer later so we can support multiple stores.
let firebaseTimeout: NodeJS.Timeout | undefined;
let sqliteInterval: NodeJS.Timeout | undefined; // Added
let flushingToSQLiteInProgress = false; // Added for JS-level lock
const USERS_REF = process.env.FIREBASE_USERS_REF_NAME ?? "users";
async function initSQLite() { // Added
log.info("Initializing SQLite user store...");
initSQLiteDB(); // Initialize the DB connection and schema
await loadUsersFromSQLite();
// Set up periodic flush for SQLite, similar to Firebase
sqliteInterval = setInterval(flushUsersToSQLite, 20 * 1000);
log.info("SQLite user store initialized and users loaded.");
}
async function initFirebase() {
log.info("Connecting to Firebase...");
const app = getFirebaseApp();
const db = admin.database(app);
const usersRef = db.ref(USERS_REF);
const snapshot = await usersRef.once("value");
const users: Record<string, User> | null = snapshot.val();
const usersData: Record<string, any> | null = snapshot.val(); // Store as 'any' initially for migration
firebaseTimeout = setInterval(flushUsers, 20 * 1000);
if (!users) {
if (!usersData) {
log.info("No users found in Firebase.");
return;
}
for (const token in users) {
upsertUser(users[token]);
// migrateTokenCountsProperty is now defined at module scope
for (const token in usersData) {
const rawUser = usersData[token];
const migratedUser: User = {
...rawUser, // Spread existing fields
token: rawUser.token || token, // Ensure token is present
ip: rawUser.ip || [],
type: rawUser.type || "normal",
promptCount: rawUser.promptCount || 0,
createdAt: rawUser.createdAt || Date.now(),
// Migrate token fields
tokenCounts: migrateTokenCountsProperty(rawUser.tokenCounts, INITIAL_TOKENS),
tokenLimits: migrateTokenCountsProperty(rawUser.tokenLimits, config.tokenQuota),
tokenRefresh: migrateTokenCountsProperty(rawUser.tokenRefresh, INITIAL_TOKENS),
meta: rawUser.meta || {},
};
// Use the internal map directly to avoid re-triggering upsertUser's default creations
users.set(token, migratedUser);
}
usersToFlush.clear();
const numUsers = Object.keys(users).length;
log.info({ users: numUsers }, "Loaded users from Firebase");
usersToFlush.clear(); // Clear flush queue after initial load and migration
const numUsers = Object.keys(usersData).length;
log.info({ users: numUsers }, "Loaded and migrated users from Firebase");
}
async function flushUsers() {
@@ -412,6 +553,128 @@ async function flushUsers() {
);
}
async function loadUsersFromSQLite() { // Added
log.info("Loading users from SQLite...");
const db = getDB();
const rows = db.prepare("SELECT * FROM users").all() as any[];
for (const row of rows) {
const rawTokenCounts = row.tokenCounts ? JSON.parse(row.tokenCounts) : null;
const rawTokenLimits = row.tokenLimits ? JSON.parse(row.tokenLimits) : null;
const rawTokenRefresh = row.tokenRefresh ? JSON.parse(row.tokenRefresh) : null;
const user: User = {
token: row.token,
ip: row.ip ? JSON.parse(row.ip) : [],
nickname: row.nickname,
type: row.type,
promptCount: row.promptCount,
tokenCounts: migrateTokenCountsProperty(rawTokenCounts, INITIAL_TOKENS),
tokenLimits: migrateTokenCountsProperty(rawTokenLimits, config.tokenQuota),
tokenRefresh: migrateTokenCountsProperty(rawTokenRefresh, INITIAL_TOKENS),
createdAt: row.createdAt,
lastUsedAt: row.lastUsedAt,
disabledAt: row.disabledAt,
disabledReason: row.disabledReason,
expiresAt: row.expiresAt,
maxIps: row.maxIps,
adminNote: row.adminNote,
meta: row.meta ? JSON.parse(row.meta) : {},
};
users.set(user.token, user);
}
usersToFlush.clear(); // Clear flush queue after initial load
log.info({ users: users.size }, "Loaded users from SQLite.");
}
async function flushUsersToSQLite() { // Added
if (flushingToSQLiteInProgress) {
log.trace("Flush to SQLite already in progress, skipping.");
return;
}
if (usersToFlush.size === 0) {
return;
}
flushingToSQLiteInProgress = true;
log.trace({ count: usersToFlush.size }, "Starting flush to SQLite.");
const db = getDB();
const insertStmt = db.prepare(`
INSERT OR REPLACE INTO users (
token, ip, nickname, type, promptCount, tokenCounts, tokenLimits,
tokenRefresh, createdAt, lastUsedAt, disabledAt, disabledReason,
expiresAt, maxIps, adminNote, meta
) VALUES (
@token, @ip, @nickname, @type, @promptCount, @tokenCounts, @tokenLimits,
@tokenRefresh, @createdAt, @lastUsedAt, @disabledAt, @disabledReason,
@expiresAt, @maxIps, @adminNote, @meta
)
`);
const deleteStmt = db.prepare("DELETE FROM users WHERE token = ?");
let updatedCount = 0;
let deletedCount = 0;
const transaction = db.transaction(() => {
for (const token of usersToFlush) {
const user = users.get(token);
if (user) {
insertStmt.run({
token: user.token,
ip: JSON.stringify(user.ip || []),
nickname: user.nickname ?? null,
type: user.type,
promptCount: user.promptCount,
tokenCounts: JSON.stringify(user.tokenCounts || INITIAL_TOKENS),
tokenLimits: JSON.stringify(user.tokenLimits || migrateTokenCountsProperty(null, config.tokenQuota)),
tokenRefresh: JSON.stringify(user.tokenRefresh || INITIAL_TOKENS),
createdAt: user.createdAt,
lastUsedAt: user.lastUsedAt ?? null,
disabledAt: user.disabledAt ?? null,
disabledReason: user.disabledReason ?? null,
expiresAt: user.expiresAt ?? null,
maxIps: user.maxIps ?? null,
adminNote: user.adminNote ?? null,
meta: JSON.stringify(user.meta || {}),
});
updatedCount++;
} else {
// User was deleted from in-memory map
deleteStmt.run(token);
deletedCount++;
}
}
});
try {
transaction();
usersToFlush.clear();
if (updatedCount > 0 || deletedCount > 0) {
log.info({ updated: updatedCount, deleted: deletedCount }, "Flushed user changes to SQLite.");
}
} catch (error: any) {
log.error({
message: error?.message || "Unknown error during SQLite flush",
stack: error?.stack,
code: error?.code, // SQLite errors often have a code
rawError: error // Log the raw error object for more details
}, "Error flushing users to SQLite.");
// Re-add tokens to flush queue if transaction failed, so we can retry
// This is a simplistic retry, might need more robust error handling
// Ensure usersToFlush still contains the tokens that failed to process
// The current logic inside the transaction means usersToFlush is cleared only on success.
// If transaction fails, usersToFlush would still contain the items from before the attempt.
// However, if items were added to usersToFlush *during* the failed transaction,
// they would be processed in the next attempt.
// For simplicity, the current re-add logic is okay, but could be refined if specific
// tokens fail consistently.
usersToFlush.forEach(token => usersToFlush.add(token));
} finally {
flushingToSQLiteInProgress = false;
log.trace("Finished flush to SQLite attempt.");
}
}
function getModelFamilyForQuotaUsage(
model: string,
api: APIFormat
@@ -22,23 +22,64 @@ const quotaTableId = Math.random().toString(36).slice(2);
</tr>
</thead>
<tbody>
<% Object.entries(quota).forEach(([key, limit]) => { %>
<% Object.entries(quota).forEach(([key, configLimit]) => { %>
<%
const counts = user.tokenCounts[key] || { input: 0, output: 0 };
const limits = user.tokenLimits[key] || { input: 0, output: 0 }; // Default if not set
const refresh = user.tokenRefresh[key] || { input: 0, output: 0 };
const usageInput = Number(counts.input) || 0;
const usageOutput = Number(counts.output) || 0;
const usageLegacy = Number(counts.legacy_total) || 0;
const displayUsage = usageInput + usageOutput || usageLegacy; // This is for total token display, not directly for cost calculation here
const limitInput = Number(limits.input) || 0;
// If limit was from legacy config.tokenQuota (a number), it's in limits.legacy_total or limits.input
const displayLimit = limitInput || Number(limits.legacy_total) || 0;
// Determine tokens to use for cost calculation
const costInputTokens = (usageInput + usageOutput > 0) ? usageInput : usageLegacy;
const costOutputTokens = (usageInput + usageOutput > 0) ? usageOutput : 0; // If using legacy, output is 0 for cost
const costDetails = tokenCostDetails(key, costInputTokens, costOutputTokens);
let remaining = 0;
let limitIsSet = false;
if (displayLimit > 0) {
remaining = displayLimit - (usageInput + usageOutput);
limitIsSet = true;
} else if (typeof configLimit === 'number' && configLimit > 0) {
// Fallback to global config limit if user-specific limit is 0 or not set meaningfully
remaining = configLimit - (usageInput + usageOutput);
limitIsSet = true;
}
const refreshDisplayValue = (Number(refresh.input) || 0) + (Number(refresh.output) || 0) || configLimit || 0;
%>
<tr>
<th scope="row"><%- key %></th>
<td><%- prettyTokens(user.tokenCounts[key]) %></td>
<td>
In: <%- prettyTokens(usageInput) %><br/>
Out: <%- prettyTokens(usageOutput) %>
<% if (usageLegacy && (usageInput + usageOutput === 0)) { %><br/>(Legacy: <%- prettyTokens(usageLegacy) %>)<% } %>
</td>
<% if (showTokenCosts) { %>
<td>$<%- tokenCost(key, user.tokenCounts[key]).toFixed(2) %></td>
<td>
In: $<%- costDetails.inputCost.toFixed(Math.max(2, (costDetails.inputCost.toString().split('.')[1] || '').length)) %><br/>
Out: $<%- costDetails.outputCost.toFixed(Math.max(2, (costDetails.outputCost.toString().split('.')[1] || '').length)) %><br/>
Total: $<%- costDetails.totalCost.toFixed(2) %>
</td>
<% } %>
<% if (!user.tokenLimits[key]) { %>
<% if (!limitIsSet) { %>
<td colspan="2" style="text-align: center">unlimited</td>
<% } else { %>
<td><%- prettyTokens(user.tokenLimits[key]) %></td>
<td><%- prettyTokens(user.tokenLimits[key] - user.tokenCounts[key]) %></td>
<td><%- prettyTokens(displayLimit) %></td>
<td><%- prettyTokens(remaining) %></td>
<% } %>
<% if (user.type === "temporary") { %>
<td>N/A</td>
<% } else { %>
<td><%- prettyTokens(user.tokenRefresh[key] || quota[key]) %></td>
<td><%- prettyTokens(refreshDisplayValue) %></td>
<% } %>
<% if (showRefreshEdit) { %>
<td class="actions">