Implement user persistence via Firebase (khanon/oai-reverse-proxy!8)
This commit is contained in:
@@ -1,204 +0,0 @@
|
||||
# Shat out by GPT-4, I did not check for correctness beyond a cursory glance
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
version: 1.0.0
|
||||
title: User Management API
|
||||
paths:
|
||||
/admin/users:
|
||||
get:
|
||||
summary: List all users
|
||||
operationId: getUsers
|
||||
responses:
|
||||
"200":
|
||||
description: A list of users
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
users:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/User"
|
||||
count:
|
||||
type: integer
|
||||
format: int32
|
||||
post:
|
||||
summary: Create a new user
|
||||
operationId: createUser
|
||||
responses:
|
||||
"200":
|
||||
description: The created user's token
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
put:
|
||||
summary: Bulk upsert users
|
||||
operationId: bulkUpsertUsers
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
users:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/User"
|
||||
responses:
|
||||
"200":
|
||||
description: The upserted users
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
upserted_users:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/User"
|
||||
count:
|
||||
type: integer
|
||||
format: int32
|
||||
"400":
|
||||
description: Bad request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
|
||||
/admin/users/{token}:
|
||||
get:
|
||||
summary: Get a user by token
|
||||
operationId: getUser
|
||||
parameters:
|
||||
- name: token
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: A user
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/User"
|
||||
"404":
|
||||
description: Not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
put:
|
||||
summary: Update a user by token
|
||||
operationId: upsertUser
|
||||
parameters:
|
||||
- name: token
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/User"
|
||||
responses:
|
||||
"200":
|
||||
description: The updated user
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/User"
|
||||
"400":
|
||||
description: Bad request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
delete:
|
||||
summary: Disables the user with the given token
|
||||
description: Optionally accepts a `disabledReason` query parameter. Returns the disabled user.
|
||||
parameters:
|
||||
- in: path
|
||||
name: token
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The token of the user to disable
|
||||
- in: query
|
||||
name: disabledReason
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
description: The reason for disabling the user
|
||||
responses:
|
||||
'200':
|
||||
description: The disabled user
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/User'
|
||||
'400':
|
||||
description: Bad request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
'404':
|
||||
description: Not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
components:
|
||||
schemas:
|
||||
User:
|
||||
type: object
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
ip:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
enum: ["normal", "special"]
|
||||
promptCount:
|
||||
type: integer
|
||||
format: int32
|
||||
tokenCount:
|
||||
type: integer
|
||||
format: int32
|
||||
createdAt:
|
||||
type: integer
|
||||
format: int64
|
||||
lastUsedAt:
|
||||
type: integer
|
||||
format: int64
|
||||
disabledAt:
|
||||
type: integer
|
||||
format: int64
|
||||
disabledReason:
|
||||
type: string
|
||||
+69
-9
@@ -1,4 +1,5 @@
|
||||
import dotenv from "dotenv";
|
||||
import type firebase from "firebase-admin";
|
||||
dotenv.config();
|
||||
|
||||
const isDev = process.env.NODE_ENV !== "production";
|
||||
@@ -17,7 +18,7 @@ type Config = {
|
||||
**/
|
||||
proxyKey?: string;
|
||||
/**
|
||||
* The admin key to used for accessing the /admin API. Required if the user
|
||||
* The admin key used to access the /admin API. Required if the user
|
||||
* management mode is set to 'user_token'.
|
||||
**/
|
||||
adminKey?: string;
|
||||
@@ -35,6 +36,19 @@ type Config = {
|
||||
* Configure this function and add users via the /admin API.
|
||||
*/
|
||||
gatekeeper: "none" | "proxy_key" | "user_token";
|
||||
/**
|
||||
* Persistence layer to use for user management.
|
||||
*
|
||||
* `memory`: Users are stored in memory and are lost on restart (default)
|
||||
*
|
||||
* `firebase_rtdb`: Users are stored in a Firebase Realtime Database; requires
|
||||
* `firebaseKey` and `firebaseRtdbUrl` to be set.
|
||||
**/
|
||||
gatekeeperStore: "memory" | "firebase_rtdb";
|
||||
/** URL of the Firebase Realtime Database if using the Firebase RTDB store. */
|
||||
firebaseRtdbUrl?: string;
|
||||
/** Base64-encoded Firebase service account key if using the Firebase RTDB store. */
|
||||
firebaseKey?: string;
|
||||
/** Per-IP limit for requests per minute to OpenAI's completions endpoint. */
|
||||
modelRateLimit: number;
|
||||
/** Max number of tokens to generate. Requests which specify a higher value will be rewritten to use this value. */
|
||||
@@ -58,21 +72,21 @@ type Config = {
|
||||
/**
|
||||
* How to display quota information on the info page.
|
||||
*
|
||||
* `none` - Hide quota information
|
||||
* `none`: Hide quota information
|
||||
*
|
||||
* `partial` - Display quota information only as a percentage
|
||||
* `partial`: Display quota information only as a percentage
|
||||
*
|
||||
* `full` - Display quota information as usage against total capacity
|
||||
* `full`: Display quota information as usage against total capacity
|
||||
*/
|
||||
quotaDisplayMode: "none" | "partial" | "full";
|
||||
/**
|
||||
* Which request queueing strategy to use when keys are over their rate limit.
|
||||
*
|
||||
* `fair` - Requests are serviced in the order they were received (default)
|
||||
* `fair`: Requests are serviced in the order they were received (default)
|
||||
*
|
||||
* `random` - Requests are serviced randomly
|
||||
* `random`: Requests are serviced randomly
|
||||
*
|
||||
* `none` - Requests are not queued and users have to retry manually
|
||||
* `none`: Requests are not queued and users have to retry manually
|
||||
*/
|
||||
queueMode: DequeueMode;
|
||||
};
|
||||
@@ -85,6 +99,9 @@ export const config: Config = {
|
||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
||||
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
|
||||
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"),
|
||||
firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined),
|
||||
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
|
||||
modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4),
|
||||
maxOutputTokens: getEnvWithDefault("MAX_OUTPUT_TOKENS", 300),
|
||||
rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false),
|
||||
@@ -106,7 +123,7 @@ export const config: Config = {
|
||||
} as const;
|
||||
|
||||
/** Prevents the server from starting if config state is invalid. */
|
||||
export function assertConfigIsValid(): void {
|
||||
export async function assertConfigIsValid() {
|
||||
// Ensure gatekeeper mode is valid.
|
||||
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
|
||||
throw new Error(
|
||||
@@ -134,12 +151,29 @@ export function assertConfigIsValid(): void {
|
||||
"`PROXY_KEY` is set, but gatekeeper mode is not `proxy_key`. Make sure to set `GATEKEEPER=proxy_key`."
|
||||
);
|
||||
}
|
||||
|
||||
// Require appropriate firebase config if using firebase store.
|
||||
if (
|
||||
config.gatekeeperStore === "firebase_rtdb" &&
|
||||
(!config.firebaseKey || !config.firebaseRtdbUrl)
|
||||
) {
|
||||
throw new Error(
|
||||
"Firebase RTDB store requires `FIREBASE_KEY` and `FIREBASE_RTDB_URL` to be set."
|
||||
);
|
||||
}
|
||||
|
||||
await maybeInitializeFirebase();
|
||||
}
|
||||
|
||||
/** Masked, but not omitted as users may wish to see if they're set. */
|
||||
/**
|
||||
* Masked, but not omitted as users may wish to see if they're set due to their
|
||||
* implications on privacy.
|
||||
*/
|
||||
export const SENSITIVE_KEYS: (keyof Config)[] = [
|
||||
"googleSheetsKey",
|
||||
"googleSheetsSpreadsheetId",
|
||||
"firebaseRtdbUrl",
|
||||
"firebaseKey",
|
||||
];
|
||||
|
||||
/** Omitted as they're not useful to display, masked or not. */
|
||||
@@ -184,3 +218,29 @@ function getEnvWithDefault<T>(name: string, defaultValue: T): T {
|
||||
return value as unknown as T;
|
||||
}
|
||||
}
|
||||
|
||||
let firebaseApp: firebase.app.App | undefined;
|
||||
|
||||
async function maybeInitializeFirebase() {
|
||||
if (!config.gatekeeperStore.startsWith("firebase")) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firebase = await import("firebase-admin");
|
||||
const firebaseKey = Buffer.from(config.firebaseKey!, "base64").toString();
|
||||
const app = firebase.initializeApp({
|
||||
credential: firebase.credential.cert(JSON.parse(firebaseKey)),
|
||||
databaseURL: config.firebaseRtdbUrl,
|
||||
});
|
||||
|
||||
await app.database().ref("connection-test").set(Date.now());
|
||||
|
||||
firebaseApp = app;
|
||||
}
|
||||
|
||||
export function getFirebaseApp(): firebase.app.App {
|
||||
if (!firebaseApp) {
|
||||
throw new Error("Firebase app not initialized.");
|
||||
}
|
||||
return firebaseApp;
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ export class KeyChecker {
|
||||
}
|
||||
|
||||
public start() {
|
||||
this.log.info("Starting key checker");
|
||||
this.log.info("Starting key checker...");
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
/**
|
||||
* Basic user management. Handles creation and tracking of proxy users, personal
|
||||
* access tokens, and quota management. No persistence is provided, users must
|
||||
* be re-created on each proxy start via the /admin API.
|
||||
* access tokens, and quota management. Supports in-memory and Firebase Realtime
|
||||
* Database persistence stores.
|
||||
*
|
||||
* Users are identified solely by their personal access token. The token is
|
||||
* used to authenticate the user for all proxied requests.
|
||||
*/
|
||||
|
||||
import admin from "firebase-admin";
|
||||
import { v4 as uuid } from "uuid";
|
||||
import { config, getFirebaseApp } from "../../config";
|
||||
import { logger } from "../../logger";
|
||||
|
||||
export interface User {
|
||||
/** The user's personal access token. */
|
||||
@@ -41,6 +44,15 @@ export type UserType = "normal" | "special";
|
||||
type UserUpdate = Partial<User> & Pick<User, "token">;
|
||||
|
||||
const users: Map<string, User> = new Map();
|
||||
const usersToFlush = new Set<string>();
|
||||
|
||||
export async function init() {
|
||||
logger.info({ store: config.gatekeeperStore }, "Initializing user store...");
|
||||
if (config.gatekeeperStore === "firebase_rtdb") {
|
||||
await initFirebase();
|
||||
}
|
||||
logger.info("User store initialized.");
|
||||
}
|
||||
|
||||
/** Creates a new user and returns their token. */
|
||||
export function createUser() {
|
||||
@@ -84,6 +96,13 @@ export function upsertUser(user: UserUpdate) {
|
||||
...existing,
|
||||
...user,
|
||||
});
|
||||
usersToFlush.add(user.token);
|
||||
|
||||
// Immediately schedule a flush to the database if we're using Firebase.
|
||||
if (config.gatekeeperStore === "firebase_rtdb") {
|
||||
setImmediate(flushUsers);
|
||||
}
|
||||
|
||||
return users.get(user.token);
|
||||
}
|
||||
|
||||
@@ -92,6 +111,7 @@ export function incrementPromptCount(token: string) {
|
||||
const user = users.get(token);
|
||||
if (!user) return;
|
||||
user.promptCount++;
|
||||
usersToFlush.add(token);
|
||||
}
|
||||
|
||||
/** Increments the token count for the given user by the given amount. */
|
||||
@@ -99,6 +119,7 @@ export function incrementTokenCount(token: string, amount = 1) {
|
||||
const user = users.get(token);
|
||||
if (!user) return;
|
||||
user.tokenCount += amount;
|
||||
usersToFlush.add(token);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -111,6 +132,7 @@ export function authenticate(token: string, ip: string) {
|
||||
if (!user || user.disabledAt) return;
|
||||
if (!user.ip.includes(ip)) user.ip.push(ip);
|
||||
user.lastUsedAt = Date.now();
|
||||
usersToFlush.add(token);
|
||||
return user;
|
||||
}
|
||||
|
||||
@@ -120,4 +142,58 @@ export function disableUser(token: string, reason?: string) {
|
||||
if (!user) return;
|
||||
user.disabledAt = Date.now();
|
||||
user.disabledReason = reason;
|
||||
usersToFlush.add(token);
|
||||
}
|
||||
|
||||
// TODO: Firebase persistence is pretend right now and just polls the in-memory
|
||||
// store to sync it with Firebase when it changes. Will refactor to abstract
|
||||
// persistence layer later so we can support multiple stores.
|
||||
let firebaseTimeout: NodeJS.Timeout | undefined;
|
||||
|
||||
async function initFirebase() {
|
||||
logger.info("Connecting to Firebase...");
|
||||
const app = getFirebaseApp();
|
||||
const db = admin.database(app);
|
||||
const usersRef = db.ref("users");
|
||||
const snapshot = await usersRef.once("value");
|
||||
const users: Record<string, User> | null = snapshot.val();
|
||||
firebaseTimeout = setInterval(flushUsers, 20 * 1000);
|
||||
if (!users) {
|
||||
logger.info("No users found in Firebase.");
|
||||
return;
|
||||
}
|
||||
for (const token in users) {
|
||||
upsertUser(users[token]);
|
||||
}
|
||||
usersToFlush.clear();
|
||||
const numUsers = Object.keys(users).length;
|
||||
logger.info({ users: numUsers }, "Loaded users from Firebase");
|
||||
}
|
||||
|
||||
async function flushUsers() {
|
||||
const app = getFirebaseApp();
|
||||
const db = admin.database(app);
|
||||
const usersRef = db.ref("users");
|
||||
const updates: Record<string, User> = {};
|
||||
|
||||
for (const token of usersToFlush) {
|
||||
const user = users.get(token);
|
||||
if (!user) {
|
||||
continue;
|
||||
}
|
||||
updates[token] = user;
|
||||
}
|
||||
|
||||
usersToFlush.clear();
|
||||
|
||||
const numUpdates = Object.keys(updates).length;
|
||||
if (numUpdates === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
await usersRef.update(updates);
|
||||
logger.info(
|
||||
{ users: Object.keys(updates).length },
|
||||
"Flushed users to Firebase"
|
||||
);
|
||||
}
|
||||
|
||||
+1
-1
@@ -191,7 +191,7 @@ function cleanQueue() {
|
||||
(waitTime) => now - waitTime.end > 90 * 1000
|
||||
);
|
||||
const removed = waitTimes.splice(0, index + 1);
|
||||
log.info(
|
||||
log.debug(
|
||||
{ stalledRequests: oldRequests.length, prunedWaitTimes: removed.length },
|
||||
`Cleaning up request queue.`
|
||||
);
|
||||
|
||||
+52
-31
@@ -11,22 +11,10 @@ import { proxyRouter, rewriteTavernRequests } from "./proxy/routes";
|
||||
import { handleInfoPage } from "./info-page";
|
||||
import { logQueue } from "./prompt-logging";
|
||||
import { start as startRequestQueue } from "./proxy/queue";
|
||||
import { init as initUserStore } from "./proxy/auth/user-store";
|
||||
|
||||
const PORT = config.port;
|
||||
|
||||
process.on("uncaughtException", (err: any) => {
|
||||
logger.error(
|
||||
{ err, stack: err?.stack },
|
||||
"UNCAUGHT EXCEPTION. Please report this error trace."
|
||||
);
|
||||
});
|
||||
process.on("unhandledRejection", (err: any) => {
|
||||
logger.error(
|
||||
{ err, stack: err?.stack },
|
||||
"UNCAUGHT PROMISE REJECTION. Please report this error trace."
|
||||
);
|
||||
});
|
||||
|
||||
const app = express();
|
||||
// middleware
|
||||
app.use("/", rewriteTavernRequests);
|
||||
@@ -88,10 +76,56 @@ app.use((_req: unknown, res: express.Response) => {
|
||||
res.status(404).json({ error: "Not found" });
|
||||
});
|
||||
|
||||
// start server and load keys
|
||||
app.listen(PORT, async () => {
|
||||
assertConfigIsValid();
|
||||
async function start() {
|
||||
logger.info("Server starting up...");
|
||||
setGitSha();
|
||||
|
||||
logger.info("Checking configs and external dependencies...");
|
||||
await assertConfigIsValid();
|
||||
|
||||
keyPool.init();
|
||||
|
||||
if (config.gatekeeper === "user_token") {
|
||||
await initUserStore();
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
logger.info("Starting prompt logging...");
|
||||
logQueue.start();
|
||||
}
|
||||
|
||||
if (config.queueMode !== "none") {
|
||||
logger.info("Starting request queue...");
|
||||
startRequestQueue();
|
||||
}
|
||||
|
||||
app.listen(PORT, async () => {
|
||||
logger.info({ port: PORT }, "Now listening for connections.");
|
||||
registerUncaughtExceptionHandler();
|
||||
});
|
||||
|
||||
logger.info(
|
||||
{ sha: process.env.COMMIT_SHA, nodeEnv: process.env.NODE_ENV },
|
||||
"Startup complete."
|
||||
);
|
||||
}
|
||||
|
||||
function registerUncaughtExceptionHandler() {
|
||||
process.on("uncaughtException", (err: any) => {
|
||||
logger.error(
|
||||
{ err, stack: err?.stack },
|
||||
"UNCAUGHT EXCEPTION. Please report this error trace."
|
||||
);
|
||||
});
|
||||
process.on("unhandledRejection", (err: any) => {
|
||||
logger.error(
|
||||
{ err, stack: err?.stack },
|
||||
"UNCAUGHT PROMISE REJECTION. Please report this error trace."
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
function setGitSha() {
|
||||
try {
|
||||
// Huggingface seems to have changed something about how they deploy Spaces
|
||||
// and git commands fail because of some ownership issue with the .git
|
||||
@@ -132,19 +166,6 @@ app.listen(PORT, async () => {
|
||||
);
|
||||
process.env.COMMIT_SHA = "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
{ sha: process.env.COMMIT_SHA },
|
||||
`Server listening on port ${PORT}`
|
||||
);
|
||||
keyPool.init();
|
||||
|
||||
if (config.promptLogging) {
|
||||
logger.info("Starting prompt logging...");
|
||||
logQueue.start();
|
||||
}
|
||||
if (config.queueMode !== "none") {
|
||||
logger.info("Starting request queue...");
|
||||
startRequestQueue();
|
||||
}
|
||||
});
|
||||
start();
|
||||
|
||||
Reference in New Issue
Block a user