diff --git a/src/admin/common.ts b/src/admin/common.ts index d9b30a6..4f6a0ce 100644 --- a/src/admin/common.ts +++ b/src/admin/common.ts @@ -50,12 +50,26 @@ export const UserSchema = z promptCount: z.number().optional(), tokenCount: z.any().optional(), // never used, but remains for compatibility tokenCounts: z - .object({ turbo: z.number(), gpt4: z.number(), claude: z.number() }) - .strict() + .object({ + turbo: z.number().optional(), + gpt4: z.number().optional(), + "gpt4-32k": z.number().optional().default(0), + claude: z.number().optional(), + }) + .refine(zodModelFamilyRefinement, { + message: "If provided, tokenCounts must include all model families", + }) .optional(), tokenLimits: z - .object({ turbo: z.number(), gpt4: z.number(), claude: z.number() }) - .strict() + .object({ + turbo: z.number().optional(), + gpt4: z.number().optional(), + "gpt4-32k": z.number().optional().default(0), + claude: z.number().optional(), + }) + .refine(zodModelFamilyRefinement, { + message: "If provided, tokenLimits must include all model families", + }) .optional(), createdAt: z.number().optional(), lastUsedAt: z.number().optional(), @@ -64,6 +78,19 @@ export const UserSchema = z }) .strict(); +// gpt4-32k was added after the initial release, so this tries to allow for +// data imported from older versions of the app which may be missing the +// new model family. +// Otherwise, all model families must be present. +function zodModelFamilyRefinement(data: Record) { + const keys = Object.keys(data).sort(); + const validSets = [ + ["claude", "gpt4", "turbo"], + ["claude", "gpt4", "gpt4-32k", "turbo"], + ]; + return validSets.some((set) => keys.join(",") === set.join(",")); +} + export const UserSchemaWithToken = UserSchema.extend({ token: z.string(), }).strict(); diff --git a/src/config.ts b/src/config.ts index df80865..2b13742 100644 --- a/src/config.ts +++ b/src/config.ts @@ -129,6 +129,8 @@ type Config = { turbo: number; /** Token allowance for GPT-4 models. */ gpt4: number; + /** Token allowance for GPT-4 32k models. */ + "gpt4-32k": number; /** Token allowance for Claude models. */ claude: number; }; @@ -197,6 +199,7 @@ export const config: Config = { tokenQuota: { turbo: getEnvWithDefault("TOKEN_QUOTA_TURBO", 0), gpt4: getEnvWithDefault("TOKEN_QUOTA_GPT4", 0), + "gpt4-32k": getEnvWithDefault("TOKEN_QUOTA_GPT4_32K", 0), claude: getEnvWithDefault("TOKEN_QUOTA_CLAUDE", 0), }, quotaRefreshPeriod: getEnvWithDefault("QUOTA_REFRESH_PERIOD", undefined), diff --git a/src/proxy/auth/user-store.ts b/src/proxy/auth/user-store.ts index a3c5e47..0fdde78 100644 --- a/src/proxy/auth/user-store.ts +++ b/src/proxy/auth/user-store.ts @@ -11,12 +11,13 @@ import admin from "firebase-admin"; import schedule from "node-schedule"; import { v4 as uuid } from "uuid"; import { config, getFirebaseApp } from "../../config"; +import { ModelFamily } from "../../key-management"; import { logger } from "../../logger"; const log = logger.child({ module: "users" }); // TODO: Consolidate model families with QueuePartition and KeyProvider. -type QuotaModel = "claude" | "turbo" | "gpt4"; +type QuotaModel = ModelFamily; export interface User { /** The user's personal access token. */ @@ -96,7 +97,7 @@ export function createUser() { ip: [], type: "normal", promptCount: 0, - tokenCounts: { turbo: 0, gpt4: 0, claude: 0 }, + tokenCounts: { turbo: 0, gpt4: 0, "gpt4-32k": 0, claude: 0 }, tokenLimits: { ...config.tokenQuota }, createdAt: Date.now(), }); @@ -125,7 +126,7 @@ export function upsertUser(user: UserUpdate) { ip: [], type: "normal", promptCount: 0, - tokenCounts: { turbo: 0, gpt4: 0, claude: 0 }, + tokenCounts: { turbo: 0, gpt4: 0, "gpt4-32k": 0, claude: 0 }, tokenLimits: { ...config.tokenQuota }, createdAt: Date.now(), }; @@ -281,8 +282,11 @@ async function flushUsers() { log.info({ users: Object.keys(updates).length }, "Flushed users to Firebase"); } -// TODO: add gpt-4-32k models; use key-management/models.ts for family mapping +// TODO: use key-management/models.ts for family mapping function getModelFamilyForQuotaUsage(model: string): QuotaModel { + if (model.includes("32k")) { + return "gpt4-32k"; + } if (model.startsWith("gpt-4")) { return "gpt4"; }