diff --git a/src/server.ts b/src/server.ts index 3bcc2e3..9b280d8 100644 --- a/src/server.ts +++ b/src/server.ts @@ -5,16 +5,16 @@ import cors from "cors"; import path from "path"; import pinoHttp from "pino-http"; import childProcess from "child_process"; -import { logger } from "./logger"; -import { keyPool } from "./shared/key-management"; -import { adminRouter } from "./admin/routes"; -import { proxyRouter } from "./proxy/routes"; import { handleInfoPage } from "./info-page"; -import { logQueue } from "./shared/prompt-logging"; -import { start as startRequestQueue } from "./proxy/queue"; -import { init as initUserStore } from "./shared/users/user-store"; -import { init as initTokenizers } from "./shared/tokenization"; +import { logger } from "./logger"; +import { adminRouter } from "./admin/routes"; import { checkOrigin } from "./proxy/check-origin"; +import { start as startRequestQueue } from "./proxy/queue"; +import { proxyRouter } from "./proxy/routes"; +import { init as initKeyPool } from "./shared/key-management"; +import { logQueue } from "./shared/prompt-logging"; +import { init as initTokenizers } from "./shared/tokenization"; +import { init as initUserStore } from "./shared/users/user-store"; import { userRouter } from "./user/routes"; const PORT = config.port; @@ -93,7 +93,7 @@ async function start() { await assertConfigIsValid(); logger.info("Starting key pool..."); - await keyPool.init(); + await initKeyPool(); await initTokenizers(); diff --git a/src/shared/key-management/anthropic/provider.ts b/src/shared/key-management/anthropic/provider.ts index 144bd30..266e533 100644 --- a/src/shared/key-management/anthropic/provider.ts +++ b/src/shared/key-management/anthropic/provider.ts @@ -1,5 +1,5 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; +import { Key, KeyProvider, KeyStore } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; import type { AnthropicModelFamily } from "../../models"; @@ -70,29 +70,35 @@ const RATE_LIMIT_LOCKOUT = 2000; const KEY_REUSE_DELAY = 500; export class AnthropicKeyProvider implements KeyProvider { - readonly service = "anthropic"; + readonly service = "anthropic" as const; - private keys: AnthropicKey[] = []; + private readonly keys: AnthropicKey[] = []; + private store: KeyStore; private checker?: AnthropicKeyChecker; private log = logger.child({ module: "key-provider", service: this.service }); - constructor() { - const keyConfig = config.anthropicKey?.trim(); - if (!keyConfig) { - this.log.warn( - "ANTHROPIC_KEY is not set. Anthropic API will not be available." - ); - return; - } - let bareKeys: string[]; - bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; - for (const key of bareKeys) { - this.keys.push(AnthropicKeyProvider.deserialize({ key })); - } - this.log.info({ keyCount: this.keys.length }, "Loaded Anthropic keys."); + constructor(store: KeyStore) { + this.store = store; } public async init() { + const storeName = this.store.constructor.name; + const serializedKeys = await this.store.load(); + + if (serializedKeys.length === 0) { + this.log.warn( + { via: storeName }, + "No Anthropic keys found. Anthropic API will not be available." + ); + return; + } + + this.keys.push(...serializedKeys.map(AnthropicKeyProvider.deserialize)); + this.log.info( + { count: this.keys.length, via: storeName }, + "Loaded Anthropic keys." + ); + if (config.checkKeys) { this.checker = new AnthropicKeyChecker(this.keys, this.update.bind(this)); this.checker.start(); diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index 985d017..cf31fe8 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -53,7 +53,7 @@ for service-agnostic functionality. export interface KeyProvider { readonly service: LLMService; - init(store: KeyStore): Promise; + init(): Promise; get(model: Model): T; list(): Omit[]; disable(key: T): void; @@ -71,7 +71,12 @@ export interface KeyStore> { update(key: T): void; } -export const keyPool = new KeyPool(); +export let keyPool: KeyPool; +export async function init() { + keyPool = new KeyPool(); + await keyPool.init(); +} + export const SUPPORTED_MODELS = [ ...OPENAI_SUPPORTED_MODELS, ...ANTHROPIC_SUPPORTED_MODELS, diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 4a7dd6d..04d7aac 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -5,11 +5,12 @@ import schedule from "node-schedule"; import { config } from "../../config"; import { logger } from "../../logger"; import { Key, Model, KeyProvider, LLMService } from "./index"; +import { GooglePalmKeyProvider } from "./palm/provider"; +import { FirebaseKeyStore, MemoryKeyStore } from "./stores"; import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; import { GooglePalmKeyProvider } from "./palm/provider"; import { AwsBedrockKeyProvider } from "./aws/provider"; -import { MemoryKeyStore } from "./stores/memory"; type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate; @@ -20,21 +21,22 @@ export class KeyPool { }; constructor() { - this.keyProviders.push(new OpenAIKeyProvider()); - this.keyProviders.push(new AnthropicKeyProvider()); - this.keyProviders.push(new GooglePalmKeyProvider()); - this.keyProviders.push(new AwsBedrockKeyProvider()); + this.keyProviders.push(new OpenAIKeyProvider(createKeyStore("openai"))); + this.keyProviders.push( + new AnthropicKeyProvider(createKeyStore("anthropic")) + ); + this.keyProviders.push( + new GooglePalmKeyProvider(createKeyStore("google-palm")) + ); + // this.keyProviders.push(new AwsBedrockKeyProvider()); } public async init() { - const KeyStore = MemoryKeyStore; // TODO: select based on config - await Promise.all(this.keyProviders.map((p) => p.init(new KeyStore()))); + await Promise.all(this.keyProviders.map((p) => p.init())); const availableKeys = this.available("all"); if (availableKeys === 0) { - throw new Error( - "No keys loaded. Ensure that at least one key is configured." - ); + throw new Error("No keys loaded, the application cannot start."); } this.scheduleRecheck(); } @@ -154,3 +156,14 @@ export class KeyPool { this.recheckJobs.openai = job; } } + +function createKeyStore(service: LLMService) { + switch (config.persistenceProvider) { + case "memory": + return new MemoryKeyStore(service); + case "firebase_rtdb": + return new FirebaseKeyStore(service); + default: + throw new Error(`Unknown store type: ${config.persistenceProvider}`); + } +} diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index 72cb1cb..80f4fe9 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -3,7 +3,7 @@ round-robin access to keys. Keys are stored in the OPENAI_KEY environment variable as a comma-separated list of keys. */ import crypto from "crypto"; import http from "http"; -import { Key, KeyProvider, Model } from "../index"; +import { Key, KeyProvider, Model, KeyStore } from "../index"; import { config } from "../../../config"; import { logger } from "../../../logger"; import { OpenAIKeyChecker } from "./checker"; @@ -66,7 +66,7 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage { rateLimitTokensReset: number; } -const SERIALIZABLE_FIELDS = [ +const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [ "key", "service", "hash", @@ -74,7 +74,7 @@ const SERIALIZABLE_FIELDS = [ "gpt4Tokens", "gpt4-32kTokens", "turboTokens", -] as const; +]; type SerializableOpenAIKey = Partial< Pick > & @@ -96,25 +96,38 @@ export class OpenAIKeyProvider implements KeyProvider { readonly service = "openai" as const; private readonly keys: OpenAIKey[] = []; + private store: KeyStore; private checker?: OpenAIKeyChecker; private log = logger.child({ module: "key-provider", service: this.service }); - constructor() { - const keyString = config.openaiKey?.trim(); - if (!keyString) { - this.log.warn("OPENAI_KEY is not set. OpenAI API will not be available."); - return; - } - let bareKeys: string[]; - bareKeys = keyString.split(",").map((k) => k.trim()); - bareKeys = [...new Set(bareKeys)]; - for (const k of bareKeys) { - this.keys.push(OpenAIKeyProvider.deserialize({ key: k })); - } - this.log.info({ keyCount: this.keys.length }, "Loaded OpenAI keys."); + constructor(store: KeyStore) { + this.store = store; } public async init() { + const storeName = this.store.constructor.name; + const serializedKeys = await this.store.load(); + + // TODO: If keystore is unavailable or returns no keys, instantiate a + // MemoryKeyStore and use the keys from process.env. Migrate them to the + // keystore when it becomes available. + // TODO: after key management UI, keychecker should always be enabled + // because keys may be added after initialization. + + if (serializedKeys.length === 0) { + this.log.warn( + { via: storeName }, + "No OpenAI keys found. OpenAI API will not be available." + ); + return; + } + + this.keys.push(...serializedKeys.map(OpenAIKeyProvider.deserialize)); + this.log.info( + { count: this.keys.length, via: storeName }, + "Loaded OpenAI keys." + ); + if (config.checkKeys) { const cloneFn = this.clone.bind(this); const updateFn = this.update.bind(this); @@ -372,7 +385,7 @@ export class OpenAIKeyProvider implements KeyProvider { static deserialize({ key, ...rest }: SerializableOpenAIKey): OpenAIKey { return { key, - service: "openai" as const, + service: "openai", modelFamilies: ["turbo" as const, "gpt4" as const], isTrial: false, isDisabled: false, diff --git a/src/shared/key-management/palm/provider.ts b/src/shared/key-management/palm/provider.ts index ce8c8a7..49eb72e 100644 --- a/src/shared/key-management/palm/provider.ts +++ b/src/shared/key-management/palm/provider.ts @@ -1,6 +1,5 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; -import { config } from "../../../config"; +import { Key, KeyProvider, KeyStore } from ".."; import { logger } from "../../../logger"; import type { GooglePalmModelFamily } from "../../models"; @@ -34,6 +33,17 @@ export interface GooglePalmKey extends Key, GooglePalmKeyUsage { rateLimitedUntil: number; } +const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [ + "key", + "service", + "hash", + "bisonTokens", +]; +type SerializableGooglePalmKey = Partial< + Pick +> & + Pick; + /** * Upon being rate limited, a key will be locked out for this many milliseconds * while we wait for other concurrent requests to finish. @@ -50,43 +60,31 @@ export class GooglePalmKeyProvider implements KeyProvider { readonly service = "google-palm"; private keys: GooglePalmKey[] = []; + private store: KeyStore; private log = logger.child({ module: "key-provider", service: this.service }); - constructor() { - const keyConfig = config.googlePalmKey?.trim(); - if (!keyConfig) { + constructor(store: KeyStore) { + this.store = store; + } + + public async init() { + const storeName = this.store.constructor.name; + const serializedKeys = await this.store.load(); + + if (serializedKeys.length === 0) { this.log.warn( - "GOOGLE_PALM_KEY is not set. PaLM API will not be available." + { via: storeName }, + "No PaLM keys found. PaLM API will not be available." ); return; } - let bareKeys: string[]; - bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; - for (const key of bareKeys) { - const newKey: GooglePalmKey = { - key, - service: this.service, - modelFamilies: ["bison"], - isDisabled: false, - isRevoked: false, - promptCount: 0, - lastUsed: 0, - rateLimitedAt: 0, - rateLimitedUntil: 0, - hash: `plm-${crypto - .createHash("sha256") - .update(key) - .digest("hex") - .slice(0, 8)}`, - lastChecked: 0, - bisonTokens: 0, - }; - this.keys.push(newKey); - } - this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys."); - } - public init() {} + this.keys.push(...serializedKeys.map(GooglePalmKeyProvider.deserialize)); + this.log.info( + { keyCount: this.keys.length, via: storeName }, + "Loaded PaLM keys." + ); + } public list() { return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); @@ -186,4 +184,27 @@ export class GooglePalmKeyProvider implements KeyProvider { } public recheck() {} + + static deserialize(serializedKey: SerializableGooglePalmKey): GooglePalmKey { + const { key, ...rest } = serializedKey; + return { + key, + service: "google-palm", + modelFamilies: ["bison"], + isTrial: false, + isDisabled: false, + promptCount: 0, + lastUsed: 0, + rateLimitedAt: 0, + rateLimitedUntil: 0, + hash: `plm-${crypto + .createHash("sha256") + .update(key) + .digest("hex") + .slice(0, 8)}`, + lastChecked: 0, + bisonTokens: 0, + ...rest, + }; + } } diff --git a/src/shared/key-management/stores/firebase.ts b/src/shared/key-management/stores/firebase.ts index 65cf410..c97377f 100644 --- a/src/shared/key-management/stores/firebase.ts +++ b/src/shared/key-management/stores/firebase.ts @@ -1,5 +1,5 @@ import type firebase from "firebase-admin"; -import { Key, KeyStore } from ".."; +import { AIService, Key, KeyStore } from ".."; import { getFirebaseApp } from "../../../config"; export class FirebaseKeyStore> @@ -7,7 +7,7 @@ export class FirebaseKeyStore> { private db: firebase.database.Database; - constructor(app = getFirebaseApp()) { + constructor(service: AIService, app = getFirebaseApp()) { this.db = app.database(); } diff --git a/src/shared/key-management/stores/index.ts b/src/shared/key-management/stores/index.ts new file mode 100644 index 0000000..eee6299 --- /dev/null +++ b/src/shared/key-management/stores/index.ts @@ -0,0 +1,2 @@ +export { FirebaseKeyStore } from "./firebase"; +export { MemoryKeyStore } from "./memory"; diff --git a/src/shared/key-management/stores/memory.ts b/src/shared/key-management/stores/memory.ts index 18ee623..04d63bb 100644 --- a/src/shared/key-management/stores/memory.ts +++ b/src/shared/key-management/stores/memory.ts @@ -1,11 +1,32 @@ -import { Key, KeyStore } from ".."; +import { APIFormat, Key, KeyStore } from ".."; export class MemoryKeyStore> implements KeyStore { - constructor() {} + private env: string; + + constructor(service: APIFormat) { + switch (service) { + case "anthropic": + this.env = "ANTHROPIC_KEY"; + break; + case "openai": + case "openai-text": + this.env = "OPENAI_KEY"; + break; + case "google-palm": + this.env = "GOOGLE_PALM_KEY"; + break; + default: + const never: never = service; + throw new Error(`Unknown service: ${never}`); + } + } public async load() { - // TODO: load from process.env - return []; + let bareKeys: string[]; + bareKeys = [ + ...new Set(process.env[this.env]?.split(",").map((k) => k.trim())), + ]; + return bareKeys.map((key) => ({ key } as K)); } public add(_key: K) {}