Per-user token quotas and automatic quota refreshing (khanon/oai-reverse-proxy!37)

This commit is contained in:
khanon
2023-08-28 19:33:14 +00:00
parent 785b1f69f3
commit cb780e85da
31 changed files with 544 additions and 145 deletions
+9 -5
View File
@@ -8,6 +8,7 @@ import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
applyQuotaLimits,
addAnthropicPreamble,
blockZoomerOrigins,
createPreprocessorMiddleware,
@@ -72,6 +73,7 @@ const rewriteAnthropicRequest = (
res: http.ServerResponse
) => {
const rewriterPipeline = [
applyQuotaLimits,
addKey,
addAnthropicPreamble,
languageFilter,
@@ -108,7 +110,7 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
if (req.inboundApi === "openai") {
req.log.info("Transforming Anthropic response to OpenAI format");
body = transformAnthropicResponse(body);
body = transformAnthropicResponse(body, req);
}
// TODO: Remove once tokenization is stable
@@ -126,17 +128,19 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
* on-the-fly.
*/
function transformAnthropicResponse(
anthropicBody: Record<string, any>
anthropicBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "ant-" + anthropicBody.log_id,
object: "chat.completion",
created: Date.now(),
model: anthropicBody.model,
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
+1 -1
View File
@@ -1,6 +1,6 @@
import type { Request, RequestHandler } from "express";
import { config } from "../../config";
import { authenticate, getUser } from "./user-store";
import { authenticate, getUser, hasAvailableQuota } from "./user-store";
const GATEKEEPER = config.gatekeeper;
const PROXY_KEY = config.proxyKey;
+108 -17
View File
@@ -8,10 +8,16 @@
*/
import admin from "firebase-admin";
import schedule from "node-schedule";
import { v4 as uuid } from "uuid";
import { config, getFirebaseApp } from "../../config";
import { logger } from "../../logger";
const log = logger.child({ module: "users" });
// TODO: Consolidate model families with QueuePartition and KeyProvider.
type QuotaModel = "claude" | "turbo" | "gpt4";
export interface User {
/** The user's personal access token. */
token: string;
@@ -21,8 +27,12 @@ export interface User {
type: UserType;
/** The number of prompts the user has made. */
promptCount: number;
/** The number of tokens the user has consumed. Not yet implemented. */
tokenCount: number;
/** @deprecated Use `tokenCounts` instead. */
tokenCount?: never;
/** The number of tokens the user has consumed, by model family. */
tokenCounts: Record<QuotaModel, number>;
/** The maximum number of tokens the user can consume, by model family. */
tokenLimits: Record<QuotaModel, number>;
/** The time at which the user was created. */
createdAt: number;
/** The time at which the user last connected. */
@@ -37,7 +47,6 @@ export interface User {
* Possible privilege levels for a user.
* - `normal`: Default role. Subject to usual rate limits and quotas.
* - `special`: Special role. Higher quotas and exempt from auto-ban/lockout.
* TODO: implement auto-ban/lockout for normal users when they do naughty shit
*/
export type UserType = "normal" | "special";
@@ -49,11 +58,32 @@ const users: Map<string, User> = new Map();
const usersToFlush = new Set<string>();
export async function init() {
logger.info({ store: config.gatekeeperStore }, "Initializing user store...");
log.info({ store: config.gatekeeperStore }, "Initializing user store...");
if (config.gatekeeperStore === "firebase_rtdb") {
await initFirebase();
}
logger.info("User store initialized.");
if (config.quotaRefreshPeriod) {
const quotaRefreshJob = schedule.scheduleJob(getRefreshCrontab(), () => {
for (const user of users.values()) {
refreshQuota(user.token);
}
log.info(
{ users: users.size, nextRefresh: quotaRefreshJob.nextInvocation() },
"Token quotas refreshed."
);
});
if (!quotaRefreshJob) {
throw new Error(
"Unable to schedule quota refresh. Is QUOTA_REFRESH_PERIOD set correctly?"
);
}
log.debug(
{ nextRefresh: quotaRefreshJob.nextInvocation() },
"Scheduled token quota refresh."
);
}
log.info("User store initialized.");
}
/** Creates a new user and returns their token. */
@@ -64,7 +94,8 @@ export function createUser() {
ip: [],
type: "normal",
promptCount: 0,
tokenCount: 0,
tokenCounts: { turbo: 0, gpt4: 0, claude: 0 },
tokenLimits: { ...config.tokenQuota },
createdAt: Date.now(),
});
usersToFlush.add(token);
@@ -86,12 +117,14 @@ export function getUsers() {
* user information via JSON. Use other functions for more specific operations.
*/
export function upsertUser(user: UserUpdate) {
// TODO: May need better merging for nested objects
const existing: User = users.get(user.token) ?? {
token: user.token,
ip: [],
type: "normal",
promptCount: 0,
tokenCount: 0,
tokenCounts: { turbo: 0, gpt4: 0, claude: 0 },
tokenLimits: { ...config.tokenQuota },
createdAt: Date.now(),
};
@@ -117,11 +150,16 @@ export function incrementPromptCount(token: string) {
usersToFlush.add(token);
}
/** Increments the token count for the given user by the given amount. */
export function incrementTokenCount(token: string, amount = 1) {
/** Increments token consumption for the given user and model. */
export function incrementTokenCount(
token: string,
model: string,
consumption: number
) {
const user = users.get(token);
if (!user) return;
user.tokenCount += amount;
const modelFamily = getModelFamily(model);
user.tokenCounts[modelFamily] += consumption;
usersToFlush.add(token);
}
@@ -148,6 +186,40 @@ export function authenticate(token: string, ip: string) {
return user;
}
export function hasAvailableQuota(
token: string,
model: string,
requested: number
) {
const user = users.get(token);
if (!user) return false;
if (user.type === "special") return true;
const modelFamily = getModelFamily(model);
const { tokenCounts, tokenLimits } = user;
const tokenLimit = tokenLimits[modelFamily];
if (!tokenLimit) return true;
const tokensConsumed = tokenCounts[modelFamily] + requested;
return tokensConsumed < tokenLimit;
}
export function refreshQuota(token: string) {
const user = users.get(token);
if (!user) return;
const { tokenCounts, tokenLimits } = user;
const quotas = Object.entries(config.tokenQuota) as [QuotaModel, number][];
quotas
// If a quota is not configured, don't touch any existing limits a user may
// already have been assigned manually.
.filter(([, quota]) => quota > 0)
.forEach(
([model, quota]) => (tokenLimits[model] = tokenCounts[model] + quota)
);
usersToFlush.add(token);
}
/** Disables the given user, optionally providing a reason. */
export function disableUser(token: string, reason?: string) {
const user = users.get(token);
@@ -163,7 +235,7 @@ export function disableUser(token: string, reason?: string) {
let firebaseTimeout: NodeJS.Timeout | undefined;
async function initFirebase() {
logger.info("Connecting to Firebase...");
log.info("Connecting to Firebase...");
const app = getFirebaseApp();
const db = admin.database(app);
const usersRef = db.ref("users");
@@ -171,7 +243,7 @@ async function initFirebase() {
const users: Record<string, User> | null = snapshot.val();
firebaseTimeout = setInterval(flushUsers, 20 * 1000);
if (!users) {
logger.info("No users found in Firebase.");
log.info("No users found in Firebase.");
return;
}
for (const token in users) {
@@ -179,7 +251,7 @@ async function initFirebase() {
}
usersToFlush.clear();
const numUsers = Object.keys(users).length;
logger.info({ users: numUsers }, "Loaded users from Firebase");
log.info({ users: numUsers }, "Loaded users from Firebase");
}
async function flushUsers() {
@@ -204,8 +276,27 @@ async function flushUsers() {
}
await usersRef.update(updates);
logger.info(
{ users: Object.keys(updates).length },
"Flushed users to Firebase"
);
log.info({ users: Object.keys(updates).length }, "Flushed users to Firebase");
}
function getModelFamily(model: string): QuotaModel {
if (model.startsWith("gpt-4")) {
// TODO: add 32k models
return "gpt4";
}
if (model.startsWith("gpt-3.5")) {
return "turbo";
}
return "claude";
}
function getRefreshCrontab() {
switch (config.quotaRefreshPeriod!) {
case "hourly":
return "0 * * * *";
case "daily":
return "0 0 * * *";
default:
return config.quotaRefreshPeriod ?? "0 0 * * *";
}
}
+28 -4
View File
@@ -1,6 +1,8 @@
import { Request, Response } from "express";
import httpProxy from "http-proxy";
import { ZodError } from "zod";
import { AIService } from "../../key-management";
import { QuotaExceededError } from "./request/apply-quota-limits";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
@@ -63,9 +65,7 @@ export const handleInternalError = (
res: Response
) => {
try {
const isZod = err instanceof ZodError;
const isForbidden = err.name === "ForbiddenError";
if (isZod) {
if (err instanceof ZodError) {
writeErrorResponse(req, res, 400, {
error: {
type: "proxy_validation_error",
@@ -75,7 +75,7 @@ export const handleInternalError = (
message: err.message,
},
});
} else if (isForbidden) {
} else if (err.name === "ForbiddenError") {
// Spoofs a vaguely threatening OpenAI error message. Only invoked by the
// block-zoomers rewriter to scare off tiktokers.
writeErrorResponse(req, res, 403, {
@@ -86,6 +86,16 @@ export const handleInternalError = (
message: err.message,
},
});
} else if (err instanceof QuotaExceededError) {
writeErrorResponse(req, res, 429, {
error: {
type: "proxy_quota_exceeded",
code: "quota_exceeded",
message: `You've exceeded your token quota for this model type.`,
info: err.quotaInfo,
stack: err.stack,
},
});
} else {
writeErrorResponse(req, res, 500, {
error: {
@@ -141,3 +151,17 @@ export function buildFakeSseMessage(
}
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
}
export function getCompletionForService({
service,
body,
}: {
service: AIService;
body: Record<string, any>;
}): { completion: string; model: string } {
if (service === "anthropic") {
return { completion: body.completion.trim(), model: body.model };
} else {
return { completion: body.choices[0].message.content, model: body.model };
}
}
@@ -0,0 +1,30 @@
import { hasAvailableQuota } from "../../auth/user-store";
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
export class QuotaExceededError extends Error {
public quotaInfo: any;
constructor(message: string, quotaInfo: any) {
super(message);
this.name = "QuotaExceededError";
this.quotaInfo = quotaInfo;
}
}
export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => {
if (!isCompletionRequest(req) || !req.user) {
return;
}
const requestedTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
if (!hasAvailableQuota(req.user.token, req.body.model, requestedTokens)) {
throw new QuotaExceededError(
"You have exceeded your proxy token quota for this model.",
{
quota: req.user.tokenLimits,
used: req.user.tokenCounts,
requested: requestedTokens,
}
);
}
};
@@ -1,7 +1,7 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { countTokens } from "../../../tokenization";
import { OpenAIPromptMessage, countTokens } from "../../../tokenization";
import { RequestPreprocessor } from ".";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
@@ -15,22 +15,26 @@ const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
* request body.
*/
export const checkContextSize: RequestPreprocessor = async (req) => {
let prompt;
const service = req.outboundApi;
let result;
switch (req.outboundApi) {
case "openai":
switch (service) {
case "openai": {
req.outputTokens = req.body.max_tokens;
prompt = req.body.messages;
const prompt: OpenAIPromptMessage[] = req.body.messages;
result = await countTokens({ req, prompt, service });
break;
case "anthropic":
}
case "anthropic": {
req.outputTokens = req.body.max_tokens_to_sample;
prompt = req.body.prompt;
const prompt: string = req.body.prompt;
result = await countTokens({ req, prompt, service });
break;
}
default:
throw new Error(`Unknown outbound API: ${req.outboundApi}`);
}
const result = await countTokens({ req, prompt, service: req.outboundApi });
req.promptTokens = result.token_count;
// TODO: Remove once token counting is stable
@@ -89,6 +93,7 @@ function validateContextSize(req: Request) {
);
req.debug.prompt_tokens = promptTokens;
req.debug.completion_tokens = outputTokens;
req.debug.max_model_tokens = modelMax;
req.debug.max_proxy_tokens = proxyMax;
}
@@ -101,7 +106,7 @@ function assertRequestHasTokenCounts(
outputTokens: z.number().int().min(1),
})
.nonstrict()
.parse(req);
.parse({ promptTokens: req.promptTokens, outputTokens: req.outputTokens });
}
/**
+1
View File
@@ -3,6 +3,7 @@ import type { ClientRequest } from "http";
import type { ProxyReqCallback } from "http-proxy";
// Express middleware (runs before http-proxy-middleware, can be async)
export { applyQuotaLimits } from "./apply-quota-limits";
export { createPreprocessorMiddleware } from "./preprocess";
export { checkContextSize } from "./check-context-size";
export { setApiFormat } from "./set-api-format";
+60 -6
View File
@@ -3,14 +3,21 @@ import { Request, Response } from "express";
import * as http from "http";
import util from "util";
import zlib from "zlib";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { keyPool } from "../../../key-management";
import { enqueue, trackWaitTime } from "../../queue";
import { incrementPromptCount } from "../../auth/user-store";
import { isCompletionRequest, writeErrorResponse } from "../common";
import {
incrementPromptCount,
incrementTokenCount,
} from "../../auth/user-store";
import {
getCompletionForService,
isCompletionRequest,
writeErrorResponse,
} from "../common";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
import { countTokens } from "../../../tokenization";
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
@@ -84,12 +91,18 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
if (req.isStreaming) {
// `handleStreamedResponse` writes to the response and ends it, so
// we can only execute middleware that doesn't write to the response.
middlewareStack.push(trackRateLimit, incrementKeyUsage, logPrompt);
middlewareStack.push(
trackRateLimit,
countResponseTokens,
incrementUsage,
logPrompt
);
} else {
middlewareStack.push(
trackRateLimit,
handleUpstreamErrors,
incrementKeyUsage,
countResponseTokens,
incrementUsage,
copyHttpHeaders,
logPrompt,
...apiMiddleware
@@ -394,15 +407,56 @@ function handleOpenAIRateLimitError(
return errorPayload;
}
const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (isCompletionRequest(req)) {
keyPool.incrementPrompt(req.key!);
if (req.user) {
incrementPromptCount(req.user.token);
const model = req.body.model;
const tokensUsed = req.promptTokens! + req.outputTokens!;
incrementTokenCount(req.user.token, model, tokensUsed);
}
}
};
const countResponseTokens: ProxyResHandlerWithBody = async (
_proxyRes,
req,
_res,
body
) => {
// This function is prone to breaking if the upstream API makes even minor
// changes to the response format, especially for SSE responses. If you're
// seeing errors in this function, check the reassembled response body from
// handleStreamedResponse to see if the upstream API has changed.
try {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
const service = req.outboundApi;
const { completion } = getCompletionForService({ service, body });
const tokens = await countTokens({ req, completion, service });
req.log.debug(
{ service, tokens, prevOutputTokens: req.outputTokens },
`Counted tokens for completion`
);
if (req.debug) {
req.debug.completion_tokens = tokens;
}
req.outputTokens = tokens.token_count;
} catch (error) {
req.log.error(
error,
"Error while counting completion tokens; assuming `max_output_tokens`"
);
// req.outputTokens will already be set to `max_output_tokens` from the
// prompt counting middleware, so we don't need to do anything here.
}
};
const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
keyPool.updateRateLimits(req.key!, proxyRes.headers);
};
+2 -18
View File
@@ -1,10 +1,8 @@
import { Request } from "express";
import { config } from "../../../config";
import { AIService } from "../../../key-management";
import { logQueue } from "../../../prompt-logging";
import { isCompletionRequest } from "../common";
import { getCompletionForService, isCompletionRequest } from "../common";
import { ProxyResHandlerWithBody } from ".";
import { logger } from "../../../logger";
/** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async (
@@ -26,7 +24,7 @@ export const logPrompt: ProxyResHandlerWithBody = async (
const promptPayload = getPromptForRequest(req);
const promptFlattened = flattenMessages(promptPayload);
const response = getResponseForService({
const response = getCompletionForService({
service: req.outboundApi,
body: responseBody,
});
@@ -62,17 +60,3 @@ const flattenMessages = (messages: string | OaiMessage[]): string => {
}
return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
};
const getResponseForService = ({
service,
body,
}: {
service: AIService;
body: Record<string, any>;
}): { completion: string; model: string } => {
if (service === "anthropic") {
return { completion: body.completion.trim(), model: body.model };
} else {
return { completion: body.choices[0].message.content, model: body.model };
}
};
+2
View File
@@ -9,6 +9,7 @@ import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
applyQuotaLimits,
blockZoomerOrigins,
createPreprocessorMiddleware,
finalizeBody,
@@ -90,6 +91,7 @@ const rewriteRequest = (
res: http.ServerResponse
) => {
const rewriterPipeline = [
applyQuotaLimits,
addKey,
languageFilter,
limitCompletions,