diff --git a/src/server.ts b/src/server.ts index 920ccd7..3bcc2e3 100644 --- a/src/server.ts +++ b/src/server.ts @@ -92,7 +92,8 @@ async function start() { logger.info("Checking configs and external dependencies..."); await assertConfigIsValid(); - keyPool.init(); + logger.info("Starting key pool..."); + await keyPool.init(); await initTokenizers(); diff --git a/src/shared/key-management/anthropic/provider.ts b/src/shared/key-management/anthropic/provider.ts index 0f5b30a..a56a5f2 100644 --- a/src/shared/key-management/anthropic/provider.ts +++ b/src/shared/key-management/anthropic/provider.ts @@ -106,7 +106,7 @@ export class AnthropicKeyProvider implements KeyProvider { this.log.info({ keyCount: this.keys.length }, "Loaded Anthropic keys."); } - public init() { + public async init() { if (config.checkKeys) { this.checker = new AnthropicKeyChecker(this.keys, this.update.bind(this)); this.checker.start(); diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index 78459d3..985d017 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -53,7 +53,7 @@ for service-agnostic functionality. export interface KeyProvider { readonly service: LLMService; - init(): void; + init(store: KeyStore): Promise; get(model: Model): T; list(): Omit[]; disable(key: T): void; @@ -65,6 +65,12 @@ export interface KeyProvider { recheck(): void; } +export interface KeyStore> { + load(): Promise; + add(key: T): void; + update(key: T): void; +} + export const keyPool = new KeyPool(); export const SUPPORTED_MODELS = [ ...OPENAI_SUPPORTED_MODELS, diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 07a7b4c..4a7dd6d 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -9,6 +9,7 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; import { GooglePalmKeyProvider } from "./palm/provider"; import { AwsBedrockKeyProvider } from "./aws/provider"; +import { MemoryKeyStore } from "./stores/memory"; type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate; @@ -25,8 +26,10 @@ export class KeyPool { this.keyProviders.push(new AwsBedrockKeyProvider()); } - public init() { - this.keyProviders.forEach((provider) => provider.init()); + public async init() { + const KeyStore = MemoryKeyStore; // TODO: select based on config + await Promise.all(this.keyProviders.map((p) => p.init(new KeyStore()))); + const availableKeys = this.available("all"); if (availableKeys === 0) { throw new Error( diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index 364648f..e55bbbf 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -66,6 +66,20 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage { rateLimitTokensReset: number; } +const SERIALIZABLE_FIELDS = [ + "key", + "service", + "hash", + "organizationId", + "gpt4Tokens", + "gpt4-32kTokens", + "turboTokens", +] as const; +type SerializableOpenAIKey = Partial< + Pick +> & + Pick; + export type OpenAIKeyUpdate = Omit< Partial, "key" | "hash" | "promptCount" @@ -81,7 +95,7 @@ const KEY_REUSE_DELAY = 1000; export class OpenAIKeyProvider implements KeyProvider { readonly service = "openai" as const; - private keys: OpenAIKey[] = []; + private readonly keys: OpenAIKey[] = []; private checker?: OpenAIKeyChecker; private log = logger.child({ module: "key-provider", service: this.service }); @@ -95,35 +109,13 @@ export class OpenAIKeyProvider implements KeyProvider { bareKeys = keyString.split(",").map((k) => k.trim()); bareKeys = [...new Set(bareKeys)]; for (const k of bareKeys) { - const newKey: OpenAIKey = { - key: k, - service: "openai" as const, - modelFamilies: ["turbo" as const, "gpt4" as const], - isTrial: false, - isDisabled: false, - isRevoked: false, - isOverQuota: false, - lastUsed: 0, - lastChecked: 0, - promptCount: 0, - hash: `oai-${crypto - .createHash("sha256") - .update(k) - .digest("hex") - .slice(0, 8)}`, - rateLimitedAt: 0, - rateLimitRequestsReset: 0, - rateLimitTokensReset: 0, - turboTokens: 0, - gpt4Tokens: 0, - "gpt4-32kTokens": 0, - }; + const newKey = OpenAIKeyProvider.deserialize({ key: k }); this.keys.push(newKey); } this.log.info({ keyCount: this.keys.length }, "Loaded OpenAI keys."); } - public init() { + public async init() { if (config.checkKeys) { const cloneFn = this.clone.bind(this); const updateFn = this.update.bind(this); @@ -137,12 +129,7 @@ export class OpenAIKeyProvider implements KeyProvider { * Don't mutate returned keys, use a KeyPool method instead. **/ public list() { - return this.keys.map((key) => { - return Object.freeze({ - ...key, - key: undefined, - }); - }); + return this.keys.map((key) => Object.freeze({ ...key, key: undefined })); } public get(model: Model) { @@ -383,20 +370,32 @@ export class OpenAIKeyProvider implements KeyProvider { this.checker?.scheduleNextCheck(); } - /** Writes key status to disk. */ - // public writeKeyStatus() { - // const keys = this.keys.map((key) => ({ - // key: key.key, - // isGpt4: key.isGpt4, - // usage: key.usage, - // hardLimit: key.hardLimit, - // isDisabled: key.isDisabled, - // })); - // fs.writeFileSync( - // path.join(__dirname, "..", "keys.json"), - // JSON.stringify(keys, null, 2) - // ); - // } + static deserialize({ key, ...rest }: SerializableOpenAIKey): OpenAIKey { + return { + key, + service: "openai" as const, + modelFamilies: ["turbo" as const, "gpt4" as const], + isTrial: false, + isDisabled: false, + isRevoked: false, + isOverQuota: false, + lastUsed: 0, + lastChecked: 0, + promptCount: 0, + hash: `oai-${crypto + .createHash("sha256") + .update(key) + .digest("hex") + .slice(0, 8)}`, + rateLimitedAt: 0, + rateLimitRequestsReset: 0, + rateLimitTokensReset: 0, + turboTokens: 0, + gpt4Tokens: 0, + "gpt4-32kTokens": 0, + ...rest, + }; + } } /** diff --git a/src/shared/key-management/stores/firebase.ts b/src/shared/key-management/stores/firebase.ts new file mode 100644 index 0000000..65cf410 --- /dev/null +++ b/src/shared/key-management/stores/firebase.ts @@ -0,0 +1,26 @@ +import type firebase from "firebase-admin"; +import { Key, KeyStore } from ".."; +import { getFirebaseApp } from "../../../config"; + +export class FirebaseKeyStore> + implements KeyStore +{ + private db: firebase.database.Database; + + constructor(app = getFirebaseApp()) { + this.db = app.database(); + } + + public async load() { + throw new Error("Method not implemented."); + return []; + } + + public add(key: K) { + throw new Error("Method not implemented."); + } + + public update(key: K) { + throw new Error("Method not implemented."); + } +} diff --git a/src/shared/key-management/stores/memory.ts b/src/shared/key-management/stores/memory.ts new file mode 100644 index 0000000..18ee623 --- /dev/null +++ b/src/shared/key-management/stores/memory.ts @@ -0,0 +1,14 @@ +import { Key, KeyStore } from ".."; + +export class MemoryKeyStore> implements KeyStore { + constructor() {} + + public async load() { + // TODO: load from process.env + return []; + } + + public add(_key: K) {} + + public update(_key: K) {} +}