Implement user persistence via Firebase (khanon/oai-reverse-proxy!8)

This commit is contained in:
nai-degen
2023-05-14 04:26:08 +00:00
parent 7126fb6c6c
commit f1ac64fa12
10 changed files with 1776 additions and 88 deletions
-204
View File
@@ -1,204 +0,0 @@
# Shat out by GPT-4, I did not check for correctness beyond a cursory glance
openapi: 3.0.0
info:
version: 1.0.0
title: User Management API
paths:
/admin/users:
get:
summary: List all users
operationId: getUsers
responses:
"200":
description: A list of users
content:
application/json:
schema:
type: object
properties:
users:
type: array
items:
$ref: "#/components/schemas/User"
count:
type: integer
format: int32
post:
summary: Create a new user
operationId: createUser
responses:
"200":
description: The created user's token
content:
application/json:
schema:
type: object
properties:
token:
type: string
put:
summary: Bulk upsert users
operationId: bulkUpsertUsers
requestBody:
content:
application/json:
schema:
type: object
properties:
users:
type: array
items:
$ref: "#/components/schemas/User"
responses:
"200":
description: The upserted users
content:
application/json:
schema:
type: object
properties:
upserted_users:
type: array
items:
$ref: "#/components/schemas/User"
count:
type: integer
format: int32
"400":
description: Bad request
content:
application/json:
schema:
type: object
properties:
error:
type: string
/admin/users/{token}:
get:
summary: Get a user by token
operationId: getUser
parameters:
- name: token
in: path
required: true
schema:
type: string
responses:
"200":
description: A user
content:
application/json:
schema:
$ref: "#/components/schemas/User"
"404":
description: Not found
content:
application/json:
schema:
type: object
properties:
error:
type: string
put:
summary: Update a user by token
operationId: upsertUser
parameters:
- name: token
in: path
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: "#/components/schemas/User"
responses:
"200":
description: The updated user
content:
application/json:
schema:
$ref: "#/components/schemas/User"
"400":
description: Bad request
content:
application/json:
schema:
type: object
properties:
error:
type: string
delete:
summary: Disables the user with the given token
description: Optionally accepts a `disabledReason` query parameter. Returns the disabled user.
parameters:
- in: path
name: token
required: true
schema:
type: string
description: The token of the user to disable
- in: query
name: disabledReason
required: false
schema:
type: string
description: The reason for disabling the user
responses:
'200':
description: The disabled user
content:
application/json:
schema:
$ref: '#/components/schemas/User'
'400':
description: Bad request
content:
application/json:
schema:
type: object
properties:
error:
type: string
'404':
description: Not found
content:
application/json:
schema:
type: object
properties:
error:
type: string
components:
schemas:
User:
type: object
properties:
token:
type: string
ip:
type: array
items:
type: string
type:
type: string
enum: ["normal", "special"]
promptCount:
type: integer
format: int32
tokenCount:
type: integer
format: int32
createdAt:
type: integer
format: int64
lastUsedAt:
type: integer
format: int64
disabledAt:
type: integer
format: int64
disabledReason:
type: string
+69 -9
View File
@@ -1,4 +1,5 @@
import dotenv from "dotenv";
import type firebase from "firebase-admin";
dotenv.config();
const isDev = process.env.NODE_ENV !== "production";
@@ -17,7 +18,7 @@ type Config = {
**/
proxyKey?: string;
/**
* The admin key to used for accessing the /admin API. Required if the user
* The admin key used to access the /admin API. Required if the user
* management mode is set to 'user_token'.
**/
adminKey?: string;
@@ -35,6 +36,19 @@ type Config = {
* Configure this function and add users via the /admin API.
*/
gatekeeper: "none" | "proxy_key" | "user_token";
/**
* Persistence layer to use for user management.
*
* `memory`: Users are stored in memory and are lost on restart (default)
*
* `firebase_rtdb`: Users are stored in a Firebase Realtime Database; requires
* `firebaseKey` and `firebaseRtdbUrl` to be set.
**/
gatekeeperStore: "memory" | "firebase_rtdb";
/** URL of the Firebase Realtime Database if using the Firebase RTDB store. */
firebaseRtdbUrl?: string;
/** Base64-encoded Firebase service account key if using the Firebase RTDB store. */
firebaseKey?: string;
/** Per-IP limit for requests per minute to OpenAI's completions endpoint. */
modelRateLimit: number;
/** Max number of tokens to generate. Requests which specify a higher value will be rewritten to use this value. */
@@ -58,21 +72,21 @@ type Config = {
/**
* How to display quota information on the info page.
*
* `none` - Hide quota information
* `none`: Hide quota information
*
* `partial` - Display quota information only as a percentage
* `partial`: Display quota information only as a percentage
*
* `full` - Display quota information as usage against total capacity
* `full`: Display quota information as usage against total capacity
*/
quotaDisplayMode: "none" | "partial" | "full";
/**
* Which request queueing strategy to use when keys are over their rate limit.
*
* `fair` - Requests are serviced in the order they were received (default)
* `fair`: Requests are serviced in the order they were received (default)
*
* `random` - Requests are serviced randomly
* `random`: Requests are serviced randomly
*
* `none` - Requests are not queued and users have to retry manually
* `none`: Requests are not queued and users have to retry manually
*/
queueMode: DequeueMode;
};
@@ -85,6 +99,9 @@ export const config: Config = {
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"),
firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined),
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4),
maxOutputTokens: getEnvWithDefault("MAX_OUTPUT_TOKENS", 300),
rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false),
@@ -106,7 +123,7 @@ export const config: Config = {
} as const;
/** Prevents the server from starting if config state is invalid. */
export function assertConfigIsValid(): void {
export async function assertConfigIsValid() {
// Ensure gatekeeper mode is valid.
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
throw new Error(
@@ -134,12 +151,29 @@ export function assertConfigIsValid(): void {
"`PROXY_KEY` is set, but gatekeeper mode is not `proxy_key`. Make sure to set `GATEKEEPER=proxy_key`."
);
}
// Require appropriate firebase config if using firebase store.
if (
config.gatekeeperStore === "firebase_rtdb" &&
(!config.firebaseKey || !config.firebaseRtdbUrl)
) {
throw new Error(
"Firebase RTDB store requires `FIREBASE_KEY` and `FIREBASE_RTDB_URL` to be set."
);
}
await maybeInitializeFirebase();
}
/** Masked, but not omitted as users may wish to see if they're set. */
/**
* Masked, but not omitted as users may wish to see if they're set due to their
* implications on privacy.
*/
export const SENSITIVE_KEYS: (keyof Config)[] = [
"googleSheetsKey",
"googleSheetsSpreadsheetId",
"firebaseRtdbUrl",
"firebaseKey",
];
/** Omitted as they're not useful to display, masked or not. */
@@ -184,3 +218,29 @@ function getEnvWithDefault<T>(name: string, defaultValue: T): T {
return value as unknown as T;
}
}
let firebaseApp: firebase.app.App | undefined;
async function maybeInitializeFirebase() {
if (!config.gatekeeperStore.startsWith("firebase")) {
return;
}
const firebase = await import("firebase-admin");
const firebaseKey = Buffer.from(config.firebaseKey!, "base64").toString();
const app = firebase.initializeApp({
credential: firebase.credential.cert(JSON.parse(firebaseKey)),
databaseURL: config.firebaseRtdbUrl,
});
await app.database().ref("connection-test").set(Date.now());
firebaseApp = app;
}
export function getFirebaseApp(): firebase.app.App {
if (!firebaseApp) {
throw new Error("Firebase app not initialized.");
}
return firebaseApp;
}
+1 -1
View File
@@ -41,7 +41,7 @@ export class KeyChecker {
}
public start() {
this.log.info("Starting key checker");
this.log.info("Starting key checker...");
this.scheduleNextCheck();
}
+78 -2
View File
@@ -1,13 +1,16 @@
/**
* Basic user management. Handles creation and tracking of proxy users, personal
* access tokens, and quota management. No persistence is provided, users must
* be re-created on each proxy start via the /admin API.
* access tokens, and quota management. Supports in-memory and Firebase Realtime
* Database persistence stores.
*
* Users are identified solely by their personal access token. The token is
* used to authenticate the user for all proxied requests.
*/
import admin from "firebase-admin";
import { v4 as uuid } from "uuid";
import { config, getFirebaseApp } from "../../config";
import { logger } from "../../logger";
export interface User {
/** The user's personal access token. */
@@ -41,6 +44,15 @@ export type UserType = "normal" | "special";
type UserUpdate = Partial<User> & Pick<User, "token">;
const users: Map<string, User> = new Map();
const usersToFlush = new Set<string>();
export async function init() {
logger.info({ store: config.gatekeeperStore }, "Initializing user store...");
if (config.gatekeeperStore === "firebase_rtdb") {
await initFirebase();
}
logger.info("User store initialized.");
}
/** Creates a new user and returns their token. */
export function createUser() {
@@ -84,6 +96,13 @@ export function upsertUser(user: UserUpdate) {
...existing,
...user,
});
usersToFlush.add(user.token);
// Immediately schedule a flush to the database if we're using Firebase.
if (config.gatekeeperStore === "firebase_rtdb") {
setImmediate(flushUsers);
}
return users.get(user.token);
}
@@ -92,6 +111,7 @@ export function incrementPromptCount(token: string) {
const user = users.get(token);
if (!user) return;
user.promptCount++;
usersToFlush.add(token);
}
/** Increments the token count for the given user by the given amount. */
@@ -99,6 +119,7 @@ export function incrementTokenCount(token: string, amount = 1) {
const user = users.get(token);
if (!user) return;
user.tokenCount += amount;
usersToFlush.add(token);
}
/**
@@ -111,6 +132,7 @@ export function authenticate(token: string, ip: string) {
if (!user || user.disabledAt) return;
if (!user.ip.includes(ip)) user.ip.push(ip);
user.lastUsedAt = Date.now();
usersToFlush.add(token);
return user;
}
@@ -120,4 +142,58 @@ export function disableUser(token: string, reason?: string) {
if (!user) return;
user.disabledAt = Date.now();
user.disabledReason = reason;
usersToFlush.add(token);
}
// TODO: Firebase persistence is pretend right now and just polls the in-memory
// store to sync it with Firebase when it changes. Will refactor to abstract
// persistence layer later so we can support multiple stores.
let firebaseTimeout: NodeJS.Timeout | undefined;
async function initFirebase() {
logger.info("Connecting to Firebase...");
const app = getFirebaseApp();
const db = admin.database(app);
const usersRef = db.ref("users");
const snapshot = await usersRef.once("value");
const users: Record<string, User> | null = snapshot.val();
firebaseTimeout = setInterval(flushUsers, 20 * 1000);
if (!users) {
logger.info("No users found in Firebase.");
return;
}
for (const token in users) {
upsertUser(users[token]);
}
usersToFlush.clear();
const numUsers = Object.keys(users).length;
logger.info({ users: numUsers }, "Loaded users from Firebase");
}
async function flushUsers() {
const app = getFirebaseApp();
const db = admin.database(app);
const usersRef = db.ref("users");
const updates: Record<string, User> = {};
for (const token of usersToFlush) {
const user = users.get(token);
if (!user) {
continue;
}
updates[token] = user;
}
usersToFlush.clear();
const numUpdates = Object.keys(updates).length;
if (numUpdates === 0) {
return;
}
await usersRef.update(updates);
logger.info(
{ users: Object.keys(updates).length },
"Flushed users to Firebase"
);
}
+1 -1
View File
@@ -191,7 +191,7 @@ function cleanQueue() {
(waitTime) => now - waitTime.end > 90 * 1000
);
const removed = waitTimes.splice(0, index + 1);
log.info(
log.debug(
{ stalledRequests: oldRequests.length, prunedWaitTimes: removed.length },
`Cleaning up request queue.`
);
+52 -31
View File
@@ -11,22 +11,10 @@ import { proxyRouter, rewriteTavernRequests } from "./proxy/routes";
import { handleInfoPage } from "./info-page";
import { logQueue } from "./prompt-logging";
import { start as startRequestQueue } from "./proxy/queue";
import { init as initUserStore } from "./proxy/auth/user-store";
const PORT = config.port;
process.on("uncaughtException", (err: any) => {
logger.error(
{ err, stack: err?.stack },
"UNCAUGHT EXCEPTION. Please report this error trace."
);
});
process.on("unhandledRejection", (err: any) => {
logger.error(
{ err, stack: err?.stack },
"UNCAUGHT PROMISE REJECTION. Please report this error trace."
);
});
const app = express();
// middleware
app.use("/", rewriteTavernRequests);
@@ -88,10 +76,56 @@ app.use((_req: unknown, res: express.Response) => {
res.status(404).json({ error: "Not found" });
});
// start server and load keys
app.listen(PORT, async () => {
assertConfigIsValid();
async function start() {
logger.info("Server starting up...");
setGitSha();
logger.info("Checking configs and external dependencies...");
await assertConfigIsValid();
keyPool.init();
if (config.gatekeeper === "user_token") {
await initUserStore();
}
if (config.promptLogging) {
logger.info("Starting prompt logging...");
logQueue.start();
}
if (config.queueMode !== "none") {
logger.info("Starting request queue...");
startRequestQueue();
}
app.listen(PORT, async () => {
logger.info({ port: PORT }, "Now listening for connections.");
registerUncaughtExceptionHandler();
});
logger.info(
{ sha: process.env.COMMIT_SHA, nodeEnv: process.env.NODE_ENV },
"Startup complete."
);
}
function registerUncaughtExceptionHandler() {
process.on("uncaughtException", (err: any) => {
logger.error(
{ err, stack: err?.stack },
"UNCAUGHT EXCEPTION. Please report this error trace."
);
});
process.on("unhandledRejection", (err: any) => {
logger.error(
{ err, stack: err?.stack },
"UNCAUGHT PROMISE REJECTION. Please report this error trace."
);
});
}
function setGitSha() {
try {
// Huggingface seems to have changed something about how they deploy Spaces
// and git commands fail because of some ownership issue with the .git
@@ -132,19 +166,6 @@ app.listen(PORT, async () => {
);
process.env.COMMIT_SHA = "unknown";
}
}
logger.info(
{ sha: process.env.COMMIT_SHA },
`Server listening on port ${PORT}`
);
keyPool.init();
if (config.promptLogging) {
logger.info("Starting prompt logging...");
logQueue.start();
}
if (config.queueMode !== "none") {
logger.info("Starting request queue...");
startRequestQueue();
}
});
start();