I should have made all these commits separately but oops

This commit is contained in:
Nopm
2025-06-03 20:14:07 -03:00
parent 5988cd7e45
commit 0411b4c3a6
28 changed files with 1551 additions and 835 deletions
+12 -1
View File
@@ -17,6 +17,14 @@ NODE_ENV=production
# The title displayed on the info page. # The title displayed on the info page.
# SERVER_TITLE=Coom Tunnel # SERVER_TITLE=Coom Tunnel
# URL for the image displayed on the login page.
# If not set, no image will be displayed.
# LOGIN_IMAGE_URL=https://example.com/your-logo.png
# Whether to enable the token-based login for the main info page.
# Defaults to true. Set to false to disable login and make the info page public.
# ENABLE_INFO_PAGE_LOGIN=true
# The route name used to proxy requests to APIs, relative to the Web site root. # The route name used to proxy requests to APIs, relative to the Web site root.
# PROXY_ENDPOINT_ROUTE=/proxy # PROXY_ENDPOINT_ROUTE=/proxy
@@ -119,8 +127,11 @@ NODE_ENV=production
# Which access control method to use. (none | proxy_key | user_token) # Which access control method to use. (none | proxy_key | user_token)
# GATEKEEPER=none # GATEKEEPER=none
# Which persistence method to use. (memory | firebase_rtdb) # Which persistence method to use. (memory | firebase_rtdb | sqlite)
# GATEKEEPER_STORE=memory # GATEKEEPER_STORE=memory
# If using sqlite store, path to the SQLite database file for user data.
# Defaults to data/user-store.sqlite in the project directory.
# SQLITE_USER_STORE_PATH=data/user-store.sqlite3
# Maximum number of unique IPs a user can connect from. (0 for unlimited) # Maximum number of unique IPs a user can connect from. (0 for unlimited)
# MAX_IPS_PER_USER=0 # MAX_IPS_PER_USER=0
+499 -325
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -78,7 +78,7 @@
"@types/stream-json": "^1.7.7", "@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1", "@types/uuid": "^9.0.1",
"concurrently": "^8.0.1", "concurrently": "^8.0.1",
"esbuild": "^0.17.16", "esbuild": "^0.25.5",
"esbuild-register": "^3.4.2", "esbuild-register": "^3.4.2",
"husky": "^8.0.3", "husky": "^8.0.3",
"nodemon": "^3.0.1", "nodemon": "^3.0.1",
+15 -5
View File
@@ -132,8 +132,13 @@ router.post("/create-user", (req, res) => {
) )
.transform((data: any) => { .transform((data: any) => {
const expiresAt = Date.now() + data.temporaryUserDuration * 60 * 1000; const expiresAt = Date.now() + data.temporaryUserDuration * 60 * 1000;
const tokenLimits = MODEL_FAMILIES.reduce((limits, model) => { const tokenLimits = MODEL_FAMILIES.reduce((limits, modelFamily) => {
limits[model] = data[`temporaryUserQuota_${model}`]; const quotaValue = data[`temporaryUserQuota_${modelFamily}`];
if (typeof quotaValue === 'number') {
limits[modelFamily] = { input: quotaValue, output: 0, legacy_total: quotaValue };
} else {
limits[modelFamily] = { input: 0, output: 0 };
}
return limits; return limits;
}, {} as UserTokenCounts); }, {} as UserTokenCounts);
return { ...data, expiresAt, tokenLimits }; return { ...data, expiresAt, tokenLimits };
@@ -547,9 +552,14 @@ router.post("/generate-stats", (req, res) => {
function getSumsForUser(user: User) { function getSumsForUser(user: User) {
const sums = MODEL_FAMILIES.reduce( const sums = MODEL_FAMILIES.reduce(
(s, model) => { (s, model) => {
const tokens = user.tokenCounts[model] ?? 0; const counts = user.tokenCounts[model] ?? { input: 0, output: 0, legacy_total: undefined };
s.sumTokens += tokens; // Ensure inputTokens and outputTokens are numbers, defaulting to 0 if NaN or undefined
s.sumCost += getTokenCostUsd(model, tokens); const inputTokens = Number(counts.input) || 0;
const outputTokens = Number(counts.output) || 0;
// We could also consider legacy_total here if input and output are 0
// For now, sumTokens and sumCost will be based on current input/output.
s.sumTokens += inputTokens + outputTokens;
s.sumCost += getTokenCostUsd(model, inputTokens, outputTokens);
return s; return s;
}, },
{ sumTokens: 0, sumCost: 0, prettyUsage: "" } { sumTokens: 0, sumCost: 0, prettyUsage: "" }
+25 -9
View File
@@ -90,11 +90,6 @@ type Config = {
* management mode is set to 'user_token'. * management mode is set to 'user_token'.
*/ */
adminKey?: string; adminKey?: string;
/**
* The password required to view the service info/status page. If not set, the
* info page will be publicly accessible.
*/
serviceInfoPassword?: string;
/** /**
* Which user management mode to use. * Which user management mode to use.
* - `none`: No user management. Proxy is open to all requests with basic * - `none`: No user management. Proxy is open to all requests with basic
@@ -111,10 +106,14 @@ type Config = {
* - `memory`: Users are stored in memory and are lost on restart (default) * - `memory`: Users are stored in memory and are lost on restart (default)
* - `firebase_rtdb`: Users are stored in a Firebase Realtime Database; * - `firebase_rtdb`: Users are stored in a Firebase Realtime Database;
* requires `firebaseKey` and `firebaseRtdbUrl` to be set. * requires `firebaseKey` and `firebaseRtdbUrl` to be set.
* - `sqlite`: Users are stored in an SQLite database; requires
* `sqliteUserStorePath` to be set.
*/ */
gatekeeperStore: "memory" | "firebase_rtdb"; gatekeeperStore: "memory" | "firebase_rtdb" | "sqlite";
/** URL of the Firebase Realtime Database if using the Firebase RTDB store. */ /** URL of the Firebase Realtime Database if using the Firebase RTDB store. */
firebaseRtdbUrl?: string; firebaseRtdbUrl?: string;
/** Path to the SQLite database file for storing user data. */
sqliteUserStorePath?: string;
/** /**
* Base64-encoded Firebase service account key if using the Firebase RTDB * Base64-encoded Firebase service account key if using the Firebase RTDB
* store. Note that you should encode the *entire* JSON key file, not just the * store. Note that you should encode the *entire* JSON key file, not just the
@@ -432,6 +431,10 @@ type Config = {
*/ */
proxyUrl?: string; proxyUrl?: string;
}; };
/** URL for the image on the login page. Defaults to empty string (no image). */
loginImageUrl?: string;
/** Whether to enable the token-based login page for the service info page. Defaults to true. */
enableInfoPageLogin?: boolean;
}; };
// To change configs, create a file called .env in the root directory. // To change configs, create a file called .env in the root directory.
@@ -452,7 +455,6 @@ export const config: Config = {
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""), azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""), adminKey: getEnvWithDefault("ADMIN_KEY", ""),
serviceInfoPassword: getEnvWithDefault("SERVICE_INFO_PASSWORD", ""),
sqliteDataPath: getEnvWithDefault( sqliteDataPath: getEnvWithDefault(
"SQLITE_DATA_PATH", "SQLITE_DATA_PATH",
path.join(DATA_DIR, "database.sqlite") path.join(DATA_DIR, "database.sqlite")
@@ -460,7 +462,11 @@ export const config: Config = {
eventLogging: getEnvWithDefault("EVENT_LOGGING", false), eventLogging: getEnvWithDefault("EVENT_LOGGING", false),
eventLoggingTrim: getEnvWithDefault("EVENT_LOGGING_TRIM", 5), eventLoggingTrim: getEnvWithDefault("EVENT_LOGGING_TRIM", 5),
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"), gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"), gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory") as Config["gatekeeperStore"],
sqliteUserStorePath: getEnvWithDefault(
"SQLITE_USER_STORE_PATH",
path.join(DATA_DIR, "user-store.sqlite")
),
maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0), maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0),
maxIpsAutoBan: getEnvWithDefault("MAX_IPS_AUTO_BAN", false), maxIpsAutoBan: getEnvWithDefault("MAX_IPS_AUTO_BAN", false),
captchaMode: getEnvWithDefault("CAPTCHA_MODE", "none"), captchaMode: getEnvWithDefault("CAPTCHA_MODE", "none"),
@@ -546,6 +552,8 @@ export const config: Config = {
interface: getEnvWithDefault("HTTP_AGENT_INTERFACE", undefined), interface: getEnvWithDefault("HTTP_AGENT_INTERFACE", undefined),
proxyUrl: getEnvWithDefault("HTTP_AGENT_PROXY_URL", undefined), proxyUrl: getEnvWithDefault("HTTP_AGENT_PROXY_URL", undefined),
}, },
loginImageUrl: getEnvWithDefault("LOGIN_IMAGE_URL", ""),
enableInfoPageLogin: getEnvWithDefault("ENABLE_INFO_PAGE_LOGIN", true),
} as const; } as const;
function generateSigningKey() { function generateSigningKey() {
@@ -667,6 +675,12 @@ export async function assertConfigIsValid() {
); );
} }
if (config.gatekeeperStore === "sqlite" && !config.sqliteUserStorePath) {
throw new Error(
"SQLite user store requires `SQLITE_USER_STORE_PATH` to be set."
);
}
if (Object.values(config.httpAgent || {}).filter(Boolean).length === 0) { if (Object.values(config.httpAgent || {}).filter(Boolean).length === 0) {
delete config.httpAgent; delete config.httpAgent;
} else if (config.httpAgent) { } else if (config.httpAgent) {
@@ -722,7 +736,6 @@ export const OMITTED_KEYS = [
"azureCredentials", "azureCredentials",
"proxyKey", "proxyKey",
"adminKey", "adminKey",
"serviceInfoPassword",
"rejectPhrases", "rejectPhrases",
"rejectMessage", "rejectMessage",
"showTokenCosts", "showTokenCosts",
@@ -731,6 +744,7 @@ export const OMITTED_KEYS = [
"firebaseKey", "firebaseKey",
"firebaseRtdbUrl", "firebaseRtdbUrl",
"sqliteDataPath", "sqliteDataPath",
"sqliteUserStorePath",
"eventLogging", "eventLogging",
"eventLoggingTrim", "eventLoggingTrim",
"gatekeeperStore", "gatekeeperStore",
@@ -749,6 +763,8 @@ export const OMITTED_KEYS = [
"adminWhitelist", "adminWhitelist",
"ipBlacklist", "ipBlacklist",
"powTokenPurgeHours", "powTokenPurgeHours",
"loginImageUrl",
"enableInfoPageLogin",
] satisfies (keyof Config)[]; ] satisfies (keyof Config)[];
type OmitKeys = (typeof OMITTED_KEYS)[number]; type OmitKeys = (typeof OMITTED_KEYS)[number];
+179 -126
View File
@@ -1,4 +1,8 @@
/** This whole module kinda sucks */ /* ──────────────────────────────────────────────────────────────
Login-gated info page
drop-in replacement for src/info-page.ts
──────────────────────────────────────────────────────────── */
import fs from "fs"; import fs from "fs";
import express, { Router, Request, Response } from "express"; import express, { Router, Request, Response } from "express";
import showdown from "showdown"; import showdown from "showdown";
@@ -8,9 +12,20 @@ import { getLastNImages } from "./shared/file-storage/image-history";
import { keyPool } from "./shared/key-management"; import { keyPool } from "./shared/key-management";
import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models"; import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models";
import { withSession } from "./shared/with-session"; import { withSession } from "./shared/with-session";
import { checkCsrfToken, injectCsrfToken } from "./shared/inject-csrf"; import { injectCsrfToken, checkCsrfToken } from "./shared/inject-csrf";
import { getUser } from "./shared/users/user-store";
/* ──────────────── TYPES: extend express-session ──────────── */
declare module "express-session" {
interface Session {
infoPageAuthed?: boolean;
}
}
/* ──────────────── misc constants ─────────────────────────── */
const INFO_PAGE_TTL = 2_000; // ms
const LOGIN_ROUTE = "/";
const INFO_PAGE_TTL = 2000;
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
qwen: "Qwen", qwen: "Qwen",
cohere: "Cohere", cohere: "Cohere",
@@ -72,13 +87,78 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
}; };
const converter = new showdown.Converter(); const converter = new showdown.Converter();
/* optional markdown greeting */
const customGreeting = fs.existsSync("greeting.md") const customGreeting = fs.existsSync("greeting.md")
? `<div id="servergreeting">${fs.readFileSync("greeting.md", "utf8")}</div>` ? `<div id="servergreeting">${fs.readFileSync("greeting.md", "utf8")}</div>`
: ""; : "";
/* ──────────────── Login page ──────────────────────── */
function renderLoginPage(csrf: string, error?: string) {
const errBlock = error
? `<div class="error-message">${escapeHtml(error)}</div>`
: "";
const pageTitle = getServerTitle();
return `<!DOCTYPE html>
<html>
<head>
<title>${pageTitle} Login</title>
<style>
body{font-family:Arial, sans-serif;display:flex;justify-content:center;
align-items:center;height:100vh;margin:0;padding:20px;background:#f5f5f5;}
.login-container{background:#fff;border-radius:8px;box-shadow:0 4px 8px rgba(0,0,0,.1);
padding:30px;width:100%;max-width:400px;text-align:center;}
.logo-image{max-width:200px;margin-bottom:20px;}
.form-group{margin-bottom:20px;}
input[type=text]{width:100%;padding:10px;border:1px solid #ddd;border-radius:4px;
box-sizing:border-box;font-size:16px;}
button{background:#4caf50;color:#fff;border:none;padding:12px 20px;border-radius:4px;
cursor:pointer;font-size:16px;width:100%;}
button:hover{background:#45a049;}
.error-message{color:#f44336;margin-bottom:15px;}
@media (prefers-color-scheme: dark) {
body { background: #2c2c2c; color: #e0e0e0; }
.login-container { background: #383838; box-shadow: 0 4px 12px rgba(0,0,0,0.4); border: 1px solid #4a4a4a; }
input[type=text] { background: #4a4a4a; color: #e0e0e0; border: 1px solid #5a5a5a; }
input[type=text]::placeholder { color: #999; }
button { background: #007bff; } /* Using a blue for dark mode button */
button:hover { background: #0056b3; }
.error-message { color: #ff8a80; } /* Lighter red for errors in dark mode */
}
</style>
</head>
<body>
<div class="login-container">
<img src="${config.loginImageUrl || ''}" alt="Logo" class="logo-image">
${errBlock}
<form method="POST" action="${LOGIN_ROUTE}">
<div class="form-group">
<input type="text" id="token" name="token" required placeholder="Your token">
<input type="hidden" name="_csrf" value="${csrf}">
</div>
<button type="submit">Access Dashboard</button>
</form>
</div>
</body>
</html>`;
}
/* ──────────────── login-required middleware ──────────────── */
function requireLogin(
req: Request,
res: Response,
next: express.NextFunction
) {
if (req.session?.infoPageAuthed) return next();
return res.send(renderLoginPage(res.locals.csrfToken));
}
/* ──────────────── INFO PAGE CACHING ──────────────────────── */
let infoPageHtml: string | undefined; let infoPageHtml: string | undefined;
let infoPageLastUpdated = 0; let infoPageLastUpdated = 0;
export const handleInfoPage = (req: Request, res: Response) => { export function handleInfoPage(req: Request, res: Response) {
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) { if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
return res.send(infoPageHtml); return res.send(infoPageHtml);
} }
@@ -93,60 +173,46 @@ export const handleInfoPage = (req: Request, res: Response) => {
infoPageLastUpdated = Date.now(); infoPageLastUpdated = Date.now();
res.send(infoPageHtml); res.send(infoPageHtml);
}; }
/* ──────────────── RENDER FULL INFO PAGE ──────────────────── */
export function renderPage(info: ServiceInfo) { export function renderPage(info: ServiceInfo) {
const title = getServerTitle(); const title = getServerTitle();
const headerHtml = buildInfoPageHeader(info); const headerHtml = buildInfoPageHeader(info);
return `<!doctype html> return `<!doctype html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="utf-8" /> <meta charset="utf-8" />
<meta name="robots" content="noindex" /> <meta name="robots" content="noindex" />
<title>${title}</title> <title>${title}</title>
<link rel="stylesheet" href="/res/css/reset.css" media="screen" /> <link rel="stylesheet" href="/res/css/reset.css" />
<link rel="stylesheet" href="/res/css/sakura.css" media="screen" /> <link rel="stylesheet" href="/res/css/sakura.css" />
<link rel="stylesheet" href="/res/css/sakura-dark.css" media="screen and (prefers-color-scheme: dark)" /> <link rel="stylesheet" href="/res/css/sakura-dark.css"
<style> media="screen and (prefers-color-scheme: dark)" />
body { <style>
font-family: sans-serif; body{font-family:sans-serif;padding:1em;max-width:900px;margin:0;}
padding: 1em; .self-service-links{display:flex;justify-content:center;margin-bottom:1em;
max-width: 900px; padding:0.5em;font-size:0.8em;}
margin: 0; .self-service-links a{margin:0 0.5em;}
} </style>
</head>
.self-service-links { <body>
display: flex; ${headerHtml}
justify-content: center; <hr/>
margin-bottom: 1em; ${getSelfServiceLinks()}
padding: 0.5em; <h2>Service Info</h2>
font-size: 0.8em; <pre>${JSON.stringify(info, null, 2)}</pre>
} </body>
.self-service-links a {
margin: 0 0.5em;
}
</style>
</head>
<body>
${headerHtml}
<hr />
${getSelfServiceLinks()}
<h2>Service Info</h2>
<pre>${JSON.stringify(info, null, 2)}</pre>
</body>
</html>`; </html>`;
} }
/** /* ──────────────── header & helper functions ──────────────── */
* If the server operator provides a `greeting.md` file, it will be included in /* (all copied verbatim from original file) */
* the rendered info page.
**/
function buildInfoPageHeader(info: ServiceInfo) { function buildInfoPageHeader(info: ServiceInfo) {
const title = getServerTitle(); const title = getServerTitle();
// TODO: use some templating engine instead of this mess
let infoBody = `# ${title}`; let infoBody = `# ${title}`;
if (config.promptLogging) { if (config.promptLogging) {
infoBody += `\n## Prompt Logging Enabled infoBody += `\n## Prompt Logging Enabled
This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps. This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps.
@@ -165,9 +231,9 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
for (const modelFamily of config.allowedModelFamilies) { for (const modelFamily of config.allowedModelFamilies) {
const service = MODEL_FAMILY_SERVICE[modelFamily]; const service = MODEL_FAMILY_SERVICE[modelFamily];
const hasKeys = keyPool.list().some((k) => { const hasKeys = keyPool.list().some(
return k.service === service && k.modelFamilies.includes(modelFamily); (k) => k.service === service && k.modelFamilies.includes(modelFamily)
}); );
const wait = info[modelFamily]?.estimatedQueueTime; const wait = info[modelFamily]?.estimatedQueueTime;
if (hasKeys && wait) { if (hasKeys && wait) {
@@ -178,9 +244,7 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
} }
infoBody += "\n\n" + waits.join(" / "); infoBody += "\n\n" + waits.join(" / ");
infoBody += customGreeting; infoBody += customGreeting;
infoBody += buildRecentImageSection(); infoBody += buildRecentImageSection();
return converter.makeHtml(infoBody); return converter.makeHtml(infoBody);
@@ -188,63 +252,60 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
function getSelfServiceLinks() { function getSelfServiceLinks() {
if (config.gatekeeper !== "user_token") return ""; if (config.gatekeeper !== "user_token") return "";
const links = [["Check your user token", "/user/lookup"]]; const links = [["Check your user token", "/user/lookup"]];
if (config.captchaMode !== "none") { if (config.captchaMode !== "none") {
links.unshift(["Request a user token", "/user/captcha"]); links.unshift(["Request a user token", "/user/captcha"]);
} }
return `<div class="self-service-links">${links return `<div class="self-service-links">${links
.map(([text, link]) => `<a href="${link}">${text}</a>`) .map(([t, l]) => `<a href="${l}">${t}</a>`)
.join(" | ")}</div>`; .join(" | ")}</div>`;
} }
function getServerTitle() { function getServerTitle() {
// Use manually set title if available if (process.env.SERVER_TITLE) return process.env.SERVER_TITLE;
if (process.env.SERVER_TITLE) { if (process.env.SPACE_ID)
return process.env.SERVER_TITLE;
}
// Huggingface
if (process.env.SPACE_ID) {
return `${process.env.SPACE_AUTHOR_NAME} / ${process.env.SPACE_TITLE}`; return `${process.env.SPACE_AUTHOR_NAME} / ${process.env.SPACE_TITLE}`;
} if (process.env.RENDER)
// Render
if (process.env.RENDER) {
return `Render / ${process.env.RENDER_SERVICE_NAME}`; return `Render / ${process.env.RENDER_SERVICE_NAME}`;
} return "Tunnel";
return "OAI Reverse Proxy";
} }
function buildRecentImageSection() { function buildRecentImageSection() {
const imageModels: ModelFamily[] = ["azure-dall-e", "dall-e", "gpt-image", "azure-gpt-image"]; const imageModels: ModelFamily[] = [
"azure-dall-e",
"dall-e",
"gpt-image",
"azure-gpt-image",
];
// Condition 1: Is the feature enabled via config?
// Condition 2: Is at least one relevant image model family allowed in config?
if ( if (
!config.showRecentImages || !config.showRecentImages ||
imageModels.every((f) => !config.allowedModelFamilies.includes(f)) imageModels.every((f) => !config.allowedModelFamilies.includes(f))
) { ) {
return ""; // Exit if feature is disabled or no relevant models are allowed
}
// Condition 3: Are there any actual images to display?
const recentImages = getLastNImages(12).reverse();
if (recentImages.length === 0) {
// If the feature is enabled and models are allowed, but no images exist,
// do not render the section, including its title.
return ""; return "";
} }
// If all conditions pass (feature enabled, models allowed, images exist), build and return the HTML
let html = `<h2>Recent Image Generations</h2>`; let html = `<h2>Recent Image Generations</h2>`;
const recentImages = getLastNImages(12).reverse(); html += `<div style="display:flex;flex-wrap:wrap;" id="recent-images">`;
if (recentImages.length === 0) {
html += `<p>No images yet.</p>`;
return html;
}
html += `<div style="display: flex; flex-wrap: wrap;" id="recent-images">`;
for (const { url, prompt } of recentImages) { for (const { url, prompt } of recentImages) {
const thumbUrl = url.replace(/\.png$/, "_t.jpg"); const thumbUrl = url.replace(/\.png$/, "_t.jpg");
const escapedPrompt = escapeHtml(prompt); const escapedPrompt = escapeHtml(prompt);
html += `<div style="margin: 0.5em;" class="recent-image"> html += `<div style="margin:0.5em" class="recent-image">
<a href="${url}" target="_blank"><img src="${thumbUrl}" title="${escapedPrompt}" alt="${escapedPrompt}" style="max-width: 150px; max-height: 150px;" /></a> <a href="${url}" target="_blank"><img src="${thumbUrl}" title="${escapedPrompt}"
</div>`; alt="${escapedPrompt}" style="max-width:150px;max-height:150px;"/></a></div>`;
} }
html += `</div>`; html += `</div><p style="clear:both;text-align:center;">
html += `<p style="clear: both; text-align: center;"><a href="/user/image-history">View all recent images</a></p>`; <a href="/user/image-history">View all recent images</a></p>`;
return html; return html;
} }
@@ -259,57 +320,49 @@ function escapeHtml(unsafe: string) {
.replace(/]/g, "&#93;"); .replace(/]/g, "&#93;");
} }
function getExternalUrlForHuggingfaceSpaceId(spaceId: string) { function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
try { try {
const [username, spacename] = spaceId.split("/"); const [u, s] = spaceId.split("/");
return `https://${username}-${spacename.replace(/_/g, "-")}.hf.space`; return `https://${u}-${s.replace(/_/g, "-")}.hf.space`;
} catch (e) { } catch {
return ""; return "";
} }
} }
function checkIfUnlocked( /* ──────────────── ROUTER ─────────────────────────────────── */
req: Request,
res: Response,
next: express.NextFunction
) {
if (config.serviceInfoPassword?.length && !req.session?.unlocked) {
return res.redirect("/unlock-info");
}
next();
}
const infoPageRouter = Router(); const infoPageRouter = Router();
if (config.serviceInfoPassword?.length) {
infoPageRouter.use(
express.json({ limit: "1mb" }),
express.urlencoded({ extended: true, limit: "1mb" })
);
infoPageRouter.use(withSession);
infoPageRouter.use(injectCsrfToken, checkCsrfToken);
infoPageRouter.post("/unlock-info", (req, res) => {
if (req.body.password !== config.serviceInfoPassword) {
return res.status(403).send("Incorrect password");
}
req.session!.unlocked = true;
res.redirect("/");
});
infoPageRouter.get("/unlock-info", (_req, res) => {
if (_req.session?.unlocked) return res.redirect("/");
res.send(` infoPageRouter.use(
<form method="post" action="/unlock-info"> express.json({ limit: "1mb" }),
<h1>Unlock Service Info</h1> express.urlencoded({ extended: true, limit: "1mb" }),
<input type="hidden" name="_csrf" value="${res.locals.csrfToken}" /> withSession,
<input type="password" name="password" placeholder="Password" /> injectCsrfToken,
<button type="submit">Unlock</button> checkCsrfToken
</form> );
`);
}); /* login attempt */
infoPageRouter.use(checkIfUnlocked); infoPageRouter.post(LOGIN_ROUTE, (req, res) => {
} const token = (req.body.token || "").trim();
infoPageRouter.get("/", handleInfoPage);
infoPageRouter.get("/status", (req, res) => { const user = getUser(token); // returns undefined if invalid
res.json(buildInfo(req.protocol + "://" + req.get("host"), false)); if (!user) {
return res
.status(401)
.send(renderLoginPage(res.locals.csrfToken, "Invalid token. Please try again."));
}
req.session!.infoPageAuthed = true;
res.redirect("/");
}); });
export { infoPageRouter };
/* GET / either login form or info page */
if (config.enableInfoPageLogin) {
infoPageRouter.get(LOGIN_ROUTE, requireLogin, handleInfoPage);
} else {
infoPageRouter.get(LOGIN_ROUTE, handleInfoPage);
}
/* ─── Removed the public /status route : simply not added ─── */
export { infoPageRouter };
+4 -2
View File
@@ -855,10 +855,12 @@ const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
}, },
`Incrementing usage for model` `Incrementing usage for model`
); );
keyPool.incrementUsage(req.key!, model, tokensUsed); // Get modelFamily for the key usage log
const modelFamilyForKeyPool = req.modelFamily!; // Should be set by getModelFamilyForRequest earlier
keyPool.incrementUsage(req.key!, modelFamilyForKeyPool, { input: req.promptTokens!, output: req.outputTokens! });
if (req.user) { if (req.user) {
incrementPromptCount(req.user.token); incrementPromptCount(req.user.token);
incrementTokenCount(req.user.token, model, req.outboundApi, tokensUsed); incrementTokenCount(req.user.token, model, req.outboundApi, { input: req.promptTokens!, output: req.outputTokens! });
} }
} }
}; };
+68 -13
View File
@@ -74,14 +74,18 @@ type ModelAggregates = {
gcpSonnet35?: number; gcpSonnet35?: number;
gcpHaiku?: number; gcpHaiku?: number;
queued: number; queued: number;
tokens: number; inputTokens: number; // Changed from tokens
outputTokens: number; // Added
legacyTokens?: number; // Added for migrated totals
}; };
/** All possible combinations of model family and aggregate type. */ /** All possible combinations of model family and aggregate type. */
type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`; type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`;
type AllStats = { type AllStats = {
proompts: number; proompts: number;
tokens: number; inputTokens: number; // Changed from tokens
outputTokens: number; // Added
legacyTokens?: number; // Added
tokenCost: number; tokenCost: number;
} & { [modelFamily in ModelFamily]?: ModelAggregates } & { } & { [modelFamily in ModelFamily]?: ModelAggregates } & {
[service in LLMService as `${service}__${ServiceAggregate}`]?: number; [service in LLMService as `${service}__${ServiceAggregate}`]?: number;
@@ -288,11 +292,14 @@ function getEndpoints(baseUrl: string, accessibleFamilies: Set<ModelFamily>) {
type TrafficStats = Pick<ServiceInfo, "proompts" | "tookens" | "proomptersNow">; type TrafficStats = Pick<ServiceInfo, "proompts" | "tookens" | "proomptersNow">;
function getTrafficStats(): TrafficStats { function getTrafficStats(): TrafficStats {
const tokens = serviceStats.get("tokens") || 0; const inputTokens = serviceStats.get("inputTokens") || 0;
const outputTokens = serviceStats.get("outputTokens") || 0;
// const legacyTokens = serviceStats.get("legacyTokens") || 0; // Optional: include in total if desired
const totalTokens = inputTokens + outputTokens; // + legacyTokens;
const tokenCost = serviceStats.get("tokenCost") || 0; const tokenCost = serviceStats.get("tokenCost") || 0;
return { return {
proompts: serviceStats.get("proompts") || 0, proompts: serviceStats.get("proompts") || 0,
tookens: `${prettyTokens(tokens)}${getCostSuffix(tokenCost)}`, tookens: `${prettyTokens(totalTokens)}${getCostSuffix(tokenCost)}`, // Simplified to show aggregate and cost
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}), ...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
}; };
} }
@@ -352,14 +359,39 @@ function addKeyToAggregates(k: KeyPoolKey) {
addToService("cohere__keys", k.service === "cohere" ? 1 : 0); addToService("cohere__keys", k.service === "cohere" ? 1 : 0);
addToService("qwen__keys", k.service === "qwen" ? 1 : 0); addToService("qwen__keys", k.service === "qwen" ? 1 : 0);
let sumTokens = 0; let sumInputTokens = 0;
let sumOutputTokens = 0;
let sumLegacyTokens = 0; // Optional
let sumCost = 0; let sumCost = 0;
const incrementGenericFamilyStats = (f: ModelFamily) => { const incrementGenericFamilyStats = (f: ModelFamily) => {
const tokens = (k as any)[`${f}Tokens`]; const usage = k.tokenUsage?.[f];
sumTokens += tokens; let familyInputTokens = 0;
sumCost += getTokenCostUsd(f, tokens); let familyOutputTokens = 0;
addToFamily(`${f}__tokens`, tokens); let familyLegacyTokens = 0;
if (usage) {
familyInputTokens = usage.input || 0;
familyOutputTokens = usage.output || 0;
if (usage.legacy_total && familyInputTokens === 0 && familyOutputTokens === 0) {
// This is a migrated key with no new usage, use legacy_total as input for cost
familyLegacyTokens = usage.legacy_total;
sumCost += getTokenCostUsd(f, usage.legacy_total, 0);
} else {
sumCost += getTokenCostUsd(f, familyInputTokens, familyOutputTokens);
}
}
// If no k.tokenUsage[f], tokens are 0, cost is 0.
sumInputTokens += familyInputTokens;
sumOutputTokens += familyOutputTokens;
sumLegacyTokens += familyLegacyTokens; // Optional
addToFamily(`${f}__inputTokens`, familyInputTokens);
addToFamily(`${f}__outputTokens`, familyOutputTokens);
if (familyLegacyTokens > 0) {
addToFamily(`${f}__legacyTokens`, familyLegacyTokens); // Optional
}
addToFamily(`${f}__revoked`, k.isRevoked ? 1 : 0); addToFamily(`${f}__revoked`, k.isRevoked ? 1 : 0);
addToFamily(`${f}__active`, k.isDisabled ? 0 : 1); addToFamily(`${f}__active`, k.isDisabled ? 0 : 1);
}; };
@@ -493,15 +525,38 @@ function addKeyToAggregates(k: KeyPoolKey) {
assertNever(k.service); assertNever(k.service);
} }
addToService("tokens", sumTokens); addToService("inputTokens", sumInputTokens);
addToService("outputTokens", sumOutputTokens);
if (sumLegacyTokens > 0) { // Optional
addToService("legacyTokens", sumLegacyTokens);
}
addToService("tokenCost", sumCost); addToService("tokenCost", sumCost);
} }
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
const tokens = familyStats.get(`${family}__tokens`) || 0; const inputTokens = familyStats.get(`${family}__inputTokens`) || 0;
const cost = getTokenCostUsd(family, tokens); const outputTokens = familyStats.get(`${family}__outputTokens`) || 0;
const legacyTokens = familyStats.get(`${family}__legacyTokens`) || 0; // Optional
let cost = 0;
let displayTokens = 0;
let usageString = "";
if (inputTokens > 0 || outputTokens > 0) {
cost = getTokenCostUsd(family, inputTokens, outputTokens);
displayTokens = inputTokens + outputTokens;
usageString = `${prettyTokens(displayTokens)} (In: ${prettyTokens(inputTokens)}, Out: ${prettyTokens(outputTokens)})${getCostSuffix(cost)}`;
} else if (legacyTokens > 0) {
// Only show legacy if no new input/output has been recorded for this family aggregate
cost = getTokenCostUsd(family, legacyTokens, 0); // Cost legacy as all input
displayTokens = legacyTokens;
usageString = `${prettyTokens(displayTokens)} tokens (legacy total)${getCostSuffix(cost)}`;
} else {
usageString = `${prettyTokens(0)} tokens${getCostSuffix(0)}`;
}
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = { let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = {
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`, usage: usageString,
activeKeys: familyStats.get(`${family}__active`) || 0, activeKeys: familyStats.get(`${family}__active`) || 0,
revokedKeys: familyStats.get(`${family}__revoked`) || 0, revokedKeys: familyStats.get(`${family}__revoked`) || 0,
}; };
+3 -2
View File
@@ -1,6 +1,6 @@
import { RequestHandler } from "express"; import { RequestHandler } from "express";
import { config } from "../config"; import { config } from "../config";
import { getTokenCostUsd, prettyTokens } from "./stats"; import { getTokenCostUsd, getTokenCostDetailsUsd, prettyTokens } from "./stats"; // Added getTokenCostDetailsUsd
import { redactIp } from "./utils"; import { redactIp } from "./utils";
import * as userStore from "./users/user-store"; import * as userStore from "./users/user-store";
@@ -30,7 +30,8 @@ export const injectLocals: RequestHandler = (req, res, next) => {
// view helpers // view helpers
res.locals.prettyTokens = prettyTokens; res.locals.prettyTokens = prettyTokens;
res.locals.tokenCost = getTokenCostUsd; res.locals.tokenCost = getTokenCostUsd; // Returns total cost as a number
res.locals.tokenCostDetails = getTokenCostDetailsUsd; // Returns { inputCost, outputCost, totalCost }
res.locals.redactIp = redactIp; res.locals.redactIp = redactIp;
next(); next();
+18 -10
View File
@@ -16,11 +16,8 @@ export type AnthropicKeyUpdate = Omit<
| "rateLimitedUntil" | "rateLimitedUntil"
>; >;
type AnthropicKeyUsage = { // AnthropicKeyUsage is removed, tokenUsage from base Key interface will be used.
[K in AnthropicModelFamily as `${K}Tokens`]: number; export interface AnthropicKey extends Key {
};
export interface AnthropicKey extends Key, AnthropicKeyUsage {
readonly service: "anthropic"; readonly service: "anthropic";
readonly modelFamilies: AnthropicModelFamily[]; readonly modelFamilies: AnthropicModelFamily[];
/** /**
@@ -120,8 +117,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
.digest("hex") .digest("hex")
.slice(0, 8)}`, .slice(0, 8)}`,
lastChecked: 0, lastChecked: 0,
claudeTokens: 0, tokenUsage: {}, // Initialize new tokenUsage field
"claude-opusTokens": 0,
tier: "unknown", tier: "unknown",
}; };
this.keys.push(newKey); this.keys.push(newKey);
@@ -206,11 +202,23 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: AnthropicModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
key[`${getClaudeModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Ensure the specific family object exists
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
} }
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+17 -14
View File
@@ -7,11 +7,8 @@ import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys"; import { prioritizeKeys } from "../prioritize-keys";
import { AwsKeyChecker } from "./checker"; import { AwsKeyChecker } from "./checker";
type AwsBedrockKeyUsage = { // AwsBedrockKeyUsage is removed, tokenUsage from base Key interface will be used.
[K in AwsBedrockModelFamily as `${K}Tokens`]: number; export interface AwsBedrockKey extends Key {
};
export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
readonly service: "aws"; readonly service: "aws";
readonly modelFamilies: AwsBedrockModelFamily[]; readonly modelFamilies: AwsBedrockModelFamily[];
/** /**
@@ -74,12 +71,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
lastChecked: 0, lastChecked: 0,
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"], modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
inferenceProfileIds: [], inferenceProfileIds: [],
["aws-claudeTokens"]: 0, tokenUsage: {}, // Initialize new tokenUsage field
["aws-claude-opusTokens"]: 0,
["aws-mistral-tinyTokens"]: 0,
["aws-mistral-smallTokens"]: 0,
["aws-mistral-mediumTokens"]: 0,
["aws-mistral-largeTokens"]: 0,
}; };
this.keys.push(newKey); this.keys.push(newKey);
} }
@@ -173,11 +165,22 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: AwsBedrockModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
} }
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+17 -26
View File
@@ -10,11 +10,8 @@ import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys"; import { prioritizeKeys } from "../prioritize-keys";
import { AzureOpenAIKeyChecker } from "./checker"; import { AzureOpenAIKeyChecker } from "./checker";
type AzureOpenAIKeyUsage = { // AzureOpenAIKeyUsage is removed, tokenUsage from base Key interface will be used.
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number; export interface AzureOpenAIKey extends Key {
};
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
readonly service: "azure"; readonly service: "azure";
readonly modelFamilies: AzureOpenAIModelFamily[]; readonly modelFamilies: AzureOpenAIModelFamily[];
contentFiltering: boolean; contentFiltering: boolean;
@@ -68,24 +65,7 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
.digest("hex") .digest("hex")
.slice(0, 8)}`, .slice(0, 8)}`,
lastChecked: 0, lastChecked: 0,
"azure-turboTokens": 0, tokenUsage: {}, // Initialize new tokenUsage field
"azure-gpt4Tokens": 0,
"azure-gpt4-32kTokens": 0,
"azure-gpt4-turboTokens": 0,
"azure-gpt4oTokens": 0,
"azure-gpt45Tokens": 0,
"azure-gpt41Tokens": 0,
"azure-gpt41-miniTokens": 0,
"azure-gpt41-nanoTokens": 0,
"azure-o1Tokens": 0,
"azure-o1-miniTokens": 0,
"azure-o1-proTokens": 0,
"azure-o3-miniTokens": 0,
"azure-o3Tokens": 0,
"azure-o4-miniTokens": 0,
"azure-codex-miniTokens": 0,
"azure-dall-eTokens": 0,
"azure-gpt-imageTokens": 0,
modelIds: [], modelIds: [],
}; };
this.keys.push(newKey); this.keys.push(newKey);
@@ -140,11 +120,22 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: AzureOpenAIModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
} }
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+20 -12
View File
@@ -2,13 +2,10 @@ import { Key, KeyProvider, createGenericGetLockoutPeriod } from "..";
import { CohereKeyChecker } from "./checker"; import { CohereKeyChecker } from "./checker";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import { CohereModelFamily } from "../../models"; import { CohereModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
type CohereKeyUsage = { // CohereKeyUsage is removed, tokenUsage from base Key interface will be used.
"cohereTokens": number; export interface CohereKey extends Key {
};
export interface CohereKey extends Key, CohereKeyUsage {
readonly service: "cohere"; readonly service: "cohere";
readonly modelFamilies: CohereModelFamily[]; readonly modelFamilies: CohereModelFamily[];
isOverQuota: boolean; isOverQuota: boolean;
@@ -42,7 +39,7 @@ export class CohereKeyProvider implements KeyProvider<CohereKey> {
hash: this.hashKey(key), hash: this.hashKey(key),
rateLimitedAt: 0, rateLimitedAt: 0,
rateLimitedUntil: 0, rateLimitedUntil: 0,
"cohereTokens": 0, tokenUsage: {}, // Initialize new tokenUsage field
isOverQuota: false, isOverQuota: false,
}); });
} }
@@ -99,13 +96,24 @@ export class CohereKeyProvider implements KeyProvider<CohereKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: CohereModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++;
key[`cohereTokens`] += tokens;
}
key.promptCount++;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Cohere only has one model family "cohere"
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
/** /**
* Upon being rate limited, a key will be locked out for this many milliseconds * Upon being rate limited, a key will be locked out for this many milliseconds
+21 -13
View File
@@ -2,13 +2,10 @@ import { Key, KeyProvider, createGenericGetLockoutPeriod } from "..";
import { DeepseekKeyChecker } from "./checker"; import { DeepseekKeyChecker } from "./checker";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import { DeepseekModelFamily } from "../../models"; import { DeepseekModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
type DeepseekKeyUsage = { // DeepseekKeyUsage is removed, tokenUsage from base Key interface will be used.
"deepseekTokens": number; export interface DeepseekKey extends Key {
};
export interface DeepseekKey extends Key, DeepseekKeyUsage {
readonly service: "deepseek"; readonly service: "deepseek";
readonly modelFamilies: DeepseekModelFamily[]; readonly modelFamilies: DeepseekModelFamily[];
isOverQuota: boolean; isOverQuota: boolean;
@@ -42,7 +39,7 @@ export class DeepseekKeyProvider implements KeyProvider<DeepseekKey> {
hash: this.hashKey(key), hash: this.hashKey(key),
rateLimitedAt: 0, rateLimitedAt: 0,
rateLimitedUntil: 0, rateLimitedUntil: 0,
"deepseekTokens": 0, tokenUsage: {}, // Initialize new tokenUsage field
isOverQuota: false, isOverQuota: false,
}); });
} }
@@ -99,13 +96,24 @@ export class DeepseekKeyProvider implements KeyProvider<DeepseekKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: DeepseekModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++;
key[`deepseekTokens`] += tokens;
}
key.promptCount++;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Deepseek only has one model family "deepseek"
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
/** /**
* Upon being rate limited, a key will be locked out for this many milliseconds * Upon being rate limited, a key will be locked out for this many milliseconds
@@ -156,4 +164,4 @@ export class DeepseekKeyProvider implements KeyProvider<DeepseekKey> {
key.rateLimitedAt = now; key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit); key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
} }
} }
+17 -10
View File
@@ -7,11 +7,8 @@ import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys"; import { prioritizeKeys } from "../prioritize-keys";
import { GcpKeyChecker } from "./checker"; import { GcpKeyChecker } from "./checker";
type GcpKeyUsage = { // GcpKeyUsage is removed, tokenUsage from base Key interface will be used.
[K in GcpModelFamily as `${K}Tokens`]: number; export interface GcpKey extends Key {
};
export interface GcpKey extends Key, GcpKeyUsage {
readonly service: "gcp"; readonly service: "gcp";
readonly modelFamilies: GcpModelFamily[]; readonly modelFamilies: GcpModelFamily[];
sonnetEnabled: boolean; sonnetEnabled: boolean;
@@ -75,8 +72,7 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
sonnet35Enabled: false, sonnet35Enabled: false,
accessToken: "", accessToken: "",
accessTokenExpiresAt: 0, accessTokenExpiresAt: 0,
["gcp-claudeTokens"]: 0, tokenUsage: {}, // Initialize new tokenUsage field
["gcp-claude-opusTokens"]: 0,
}; };
this.keys.push(newKey); this.keys.push(newKey);
} }
@@ -160,11 +156,22 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: GcpModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
key[`${getGcpModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
} }
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+17 -11
View File
@@ -22,11 +22,8 @@ export type GoogleAIKeyUpdate = Omit<
| "rateLimitedUntil" | "rateLimitedUntil"
>; >;
type GoogleAIKeyUsage = { // GoogleAIKeyUsage is removed, tokenUsage from base Key interface will be used.
[K in GoogleAIModelFamily as `${K}Tokens`]: number; export interface GoogleAIKey extends Key {
};
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
readonly service: "google-ai"; readonly service: "google-ai";
readonly modelFamilies: GoogleAIModelFamily[]; readonly modelFamilies: GoogleAIModelFamily[];
/** All detected model IDs on this key. */ /** All detected model IDs on this key. */
@@ -84,9 +81,7 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
.digest("hex") .digest("hex")
.slice(0, 8)}`, .slice(0, 8)}`,
lastChecked: 0, lastChecked: 0,
"gemini-flashTokens": 0, tokenUsage: {}, // Initialize new tokenUsage field
"gemini-proTokens": 0,
"gemini-ultraTokens": 0,
modelIds: [], modelIds: [],
overQuotaFamilies: [], overQuotaFamilies: [],
}; };
@@ -139,11 +134,22 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: GoogleAIModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
key[`${getGoogleAIModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
} }
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+9 -1
View File
@@ -36,6 +36,14 @@ export interface Key {
rateLimitedAt: number; rateLimitedAt: number;
/** The time until which this key is rate limited. */ /** The time until which this key is rate limited. */
rateLimitedUntil: number; rateLimitedUntil: number;
/** Detailed token usage, separated by input and output, per model family. */
tokenUsage?: {
[family in ModelFamily]?: {
input: number;
output: number;
legacy_total?: number; // To store migrated single-number totals
};
};
} }
/* /*
@@ -58,7 +66,7 @@ export interface KeyProvider<T extends Key = Key> {
disable(key: T): void; disable(key: T): void;
update(hash: string, update: Partial<T>): void; update(hash: string, update: Partial<T>): void;
available(): number; available(): number;
incrementUsage(hash: string, model: string, tokens: number): void; incrementUsage(hash: string, modelFamily: ModelFamily, usage: { input: number; output: number }): void;
getLockoutPeriod(model: ModelFamily): number; getLockoutPeriod(model: ModelFamily): number;
markRateLimited(hash: string): void; markRateLimited(hash: string): void;
recheck(): void; recheck(): void;
+24 -3
View File
@@ -108,9 +108,30 @@ export class KeyPool {
}, 0); }, 0);
} }
public incrementUsage(key: Key, model: string, tokens: number): void { public incrementUsage(key: Key, modelName: string, usage: { input: number; output: number }): void {
const provider = this.getKeyProvider(key.service); const provider = this.getKeyProvider(key.service);
provider.incrementUsage(key.hash, model, tokens); // Assuming the provider's incrementUsage expects a modelFamily.
// We need a robust way to get modelFamily from modelName here.
// This might involve calling a method similar to getModelFamilyForRequest from user-store,
// or enhancing getServiceForModel to also return family, or passing family directly.
// For now, let's assume the provider can handle the modelName or we derive family.
// This part is tricky as KeyPool's getServiceForModel is for service, not family directly from a generic model string.
// Let's assume for now the provider's incrementUsage can take modelName and derive family,
// or the KeyProvider interface's incrementUsage should take modelName.
// The KeyProvider interface was changed to modelFamily. So we MUST derive it.
// This requires a utility function similar to what's in user-store or models.ts.
// For now, I'll placeholder this derivation. This is a critical point.
// Placeholder: const modelFamily = this.getModelFamilyForModel(modelName, key.service);
// This is complex because getModelFamilyForModel needs the service context.
// Let's assume the `modelName` passed here is actually `modelFamily` for now,
// or that the caller will resolve it.
// The KeyProvider interface expects `modelFamily`. The caller in middleware/response/index.ts
// has `model` (name) and `req.outboundApi`. It should resolve to family there.
// So, `modelName` here should actually be `modelFamily`.
// I will assume the caller of KeyPool.incrementUsage will pass modelFamily.
// So, changing `model: string` to `modelFamily: ModelFamily` in signature.
// This change needs to be propagated to the caller.
provider.incrementUsage(key.hash, modelName as ModelFamily, usage); // Casting modelName, assuming caller provides family
} }
public getLockoutPeriod(family: ModelFamily): number { public getLockoutPeriod(family: ModelFamily): number {
@@ -247,4 +268,4 @@ export class KeyPool {
); );
this.recheckJobs["google-ai"] = googleJob; this.recheckJobs["google-ai"] = googleJob;
} }
} }
@@ -7,11 +7,8 @@ import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys"; import { prioritizeKeys } from "../prioritize-keys";
import { MistralAIKeyChecker } from "./checker"; import { MistralAIKeyChecker } from "./checker";
type MistralAIKeyUsage = { // MistralAIKeyUsage is removed, tokenUsage from base Key interface will be used.
[K in MistralAIModelFamily as `${K}Tokens`]: number; export interface MistralAIKey extends Key {
};
export interface MistralAIKey extends Key, MistralAIKeyUsage {
readonly service: "mistral-ai"; readonly service: "mistral-ai";
readonly modelFamilies: MistralAIModelFamily[]; readonly modelFamilies: MistralAIModelFamily[];
} }
@@ -67,10 +64,7 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
.digest("hex") .digest("hex")
.slice(0, 8)}`, .slice(0, 8)}`,
lastChecked: 0, lastChecked: 0,
"mistral-tinyTokens": 0, tokenUsage: {}, // Initialize new tokenUsage field
"mistral-smallTokens": 0,
"mistral-mediumTokens": 0,
"mistral-largeTokens": 0,
}; };
this.keys.push(newKey); this.keys.push(newKey);
} }
@@ -117,12 +111,22 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: MistralAIModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
const family = getMistralAIModelFamily(model);
key[`${family}Tokens`] += tokens; if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
} }
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
+17 -26
View File
@@ -3,16 +3,13 @@ import http from "http";
import { Key, KeyProvider } from "../index"; import { Key, KeyProvider } from "../index";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models"; import { getOpenAIModelFamily, OpenAIModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
import { PaymentRequiredError } from "../../errors"; import { PaymentRequiredError } from "../../errors";
import { OpenAIKeyChecker } from "./checker"; import { OpenAIKeyChecker } from "./checker";
import { prioritizeKeys } from "../prioritize-keys"; import { prioritizeKeys } from "../prioritize-keys";
type OpenAIKeyUsage = { // OpenAIKeyUsage is removed, tokenUsage from base Key interface will be used.
[K in OpenAIModelFamily as `${K}Tokens`]: number; export interface OpenAIKey extends Key {
};
export interface OpenAIKey extends Key, OpenAIKeyUsage {
readonly service: "openai"; readonly service: "openai";
modelFamilies: OpenAIModelFamily[]; modelFamilies: OpenAIModelFamily[];
/** /**
@@ -108,24 +105,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
rateLimitedUntil: 0, rateLimitedUntil: 0,
rateLimitRequestsReset: 0, rateLimitRequestsReset: 0,
rateLimitTokensReset: 0, rateLimitTokensReset: 0,
turboTokens: 0, tokenUsage: {}, // Initialize new tokenUsage field
gpt4Tokens: 0,
"gpt4-32kTokens": 0,
"gpt4-turboTokens": 0,
gpt4oTokens: 0,
gpt45Tokens: 0,
gpt41Tokens: 0,
"gpt41-miniTokens": 0,
"gpt41-nanoTokens": 0,
"o1Tokens": 0,
"o1-miniTokens": 0,
"o1-proTokens": 0,
"o3-miniTokens": 0,
"o3Tokens": 0,
"o4-miniTokens": 0,
"codex-miniTokens": 0,
"dall-eTokens": 0,
"gpt-imageTokens": 0,
modelIds: [], modelIds: [],
}; };
this.keys.push(newKey); this.keys.push(newKey);
@@ -337,11 +317,22 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
key.rateLimitedUntil = now + key.rateLimitRequestsReset; key.rateLimitedUntil = now + key.rateLimitRequestsReset;
} }
public incrementUsage(keyHash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: OpenAIModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
key[`${getOpenAIModelFamily(model)}Tokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
} }
public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) { public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) {
+1 -1
View File
@@ -6,7 +6,7 @@ export interface QwenKey extends Key {
readonly service: "qwen"; readonly service: "qwen";
readonly modelFamilies: QwenModelFamily[]; readonly modelFamilies: QwenModelFamily[];
isOverQuota: boolean; isOverQuota: boolean;
"qwenTokens": number; // "qwenTokens" is removed, tokenUsage from base Key interface will be used.
} }
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import { assertNever } from "../../utils"; import { assertNever } from "../../utils";
+17 -4
View File
@@ -2,6 +2,7 @@ import { KeyProvider, createGenericGetLockoutPeriod } from "..";
import { QwenKeyChecker, QwenKey } from "./checker"; import { QwenKeyChecker, QwenKey } from "./checker";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import { QwenModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
// Re-export the QwenKey interface // Re-export the QwenKey interface
export type { QwenKey } from "./checker"; export type { QwenKey } from "./checker";
@@ -36,7 +37,7 @@ export class QwenKeyProvider implements KeyProvider<QwenKey> {
hash: this.hashKey(key), hash: this.hashKey(key),
rateLimitedAt: 0, rateLimitedAt: 0,
rateLimitedUntil: 0, rateLimitedUntil: 0,
"qwenTokens": 0, tokenUsage: {}, // Initialize new tokenUsage field
isOverQuota: false, isOverQuota: false,
}); });
} }
@@ -93,11 +94,23 @@ export class QwenKeyProvider implements KeyProvider<QwenKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: QwenModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
key[`qwenTokens`] += tokens;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Qwen only has one model family "qwen"
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
} }
/** /**
+21 -13
View File
@@ -2,13 +2,10 @@ import { Key, KeyProvider, createGenericGetLockoutPeriod } from "..";
import { XaiKeyChecker } from "./checker"; import { XaiKeyChecker } from "./checker";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import { XaiModelFamily } from "../../models"; import { XaiModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
type XaiKeyUsage = { // XaiKeyUsage is removed, tokenUsage from base Key interface will be used.
"xaiTokens": number; export interface XaiKey extends Key {
};
export interface XaiKey extends Key, XaiKeyUsage {
readonly service: "xai"; readonly service: "xai";
readonly modelFamilies: XaiModelFamily[]; readonly modelFamilies: XaiModelFamily[];
isOverQuota: boolean; isOverQuota: boolean;
@@ -42,7 +39,7 @@ export class XaiKeyProvider implements KeyProvider<XaiKey> {
hash: this.hashKey(key), hash: this.hashKey(key),
rateLimitedAt: 0, rateLimitedAt: 0,
rateLimitedUntil: 0, rateLimitedUntil: 0,
"xaiTokens": 0, tokenUsage: {}, // Initialize new tokenUsage field
isOverQuota: false, isOverQuota: false,
}); });
} }
@@ -99,13 +96,24 @@ export class XaiKeyProvider implements KeyProvider<XaiKey> {
return this.keys.filter((k) => !k.isDisabled).length; return this.keys.filter((k) => !k.isDisabled).length;
} }
public incrementUsage(hash: string, model: string, tokens: number) { public incrementUsage(keyHash: string, modelFamily: XaiModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return; if (!key) return;
key.promptCount++;
key[`xaiTokens`] += tokens;
}
key.promptCount++;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Xai only has one model family "xai"
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
/** /**
* Upon being rate limited, a key will be locked out for this many milliseconds * Upon being rate limited, a key will be locked out for this many milliseconds
@@ -156,4 +164,4 @@ export class XaiKeyProvider implements KeyProvider<XaiKey> {
key.rateLimitedAt = now; key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit); key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
} }
} }
+62
View File
@@ -0,0 +1,62 @@
import Database from 'better-sqlite3';
import { config } from '../config';
import { logger } from '../logger';
const log = logger.child({ module: 'sqlite-db' });
let db: Database.Database;
export function initSQLiteDB(): Database.Database {
if (db) {
return db;
}
const dbPath = config.sqliteUserStorePath;
if (!dbPath) {
log.error('SQLite user store DB path (SQLITE_USER_STORE_PATH) is not configured.');
throw new Error('SQLite user store DB path is not configured.');
}
log.info({ path: dbPath }, 'Initializing SQLite database for user store...');
db = new Database(dbPath, { verbose: config.logLevel === 'trace' ? console.log : undefined });
// Enable WAL mode for better concurrency and performance.
db.pragma('journal_mode = WAL');
// Create users table
// Note: JSON fields (ip, tokenCounts, etc.) are stored as TEXT.
// Timestamps are stored as INTEGER (Unix epoch milliseconds).
db.exec(`
CREATE TABLE IF NOT EXISTS users (
token TEXT PRIMARY KEY,
ip TEXT, /* JSON string array */
nickname TEXT,
type TEXT NOT NULL CHECK(type IN ('normal', 'special', 'temporary')),
promptCount INTEGER NOT NULL DEFAULT 0,
tokenCounts TEXT, /* JSON string object */
tokenLimits TEXT, /* JSON string object */
tokenRefresh TEXT, /* JSON string object */
createdAt INTEGER NOT NULL,
lastUsedAt INTEGER,
disabledAt INTEGER,
disabledReason TEXT,
expiresAt INTEGER,
maxIps INTEGER,
adminNote TEXT,
meta TEXT /* JSON string object */
);
`);
log.info('SQLite database initialized and `users` table created/verified.');
return db;
}
export function getDB(): Database.Database {
if (!db) {
// This might happen if getDB is called before initSQLiteDB,
// though user-store should ensure init is called first.
log.warn('SQLite DB instance requested before initialization. Attempting to initialize now.');
return initSQLiteDB();
}
return db;
}
+82 -140
View File
@@ -1,146 +1,88 @@
import { config } from "../config"; import { config } from "../config";
import { ModelFamily } from "./models"; import { ModelFamily } from "./models";
// Using weighted averages now for better guessing, thinking models use around 1:3 ratio for input:output // Prices are per 1 million tokens.
// for the thinking part, other models hover around 3:1 input output, still not the best, but reflects better to real proompting. const MODEL_PRICING: Record<ModelFamily, { input: number; output: number } | undefined> = {
export function getTokenCostUsd(model: ModelFamily, tokens: number) { "deepseek": { input: 0.14, output: 0.28 }, // DeepSeek-V2: $0.14/$0.28 per 1M tokens
let cost = 0; "xai": { input: 5.6, output: 16.8 }, // Grok: Derived from avg $14/1M (assuming 1:3 in/out ratio) - needs official pricing
switch (model) { "gpt41": { input: 2.00, output: 8.00 },
case "deepseek": "azure-gpt41": { input: 2.00, output: 8.00 },
cost = 0.00000178; "gpt41-mini": { input: 0.40, output: 1.60 },
// uncached r1 pricing, again the highest average "azure-gpt41-mini": { input: 0.40, output: 1.60 },
break; "gpt41-nano": { input: 0.10, output: 0.40 },
case "xai": "azure-gpt41-nano": { input: 0.10, output: 0.40 },
cost = 0.000014; "gpt45": { input: 75.00, output: 150.00 }, // Example, needs verification if this model family is still current with this pricing
// just using the highest input/output price aka grok-3 (because who cares about grok) "azure-gpt45": { input: 75.00, output: 150.00 }, // Example, needs verification
break; "gpt4o": { input: 5.00, output: 20.00 },
case "gpt41": "azure-gpt4o": { input: 5.00, output: 20.00 },
case "azure-gpt41": "gpt4-turbo": { input: 10.00, output: 30.00 },
cost = 0.0000075; "azure-gpt4-turbo": { input: 10.00, output: 30.00 },
// averaged the same wa* as 4.5 "o1-pro": { input: 150.00, output: 600.00 },
break; "azure-o1-pro": { input: 150.00, output: 600.00 },
case "gpt41-mini": "o1": { input: 15.00, output: 60.00 },
case "azure-gpt41-mini": "azure-o1": { input: 15.00, output: 60.00 },
cost = 0.0000015; "o1-mini": { input: 1.10, output: 4.40 },
break; "azure-o1-mini": { input: 1.10, output: 4.40 },
case "gpt41-nano": "o3-mini": { input: 1.10, output: 4.40 },
case "azure-gpt41-nano": "azure-o3-mini": { input: 1.10, output: 4.40 },
cost = 0.0000003; "o3": { input: 10.00, output: 40.00 },
break; "azure-o3": { input: 10.00, output: 40.00 },
case "gpt45": "o4-mini": { input: 1.10, output: 4.40 },
case "azure-gpt45": "azure-o4-mini": { input: 1.10, output: 4.40 },
// $75/$150 for 1M input/output tokens pricing, averaged to $112 "codex-mini": { input: 1.50, output: 6.00 },
cost = 0.00009375; "azure-codex-mini": { input: 1.50, output: 6.00 },
break; "gpt4-32k": { input: 60.00, output: 120.00 },
case "gpt4o": "azure-gpt4-32k": { input: 60.00, output: 120.00 },
case "azure-gpt4o": "gpt4": { input: 30.00, output: 60.00 },
cost = 0.0000075; "azure-gpt4": { input: 30.00, output: 60.00 },
break; "turbo": { input: 0.60, output: 2.40 }, // Maps to GPT-4o mini
case "azure-gpt4-turbo": "azure-turbo": { input: 0.60, output: 2.40 },
case "gpt4-turbo": "dall-e": { input: 0, output: 0 }, // Pricing is per image, not token based in this context.
cost = 0.0000125; "azure-dall-e": { input: 0, output: 0 }, // Pricing is per image.
break; "gpt-image": { input: 0, output: 0 }, // Complex pricing (text, image input, image output tokens), handle separately.
case "azure-o1-pro": "azure-gpt-image": { input: 0, output: 0 }, // Complex pricing.
case "o1-pro": "claude": { input: 3.00, output: 15.00 }, // Anthropic Claude Sonnet 4
// OpenAI o1-pro pricing $150/1M input tokens and $600/1M output tokens "aws-claude": { input: 3.00, output: 15.00 },
cost = 0.0004875; "gcp-claude": { input: 3.00, output: 15.00 },
break; "claude-opus": { input: 15.00, output: 75.00 }, // Anthropic Claude Opus 4
case "azure-o1": "aws-claude-opus": { input: 15.00, output: 75.00 },
case "o1": "gcp-claude-opus": { input: 15.00, output: 75.00 },
// Currently we do not track output tokens separately, and O1 uses "mistral-tiny": { input: 0.04, output: 0.04 }, // Using old price if no new API price found
// considerably more output tokens that other models for its hidden "aws-mistral-tiny": { input: 0.04, output: 0.04 },
// reasoning. The official O1 pricing is $15/1M input tokens and $60/1M "mistral-small": { input: 0.10, output: 0.30 }, // Mistral Small 3.1
// output tokens so we will return a higher estimate here. "aws-mistral-small": { input: 0.10, output: 0.30 },
cost = 0.00004875; "mistral-medium": { input: 0.40, output: 2.00 }, // Mistral Medium 3
break; "aws-mistral-medium": { input: 0.40, output: 2.00 },
case "azure-o1-mini": "mistral-large": { input: 2.00, output: 6.00 },
case "o1-mini": "aws-mistral-large": { input: 2.00, output: 6.00 },
case "azure-o3-mini": "gemini-flash": { input: 0.35, output: 1.05 }, // Gemini 1.5 Flash
case "o3-mini": "gemini-pro": { input: 0.125, output: 0.375 }, // Gemini 1.0 Pro
cost = 0.000003575; // $1.1/1M input tokens, $4.4/1M output tokens "gemini-ultra": { input: 25.00, output: 75.00 }, // Estimated based on Gemini Pro (5-10x) and character to token conversion. Official per-token pricing needed.
break; // Ensure all ModelFamily entries from models.ts are covered or have a default.
case "azure-o3": // Adding placeholders for families in models.ts but not yet priced here.
case "o3": "cohere": { input: 0.25, output: 0.50 }, // Cohere Command R, as an example
cost = 0.000032; // $10/1M input tokens, $40/1M output tokens "qwen": { input: 1.40, output: 2.80 }, // Qwen-plus, as an example
break; };
case "azure-o4-mini":
case "o4-mini": export function getTokenCostDetailsUsd(model: ModelFamily, inputTokens: number, outputTokens?: number): { inputCost: number, outputCost: number, totalCost: number } {
cost = 0.000003575; // $1.1/1M input tokens, $4.4/1M output tokens const pricing = MODEL_PRICING[model];
break;
case "azure-codex-mini": if (!pricing) {
case "codex-mini": console.warn(`Pricing not found for model family: ${model}. Returning 0 cost for all components.`);
// Codex Mini pricing: $1.5/1M input tokens, $6.0/1M output tokens return { inputCost: 0, outputCost: 0, totalCost: 0 };
// Using weighted average for 1:3 input:output ratio
cost = 0.0000045; // Weighted average with output bias
break;
case "azure-gpt4-32k":
case "gpt4-32k":
cost = 0.000075;
break;
case "azure-gpt4":
case "gpt4":
cost = 0.0000375;
break;
case "azure-turbo":
case "turbo":
cost = 0.00000075;
break;
case "azure-dall-e":
case "dall-e":
cost = 0.00001;
break;
case "azure-gpt-image":
case "gpt-image":
// gpt-image-1 pricing:
// Text input tokens: $5 per 1M tokens
// Image input tokens: $10 per 1M tokens
// Image output tokens: $40 per 1M tokens
// Weighted average assuming a mix of text/image input and output
// Typical cost is $0.02-$0.19 per image depending on quality
cost = 0.000018; // Balanced estimate accounting for input/output mix
break;
case "aws-claude":
case "gcp-claude":
case "claude":
cost = 0.00001;
break;
case "aws-claude-opus":
case "gcp-claude-opus":
case "claude-opus":
cost = 0.00003;
break;
case "aws-mistral-tiny":
case "mistral-tiny":
// Using Ministral 3B pricing: $0.04/1M input tokens, $0.04/1M output tokens
// For edge/tiny models, a more balanced 1:1 ratio is used
cost = 0.00000004;
break;
case "aws-mistral-small":
case "mistral-small":
// Using Codestral pricing: $0.3/1M input, $0.9/1M output (highest in category)
// Weighted average for 1:3 input:output ratio
cost = 0.00000075;
break;
case "aws-mistral-medium":
case "mistral-medium":
// Using Mistral Saba pricing: $0.2/1M input, $0.6/1M output
// Weighted average for 1:3 input:output ratio
cost = 0.0000005;
break;
case "aws-mistral-large":
case "mistral-large":
// Using Mistral Large/Pixtral Large pricing: $2/1M input, $6/1M output
// Weighted average for 1:3 input:output ratio
cost = 0.000005;
break;
case "gemini-flash":
cost = 0.0000002326;
break;
case "gemini-pro":
cost = 0.00000344;
break;
} }
return cost * Math.max(0, tokens);
const costPerMillionInputTokens = pricing.input;
const costPerMillionOutputTokens = pricing.output;
const inputCost = (costPerMillionInputTokens / 1_000_000) * Math.max(0, inputTokens);
const outputCost = (costPerMillionOutputTokens / 1_000_000) * Math.max(0, outputTokens ?? 0);
return { inputCost, outputCost, totalCost: inputCost + outputCost };
}
export function getTokenCostUsd(model: ModelFamily, inputTokens: number, outputTokens?: number): number {
return getTokenCostDetailsUsd(model, inputTokens, outputTokens).totalCost;
} }
export function prettyTokens(tokens: number): string { export function prettyTokens(tokens: number): string {
@@ -159,4 +101,4 @@ export function prettyTokens(tokens: number): string {
export function getCostSuffix(cost: number) { export function getCostSuffix(cost: number) {
if (!config.showTokenCosts) return ""; if (!config.showTokenCosts) return "";
return ` ($${cost.toFixed(2)})`; return ` ($${cost.toFixed(2)})`;
} }
+15 -5
View File
@@ -3,11 +3,21 @@ import { MODEL_FAMILIES, ModelFamily } from "../models";
import { makeOptionalPropsNullable } from "../utils"; import { makeOptionalPropsNullable } from "../utils";
// This just dynamically creates a Zod object type with a key for each model // This just dynamically creates a Zod object type with a key for each model
// family and an optional number value. // family and an optional number value for input and output tokens.
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object( export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object(
MODEL_FAMILIES.reduce( MODEL_FAMILIES.reduce(
(acc, family) => ({ ...acc, [family]: z.number().optional().default(0) }), (acc, family) => ({
{} as Record<ModelFamily, ZodType<number>> ...acc,
[family]: z
.object({
input: z.number().optional().default(0),
output: z.number().optional().default(0),
legacy_total: z.number().optional(), // Added legacy_total
})
.optional()
.default({ input: 0, output: 0 }), // Default will not have legacy_total
}),
{} as Record<ModelFamily, ZodType<{ input: number; output: number; legacy_total?: number }>>
) )
); );
@@ -33,7 +43,7 @@ export const UserSchema = z
* Never used; retained for backwards compatibility. * Never used; retained for backwards compatibility.
*/ */
tokenCount: z.any().optional(), tokenCount: z.any().optional(),
/** Number of tokens the user has consumed, by model family. */ /** Number of input and output tokens the user has consumed, by model family. */
tokenCounts: tokenCountsSchema, tokenCounts: tokenCountsSchema,
/** Maximum number of tokens the user can consume, by model family. */ /** Maximum number of tokens the user can consume, by model family. */
tokenLimits: tokenCountsSchema, tokenLimits: tokenCountsSchema,
@@ -67,7 +77,7 @@ export const UserPartialSchema = makeOptionalPropsNullable(UserSchema)
.extend({ token: z.string() }); .extend({ token: z.string() });
export type UserTokenCounts = { export type UserTokenCounts = {
[K in ModelFamily]: number | undefined; [K in ModelFamily]: { input: number; output: number; legacy_total?: number } | undefined;
}; };
export type User = z.infer<typeof UserSchema>; export type User = z.infer<typeof UserSchema>;
export type UserUpdate = z.infer<typeof UserPartialSchema>; export type UserUpdate = z.infer<typeof UserPartialSchema>;
+305 -42
View File
@@ -10,9 +10,11 @@
import admin from "firebase-admin"; import admin from "firebase-admin";
import schedule from "node-schedule"; import schedule from "node-schedule";
import { v4 as uuid } from "uuid"; import { v4 as uuid } from "uuid";
import type { Database } from 'better-sqlite3';
import { config } from "../../config"; import { config } from "../../config";
import { logger } from "../../logger"; import { logger } from "../../logger";
import { getFirebaseApp } from "../firebase"; import { getFirebaseApp } from "../firebase";
import { initSQLiteDB, getDB } from "../sqlite-db"; // Added
import { APIFormat } from "../key-management"; import { APIFormat } from "../key-management";
import { import {
getAwsBedrockModelFamily, getAwsBedrockModelFamily,
@@ -31,9 +33,45 @@ import { User, UserTokenCounts, UserUpdate } from "./schema";
const log = logger.child({ module: "users" }); const log = logger.child({ module: "users" });
const INITIAL_TOKENS: Required<UserTokenCounts> = MODEL_FAMILIES.reduce( const INITIAL_TOKENS: Required<UserTokenCounts> = MODEL_FAMILIES.reduce(
(acc, family) => ({ ...acc, [family]: 0 }), (acc, family) => {
{} as Record<ModelFamily, number> acc[family] = { input: 0, output: 0 }; // legacy_total is undefined by default
); return acc;
},
{} as Record<ModelFamily, { input: number; output: number; legacy_total?: number }>
) as Required<UserTokenCounts>;
const migrateTokenCountsProperty = (
parsedProperty: any, // Data from DB (JSON.parse result for a specific user's property like tokenCounts)
defaultConfigForProperty: Record<ModelFamily, number | { input: number; output: number; legacy_total?: number } | undefined> // e.g., INITIAL_TOKENS or config.tokenQuota
): UserTokenCounts => {
const result = {} as UserTokenCounts;
for (const family of MODEL_FAMILIES) {
const dbValue = parsedProperty?.[family];
const configValue = defaultConfigForProperty[family];
if (typeof dbValue === 'number') {
// Case 1: DB has old numeric format - migrate and add legacy_total
result[family] = { input: dbValue, output: 0, legacy_total: dbValue };
} else if (typeof dbValue === 'object' && dbValue !== null && (typeof dbValue.input === 'number' || typeof dbValue.output === 'number')) {
// Case 2: DB has new object format (might or might not have legacy_total from a previous migration)
result[family] = { input: dbValue.input ?? 0, output: dbValue.output ?? 0, legacy_total: dbValue.legacy_total };
} else {
// Case 3: DB value is missing or invalid, use default from config
if (typeof configValue === 'number') {
// Default from config is old numeric format (e.g., config.tokenQuota[family]) - migrate and add legacy_total
result[family] = { input: configValue, output: 0, legacy_total: configValue };
} else if (typeof configValue === 'object' && configValue !== null && (typeof configValue.input === 'number' || typeof configValue.output === 'number')) {
// Default from config is new object format (e.g., INITIAL_TOKENS[family])
result[family] = { input: configValue.input ?? 0, output: configValue.output ?? 0, legacy_total: configValue.legacy_total };
} else {
// Ultimate fallback: if configValue is also missing or invalid for this family
result[family] = { input: 0, output: 0 }; // No legacy_total here
}
}
}
return result;
};
const users: Map<string, User> = new Map(); const users: Map<string, User> = new Map();
const usersToFlush = new Set<string>(); const usersToFlush = new Set<string>();
@@ -44,6 +82,8 @@ export async function init() {
log.info({ store: config.gatekeeperStore }, "Initializing user store..."); log.info({ store: config.gatekeeperStore }, "Initializing user store...");
if (config.gatekeeperStore === "firebase_rtdb") { if (config.gatekeeperStore === "firebase_rtdb") {
await initFirebase(); await initFirebase();
} else if (config.gatekeeperStore === "sqlite") {
await initSQLite(); // Added
} }
if (config.quotaRefreshPeriod) { if (config.quotaRefreshPeriod) {
const crontab = getRefreshCrontab(); const crontab = getRefreshCrontab();
@@ -80,9 +120,14 @@ export function createUser(createOptions?: {
ip: [], ip: [],
type: "normal", type: "normal",
promptCount: 0, promptCount: 0,
tokenCounts: { ...INITIAL_TOKENS }, tokenCounts: { ...INITIAL_TOKENS }, // New counts don't have legacy_total
tokenLimits: createOptions?.tokenLimits ?? { ...config.tokenQuota }, tokenLimits: createOptions?.tokenLimits ?? MODEL_FAMILIES.reduce((acc, family) => {
tokenRefresh: createOptions?.tokenRefresh ?? { ...INITIAL_TOKENS }, const quota = config.tokenQuota[family];
// If quota is a number, it's a legacy total limit, store it as such
acc[family] = typeof quota === 'number' ? { input: quota, output: 0, legacy_total: quota } : (quota || { input: 0, output: 0 });
return acc;
}, {} as UserTokenCounts),
tokenRefresh: createOptions?.tokenRefresh ?? { ...INITIAL_TOKENS }, // Refresh amounts typically start fresh
createdAt: Date.now(), createdAt: Date.now(),
meta: {}, meta: {},
}; };
@@ -125,9 +170,14 @@ export function upsertUser(user: UserUpdate) {
ip: [], ip: [],
type: "normal", type: "normal",
promptCount: 0, promptCount: 0,
tokenCounts: { ...INITIAL_TOKENS }, tokenCounts: { ...INITIAL_TOKENS }, // New counts don't have legacy_total
tokenLimits: { ...config.tokenQuota }, tokenLimits: MODEL_FAMILIES.reduce((acc, family) => {
tokenRefresh: { ...INITIAL_TOKENS }, const quota = config.tokenQuota[family];
// If quota is a number, it's a legacy total limit, store it as such
acc[family] = typeof quota === 'number' ? { input: quota, output: 0, legacy_total: quota } : (quota || { input: 0, output: 0 });
return acc;
}, {} as UserTokenCounts),
tokenRefresh: { ...INITIAL_TOKENS }, // Refresh amounts typically start fresh
createdAt: Date.now(), createdAt: Date.now(),
meta: {}, meta: {},
}; };
@@ -146,21 +196,37 @@ export function upsertUser(user: UserUpdate) {
if (updates.tokenCounts) { if (updates.tokenCounts) {
for (const family of MODEL_FAMILIES) { for (const family of MODEL_FAMILIES) {
updates.tokenCounts[family] ??= 0; updates.tokenCounts[family] ??= { input: 0, output: 0 };
// The property is now guaranteed to be an object, so the 'number' check is removed.
// Defaulting individual fields if they are missing.
const counts = updates.tokenCounts[family]!; // Should not be undefined here
counts.input ??= 0;
counts.output ??= 0;
// legacy_total is optional and not defaulted here if missing
} }
} }
if (updates.tokenLimits) { if (updates.tokenLimits) {
for (const family of MODEL_FAMILIES) { for (const family of MODEL_FAMILIES) {
updates.tokenLimits[family] ??= 0; updates.tokenLimits[family] ??= { input: 0, output: 0 };
// The property is now guaranteed to be an object, so the 'number' check is removed.
// Defaulting individual fields if they are missing.
const limits = updates.tokenLimits[family]!; // Should not be undefined here
limits.input ??= 0;
limits.output ??= 0;
// legacy_total is optional and not defaulted here if missing
} }
} }
// tokenRefresh is a special case where we want to merge the existing and // tokenRefresh is a special case where we want to merge the existing and
// updated values for each model family, ignoring falsy values. // updated values for each model family, ignoring falsy values.
if (updates.tokenRefresh) { if (updates.tokenRefresh) {
const merged = { ...existing.tokenRefresh }; const merged = { ...existing.tokenRefresh } as UserTokenCounts;
for (const family of MODEL_FAMILIES) { for (const family of MODEL_FAMILIES) {
merged[family] = const updateRefresh = updates.tokenRefresh[family];
updates.tokenRefresh[family] || existing.tokenRefresh[family]; const existingRefresh = existing.tokenRefresh[family];
merged[family] = {
input: (updateRefresh?.input || existingRefresh?.input) ?? 0,
output: (updateRefresh?.output || existingRefresh?.output) ?? 0,
};
} }
updates.tokenRefresh = merged; updates.tokenRefresh = merged;
} }
@@ -168,9 +234,11 @@ export function upsertUser(user: UserUpdate) {
users.set(user.token, Object.assign(existing, updates)); users.set(user.token, Object.assign(existing, updates));
usersToFlush.add(user.token); usersToFlush.add(user.token);
// Immediately schedule a flush to the database if we're using Firebase. // Immediately schedule a flush to the database if a persistent store is used.
if (config.gatekeeperStore === "firebase_rtdb") { if (config.gatekeeperStore === "firebase_rtdb") {
setImmediate(flushUsers); setImmediate(flushUsers);
} else if (config.gatekeeperStore === "sqlite") {
setImmediate(flushUsersToSQLite);
} }
return users.get(user.token); return users.get(user.token);
@@ -189,13 +257,16 @@ export function incrementTokenCount(
token: string, token: string,
model: string, model: string,
api: APIFormat, api: APIFormat,
consumption: number consumption: { input: number; output: number }
) { ) {
const user = users.get(token); const user = users.get(token);
if (!user) return; if (!user) return;
const modelFamily = getModelFamilyForQuotaUsage(model, api); const modelFamily = getModelFamilyForQuotaUsage(model, api);
const existing = user.tokenCounts[modelFamily] ?? 0; const existingCounts = user.tokenCounts[modelFamily] ?? { input: 0, output: 0 };
user.tokenCounts[modelFamily] = existing + consumption; user.tokenCounts[modelFamily] = {
input: (existingCounts.input ?? 0) + consumption.input,
output: (existingCounts.output ?? 0) + consumption.output,
};
usersToFlush.add(token); usersToFlush.add(token);
} }
@@ -251,12 +322,36 @@ export function hasAvailableQuota({
const modelFamily = getModelFamilyForQuotaUsage(model, api); const modelFamily = getModelFamilyForQuotaUsage(model, api);
const { tokenCounts, tokenLimits } = user; const { tokenCounts, tokenLimits } = user;
const tokenLimit = tokenLimits[modelFamily]; const limitConfig = tokenLimits[modelFamily];
const currentUsage = tokenCounts[modelFamily] ?? { input: 0, output: 0 };
if (!tokenLimit) return true; // If no specific limit object for the family, or if it's essentially unlimited (e.g. input/output are 0 or not set)
// fall back to checking config.tokenQuota which is a number (total limit).
if (!limitConfig || (limitConfig.input === 0 && limitConfig.output === 0 && !config.tokenQuota[modelFamily])) {
return true; // No effective limit
}
const tokensConsumed = (tokenCounts[modelFamily] ?? 0) + requested; let effectiveLimit: number;
return tokensConsumed < tokenLimit; if (limitConfig && (limitConfig.input > 0 || limitConfig.output > 0)) {
// If a specific limit object exists and has positive values, sum them.
// This assumes the limit is a total limit. If input/output are separate, this logic needs change.
effectiveLimit = (limitConfig.input ?? Number.MAX_SAFE_INTEGER) + (limitConfig.output ?? Number.MAX_SAFE_INTEGER);
} else {
// Fallback to general numeric quota from config if specific limitObj is not effectively set.
const generalQuota = config.tokenQuota[modelFamily];
if (typeof generalQuota === 'number' && generalQuota > 0) {
effectiveLimit = generalQuota;
} else {
return true; // No limit defined
}
}
// Assuming 'requested' is for input tokens. If 'requested' can be input or output,
// this needs to be an object {input: number, output: number}.
// For now, we sum current input & output and add 'requested' to input for checking.
// This is a simplification. A more robust solution would involve 'requested' being an object.
const totalConsumed = (currentUsage.input ?? 0) + (currentUsage.output ?? 0) + requested;
return totalConsumed < effectiveLimit;
} }
/** /**
@@ -270,18 +365,33 @@ export function refreshQuota(token: string) {
const { tokenQuota } = config; const { tokenQuota } = config;
const { tokenCounts, tokenLimits, tokenRefresh } = user; const { tokenCounts, tokenLimits, tokenRefresh } = user;
// Get default quotas for each model family. for (const family of MODEL_FAMILIES) {
const defaultQuotas = Object.entries(tokenQuota) as [ModelFamily, number][]; const currentUsage = tokenCounts[family] ?? { input: 0, output: 0 };
// If any user-specific refresh quotas are present, override default quotas. const userRefreshConfig = tokenRefresh[family] ?? { input: 0, output: 0 };
const userQuotas = defaultQuotas.map( const globalDefaultQuotaValue = config.tokenQuota[family]; // This is a number or undefined
([f, q]) => [f, (tokenRefresh[f] ?? 0) || q] as const /* narrow to tuple */
);
userQuotas let refreshInputAmount = 0;
// Ignore families with no global or user-specific refresh quota. let refreshOutputAmount = 0;
.filter(([, q]) => q > 0)
// Increase family token limit by the family's refresh amount. // Prioritize user-specific refresh amounts if they are positive
.forEach(([f, q]) => (tokenLimits[f] = (tokenCounts[f] ?? 0) + q)); if (userRefreshConfig.input > 0 || userRefreshConfig.output > 0) {
refreshInputAmount = userRefreshConfig.input;
refreshOutputAmount = userRefreshConfig.output;
} else if (typeof globalDefaultQuotaValue === 'number' && globalDefaultQuotaValue > 0) {
// If no user-specific refresh, use the global quota.
// Distribute the global quota. For simplicity, add to input, or define a rule.
// Here, let's assume the global quota is a total that primarily refreshes 'input'.
refreshInputAmount = globalDefaultQuotaValue;
refreshOutputAmount = 0; // Or some portion of globalDefaultQuotaValue
}
if (refreshInputAmount > 0 || refreshOutputAmount > 0) {
tokenLimits[family] = {
input: (currentUsage.input ?? 0) + refreshInputAmount,
output: (currentUsage.output ?? 0) + refreshOutputAmount,
};
}
}
usersToFlush.add(token); usersToFlush.add(token);
} }
@@ -289,8 +399,9 @@ export function resetUsage(token: string) {
const user = users.get(token); const user = users.get(token);
if (!user) return; if (!user) return;
const { tokenCounts } = user; const { tokenCounts } = user;
const counts = Object.entries(tokenCounts) as [ModelFamily, number][]; for (const family of MODEL_FAMILIES) {
counts.forEach(([model]) => (tokenCounts[model] = 0)); tokenCounts[family] = { input: 0, output: 0 }; // legacy_total is implicitly undefined/removed
}
usersToFlush.add(token); usersToFlush.add(token);
} }
@@ -359,26 +470,56 @@ function refreshAllQuotas() {
// store to sync it with Firebase when it changes. Will refactor to abstract // store to sync it with Firebase when it changes. Will refactor to abstract
// persistence layer later so we can support multiple stores. // persistence layer later so we can support multiple stores.
let firebaseTimeout: NodeJS.Timeout | undefined; let firebaseTimeout: NodeJS.Timeout | undefined;
let sqliteInterval: NodeJS.Timeout | undefined; // Added
let flushingToSQLiteInProgress = false; // Added for JS-level lock
const USERS_REF = process.env.FIREBASE_USERS_REF_NAME ?? "users"; const USERS_REF = process.env.FIREBASE_USERS_REF_NAME ?? "users";
async function initSQLite() { // Added
log.info("Initializing SQLite user store...");
initSQLiteDB(); // Initialize the DB connection and schema
await loadUsersFromSQLite();
// Set up periodic flush for SQLite, similar to Firebase
sqliteInterval = setInterval(flushUsersToSQLite, 20 * 1000);
log.info("SQLite user store initialized and users loaded.");
}
async function initFirebase() { async function initFirebase() {
log.info("Connecting to Firebase..."); log.info("Connecting to Firebase...");
const app = getFirebaseApp(); const app = getFirebaseApp();
const db = admin.database(app); const db = admin.database(app);
const usersRef = db.ref(USERS_REF); const usersRef = db.ref(USERS_REF);
const snapshot = await usersRef.once("value"); const snapshot = await usersRef.once("value");
const users: Record<string, User> | null = snapshot.val(); const usersData: Record<string, any> | null = snapshot.val(); // Store as 'any' initially for migration
firebaseTimeout = setInterval(flushUsers, 20 * 1000); firebaseTimeout = setInterval(flushUsers, 20 * 1000);
if (!users) {
if (!usersData) {
log.info("No users found in Firebase."); log.info("No users found in Firebase.");
return; return;
} }
for (const token in users) {
upsertUser(users[token]); // migrateTokenCountsProperty is now defined at module scope
for (const token in usersData) {
const rawUser = usersData[token];
const migratedUser: User = {
...rawUser, // Spread existing fields
token: rawUser.token || token, // Ensure token is present
ip: rawUser.ip || [],
type: rawUser.type || "normal",
promptCount: rawUser.promptCount || 0,
createdAt: rawUser.createdAt || Date.now(),
// Migrate token fields
tokenCounts: migrateTokenCountsProperty(rawUser.tokenCounts, INITIAL_TOKENS),
tokenLimits: migrateTokenCountsProperty(rawUser.tokenLimits, config.tokenQuota),
tokenRefresh: migrateTokenCountsProperty(rawUser.tokenRefresh, INITIAL_TOKENS),
meta: rawUser.meta || {},
};
// Use the internal map directly to avoid re-triggering upsertUser's default creations
users.set(token, migratedUser);
} }
usersToFlush.clear(); usersToFlush.clear(); // Clear flush queue after initial load and migration
const numUsers = Object.keys(users).length; const numUsers = Object.keys(usersData).length;
log.info({ users: numUsers }, "Loaded users from Firebase"); log.info({ users: numUsers }, "Loaded and migrated users from Firebase");
} }
async function flushUsers() { async function flushUsers() {
@@ -412,6 +553,128 @@ async function flushUsers() {
); );
} }
async function loadUsersFromSQLite() { // Added
log.info("Loading users from SQLite...");
const db = getDB();
const rows = db.prepare("SELECT * FROM users").all() as any[];
for (const row of rows) {
const rawTokenCounts = row.tokenCounts ? JSON.parse(row.tokenCounts) : null;
const rawTokenLimits = row.tokenLimits ? JSON.parse(row.tokenLimits) : null;
const rawTokenRefresh = row.tokenRefresh ? JSON.parse(row.tokenRefresh) : null;
const user: User = {
token: row.token,
ip: row.ip ? JSON.parse(row.ip) : [],
nickname: row.nickname,
type: row.type,
promptCount: row.promptCount,
tokenCounts: migrateTokenCountsProperty(rawTokenCounts, INITIAL_TOKENS),
tokenLimits: migrateTokenCountsProperty(rawTokenLimits, config.tokenQuota),
tokenRefresh: migrateTokenCountsProperty(rawTokenRefresh, INITIAL_TOKENS),
createdAt: row.createdAt,
lastUsedAt: row.lastUsedAt,
disabledAt: row.disabledAt,
disabledReason: row.disabledReason,
expiresAt: row.expiresAt,
maxIps: row.maxIps,
adminNote: row.adminNote,
meta: row.meta ? JSON.parse(row.meta) : {},
};
users.set(user.token, user);
}
usersToFlush.clear(); // Clear flush queue after initial load
log.info({ users: users.size }, "Loaded users from SQLite.");
}
async function flushUsersToSQLite() { // Added
if (flushingToSQLiteInProgress) {
log.trace("Flush to SQLite already in progress, skipping.");
return;
}
if (usersToFlush.size === 0) {
return;
}
flushingToSQLiteInProgress = true;
log.trace({ count: usersToFlush.size }, "Starting flush to SQLite.");
const db = getDB();
const insertStmt = db.prepare(`
INSERT OR REPLACE INTO users (
token, ip, nickname, type, promptCount, tokenCounts, tokenLimits,
tokenRefresh, createdAt, lastUsedAt, disabledAt, disabledReason,
expiresAt, maxIps, adminNote, meta
) VALUES (
@token, @ip, @nickname, @type, @promptCount, @tokenCounts, @tokenLimits,
@tokenRefresh, @createdAt, @lastUsedAt, @disabledAt, @disabledReason,
@expiresAt, @maxIps, @adminNote, @meta
)
`);
const deleteStmt = db.prepare("DELETE FROM users WHERE token = ?");
let updatedCount = 0;
let deletedCount = 0;
const transaction = db.transaction(() => {
for (const token of usersToFlush) {
const user = users.get(token);
if (user) {
insertStmt.run({
token: user.token,
ip: JSON.stringify(user.ip || []),
nickname: user.nickname ?? null,
type: user.type,
promptCount: user.promptCount,
tokenCounts: JSON.stringify(user.tokenCounts || INITIAL_TOKENS),
tokenLimits: JSON.stringify(user.tokenLimits || migrateTokenCountsProperty(null, config.tokenQuota)),
tokenRefresh: JSON.stringify(user.tokenRefresh || INITIAL_TOKENS),
createdAt: user.createdAt,
lastUsedAt: user.lastUsedAt ?? null,
disabledAt: user.disabledAt ?? null,
disabledReason: user.disabledReason ?? null,
expiresAt: user.expiresAt ?? null,
maxIps: user.maxIps ?? null,
adminNote: user.adminNote ?? null,
meta: JSON.stringify(user.meta || {}),
});
updatedCount++;
} else {
// User was deleted from in-memory map
deleteStmt.run(token);
deletedCount++;
}
}
});
try {
transaction();
usersToFlush.clear();
if (updatedCount > 0 || deletedCount > 0) {
log.info({ updated: updatedCount, deleted: deletedCount }, "Flushed user changes to SQLite.");
}
} catch (error: any) {
log.error({
message: error?.message || "Unknown error during SQLite flush",
stack: error?.stack,
code: error?.code, // SQLite errors often have a code
rawError: error // Log the raw error object for more details
}, "Error flushing users to SQLite.");
// Re-add tokens to flush queue if transaction failed, so we can retry
// This is a simplistic retry, might need more robust error handling
// Ensure usersToFlush still contains the tokens that failed to process
// The current logic inside the transaction means usersToFlush is cleared only on success.
// If transaction fails, usersToFlush would still contain the items from before the attempt.
// However, if items were added to usersToFlush *during* the failed transaction,
// they would be processed in the next attempt.
// For simplicity, the current re-add logic is okay, but could be refined if specific
// tokens fail consistently.
usersToFlush.forEach(token => usersToFlush.add(token));
} finally {
flushingToSQLiteInProgress = false;
log.trace("Finished flush to SQLite attempt.");
}
}
function getModelFamilyForQuotaUsage( function getModelFamilyForQuotaUsage(
model: string, model: string,
api: APIFormat api: APIFormat
@@ -22,23 +22,64 @@ const quotaTableId = Math.random().toString(36).slice(2);
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
<% Object.entries(quota).forEach(([key, limit]) => { %> <% Object.entries(quota).forEach(([key, configLimit]) => { %>
<%
const counts = user.tokenCounts[key] || { input: 0, output: 0 };
const limits = user.tokenLimits[key] || { input: 0, output: 0 }; // Default if not set
const refresh = user.tokenRefresh[key] || { input: 0, output: 0 };
const usageInput = Number(counts.input) || 0;
const usageOutput = Number(counts.output) || 0;
const usageLegacy = Number(counts.legacy_total) || 0;
const displayUsage = usageInput + usageOutput || usageLegacy; // This is for total token display, not directly for cost calculation here
const limitInput = Number(limits.input) || 0;
// If limit was from legacy config.tokenQuota (a number), it's in limits.legacy_total or limits.input
const displayLimit = limitInput || Number(limits.legacy_total) || 0;
// Determine tokens to use for cost calculation
const costInputTokens = (usageInput + usageOutput > 0) ? usageInput : usageLegacy;
const costOutputTokens = (usageInput + usageOutput > 0) ? usageOutput : 0; // If using legacy, output is 0 for cost
const costDetails = tokenCostDetails(key, costInputTokens, costOutputTokens);
let remaining = 0;
let limitIsSet = false;
if (displayLimit > 0) {
remaining = displayLimit - (usageInput + usageOutput);
limitIsSet = true;
} else if (typeof configLimit === 'number' && configLimit > 0) {
// Fallback to global config limit if user-specific limit is 0 or not set meaningfully
remaining = configLimit - (usageInput + usageOutput);
limitIsSet = true;
}
const refreshDisplayValue = (Number(refresh.input) || 0) + (Number(refresh.output) || 0) || configLimit || 0;
%>
<tr> <tr>
<th scope="row"><%- key %></th> <th scope="row"><%- key %></th>
<td><%- prettyTokens(user.tokenCounts[key]) %></td> <td>
In: <%- prettyTokens(usageInput) %><br/>
Out: <%- prettyTokens(usageOutput) %>
<% if (usageLegacy && (usageInput + usageOutput === 0)) { %><br/>(Legacy: <%- prettyTokens(usageLegacy) %>)<% } %>
</td>
<% if (showTokenCosts) { %> <% if (showTokenCosts) { %>
<td>$<%- tokenCost(key, user.tokenCounts[key]).toFixed(2) %></td> <td>
In: $<%- costDetails.inputCost.toFixed(Math.max(2, (costDetails.inputCost.toString().split('.')[1] || '').length)) %><br/>
Out: $<%- costDetails.outputCost.toFixed(Math.max(2, (costDetails.outputCost.toString().split('.')[1] || '').length)) %><br/>
Total: $<%- costDetails.totalCost.toFixed(2) %>
</td>
<% } %> <% } %>
<% if (!user.tokenLimits[key]) { %> <% if (!limitIsSet) { %>
<td colspan="2" style="text-align: center">unlimited</td> <td colspan="2" style="text-align: center">unlimited</td>
<% } else { %> <% } else { %>
<td><%- prettyTokens(user.tokenLimits[key]) %></td> <td><%- prettyTokens(displayLimit) %></td>
<td><%- prettyTokens(user.tokenLimits[key] - user.tokenCounts[key]) %></td> <td><%- prettyTokens(remaining) %></td>
<% } %> <% } %>
<% if (user.type === "temporary") { %> <% if (user.type === "temporary") { %>
<td>N/A</td> <td>N/A</td>
<% } else { %> <% } else { %>
<td><%- prettyTokens(user.tokenRefresh[key] || quota[key]) %></td> <td><%- prettyTokens(refreshDisplayValue) %></td>
<% } %> <% } %>
<% if (showRefreshEdit) { %> <% if (showRefreshEdit) { %>
<td class="actions"> <td class="actions">