diff --git a/src/shared/key-management/anthropic/provider.ts b/src/shared/key-management/anthropic/provider.ts index c9a13e3..02c5abe 100644 --- a/src/shared/key-management/anthropic/provider.ts +++ b/src/shared/key-management/anthropic/provider.ts @@ -1,7 +1,8 @@ import { config } from "../../../config"; import { logger } from "../../../logger"; import type { AnthropicModelFamily } from "../../models"; -import { Key, KeyProvider, KeyStore } from "../types"; +import { KeyProviderBase } from "../key-provider-base"; +import { Key } from "../types"; import { AnthropicKeyChecker } from "./checker"; const RATE_LIMIT_LOCKOUT = 2000; @@ -43,17 +44,12 @@ export interface AnthropicKey extends Key, AnthropicKeyUsage { isPozzed: boolean; } -export class AnthropicKeyProvider implements KeyProvider { +export class AnthropicKeyProvider extends KeyProviderBase { readonly service = "anthropic" as const; - private readonly keys: AnthropicKey[] = []; - private store: KeyStore; + protected readonly keys: AnthropicKey[] = []; private checker?: AnthropicKeyChecker; - private log = logger.child({ module: "key-provider", service: this.service }); - - constructor(store: KeyStore) { - this.store = store; - } + protected log = logger.child({ module: "key-provider", service: this.service }); public async init() { const storeName = this.store.constructor.name; @@ -75,10 +71,6 @@ export class AnthropicKeyProvider implements KeyProvider { } } - public list() { - return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); - } - public get(_model: AnthropicModel) { // Currently, all Anthropic keys have access to all models. This will almost // certainly change when they move out of beta later this year. @@ -123,22 +115,6 @@ export class AnthropicKeyProvider implements KeyProvider { return { ...selectedKey }; } - public disable(key: AnthropicKey) { - const keyFromPool = this.keys.find((k) => k.hash === key.hash); - if (!keyFromPool || keyFromPool.isDisabled) return; - keyFromPool.isDisabled = true; - this.log.warn({ key: key.hash }, "Key disabled"); - } - - public update(hash: string, update: Partial) { - const keyFromPool = this.keys.find((k) => k.hash === hash)!; - Object.assign(keyFromPool, { lastChecked: Date.now(), ...update }); - } - - public available() { - return this.keys.filter((k) => !k.isDisabled).length; - } - public incrementUsage(hash: string, _model: string, tokens: number) { const key = this.keys.find((k) => k.hash === hash); if (!key) return; diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index a88b69e..4b0ed7f 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -1,7 +1,8 @@ import { config } from "../../../config"; import { logger } from "../../../logger"; import type { AwsBedrockModelFamily } from "../../models"; -import { Key, KeyProvider, KeyStore } from "../types"; +import { KeyProviderBase } from "../key-provider-base"; +import { Key } from "../types"; import { AwsKeyChecker } from "./checker"; const RATE_LIMIT_LOCKOUT = 2000; @@ -35,17 +36,12 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage { awsLoggingStatus: "unknown" | "disabled" | "enabled"; } -export class AwsBedrockKeyProvider implements KeyProvider { +export class AwsBedrockKeyProvider extends KeyProviderBase { readonly service = "aws" as const; - private readonly keys: AwsBedrockKey[] = []; - private store: KeyStore; + protected readonly keys: AwsBedrockKey[] = []; private checker?: AwsKeyChecker; - private log = logger.child({ module: "key-provider", service: this.service }); - - constructor(store: KeyStore) { - this.store = store; - } + protected log = logger.child({ module: "key-provider", service: this.service }); public async init() { const storeName = this.store.constructor.name; @@ -67,10 +63,6 @@ export class AwsBedrockKeyProvider implements KeyProvider { } } - public list() { - return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); - } - public get(_model: AwsBedrockModel) { const availableKeys = this.keys.filter((k) => { const isNotLogged = k.awsLoggingStatus === "disabled"; @@ -112,22 +104,6 @@ export class AwsBedrockKeyProvider implements KeyProvider { return { ...selectedKey }; } - public disable(key: AwsBedrockKey) { - const keyFromPool = this.keys.find((k) => k.hash === key.hash); - if (!keyFromPool || keyFromPool.isDisabled) return; - keyFromPool.isDisabled = true; - this.log.warn({ key: key.hash }, "Key disabled"); - } - - public update(hash: string, update: Partial) { - const keyFromPool = this.keys.find((k) => k.hash === hash)!; - Object.assign(keyFromPool, { lastChecked: Date.now(), ...update }); - } - - public available() { - return this.keys.filter((k) => !k.isDisabled).length; - } - public incrementUsage(hash: string, _model: string, tokens: number) { const key = this.keys.find((k) => k.hash === hash); if (!key) return; diff --git a/src/shared/key-management/key-checker-base.ts b/src/shared/key-management/key-checker-base.ts index 8f30a20..f9ec693 100644 --- a/src/shared/key-management/key-checker-base.ts +++ b/src/shared/key-management/key-checker-base.ts @@ -1,7 +1,6 @@ import { AxiosError } from "axios"; import pino from "pino"; import { logger } from "../../logger"; - import { Key } from "./types"; type KeyCheckerOptions = { diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 068214a..7931629 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -4,24 +4,17 @@ import os from "os"; import schedule from "node-schedule"; import { config } from "../../config"; import { logger } from "../../logger"; +import { KeyProviderBase } from "./key-provider-base"; import { getSerializer } from "./serializers"; import { FirebaseKeyStore, MemoryKeyStore } from "./stores"; import { AnthropicKeyProvider } from "./anthropic/provider"; import { OpenAIKeyProvider } from "./openai/provider"; import { GooglePalmKeyProvider } from "./palm/provider"; import { AwsBedrockKeyProvider } from "./aws/provider"; - -import { - Key, - KeyProvider, - KeyStore, - LLMService, - Model, - ServiceToKey, -} from "./types"; +import { Key, KeyStore, LLMService, Model, ServiceToKey } from "./types"; export class KeyPool { - private keyProviders: KeyProvider[] = []; + private keyProviders: KeyProviderBase[] = []; private recheckJobs: Partial> = { openai: null, }; @@ -131,7 +124,7 @@ export class KeyPool { throw new Error(`Unknown service for model '${model}'`); } - private getKeyProvider(service: LLMService): KeyProvider { + private getKeyProvider(service: LLMService): KeyProviderBase { return this.keyProviders.find((provider) => provider.service === service)!; } diff --git a/src/shared/key-management/key-provider-base.ts b/src/shared/key-management/key-provider-base.ts index 62db2fe..cefaf3c 100644 --- a/src/shared/key-management/key-provider-base.ts +++ b/src/shared/key-management/key-provider-base.ts @@ -1,25 +1,65 @@ -import { Key, KeyProvider, LLMService, Model } from "./types"; +import { logger } from "../../logger"; +import { Key, KeyStore, LLMService, Model } from "./types"; -export abstract class KeyProvierBase implements KeyProvider { - abstract readonly service: LLMService; +export abstract class KeyProviderBase { + public abstract readonly service: LLMService; - abstract init(): Promise; + protected abstract readonly keys: K[]; + protected abstract log: typeof logger; + protected readonly store: KeyStore; - abstract get(model: Model): K; + public constructor(keyStore: KeyStore) { + this.store = keyStore; + } - abstract list(): Omit[]; + public abstract init(): Promise; - abstract disable(key: K): void; + public addKey(key: K): void { + this.keys.push(key); + this.store.add(key); + } - abstract update(hash: string, update: Partial): void; + public abstract get(model: Model): K; - abstract available(): number; + /** + * Returns a list of all keys, with the actual key value removed. Don't + * mutate the returned objects; use `update` instead to ensure the changes + * are synced to the key store. + */ + public list(): Omit[] { + return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); + } - abstract incrementUsage(hash: string, model: string, tokens: number): void; + public disable(key: K): void { + const keyFromPool = this.keys.find((k) => k.hash === key.hash); + if (!keyFromPool || keyFromPool.isDisabled) return; + this.update(key.hash, { isDisabled: true } as Partial, true); + this.log.warn({ key: key.hash }, "Key disabled"); + } - abstract getLockoutPeriod(model: Model): number; + public update(hash: string, update: Partial, force = false): void { + const key = this.keys.find((k) => k.hash === hash); + if (!key) { + throw new Error(`No key with hash ${hash}`); + } - abstract markRateLimited(hash: string): void; + Object.assign(key, { lastChecked: Date.now(), ...update }); + this.store.update(hash, update, force); + } - abstract recheck(): void; + public available(): number { + return this.keys.filter((k) => !k.isDisabled).length; + } + + public abstract incrementUsage( + hash: string, + model: string, + tokens: number + ): void; + + public abstract getLockoutPeriod(model: Model): number; + + public abstract markRateLimited(hash: string): void; + + public abstract recheck(): void; } diff --git a/src/shared/key-management/key-serializer-base.ts b/src/shared/key-management/key-serializer-base.ts index 83b6444..a681886 100644 --- a/src/shared/key-management/key-serializer-base.ts +++ b/src/shared/key-management/key-serializer-base.ts @@ -1,5 +1,4 @@ -import { KeySerializer, SerializedKey } from "./index"; -import { Key } from "./types"; +import { Key, KeySerializer, SerializedKey } from "./types"; export abstract class KeySerializerBase implements KeySerializer diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index 5217a61..4437542 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -3,8 +3,9 @@ import { IncomingHttpHeaders } from "http"; import { config } from "../../../config"; import { logger } from "../../../logger"; import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models"; -import { Key, KeyProvider, KeyStore, Model } from "../types"; +import { Key, Model } from "../types"; import { OpenAIKeyChecker } from "./checker"; +import { KeyProviderBase } from "../key-provider-base"; const KEY_REUSE_DELAY = 1000; @@ -60,17 +61,12 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage { rateLimitTokensReset: number; } -export class OpenAIKeyProvider implements KeyProvider { +export class OpenAIKeyProvider extends KeyProviderBase { readonly service = "openai" as const; - private readonly keys: OpenAIKey[] = []; - private store: KeyStore; + protected readonly keys: OpenAIKey[] = []; private checker?: OpenAIKeyChecker; - private log = logger.child({ module: "key-provider", service: this.service }); - - constructor(store: KeyStore) { - this.store = store; - } + protected log = logger.child({ module: "key-provider", service: this.service }); public async init() { const storeName = this.store.constructor.name; @@ -97,14 +93,6 @@ export class OpenAIKeyProvider implements KeyProvider { } } - /** - * Returns a list of all keys, with the key field removed. - * Don't mutate returned keys, use a KeyPool method instead. - **/ - public list() { - return this.keys.map((key) => Object.freeze({ ...key, key: undefined })); - } - public get(model: Model) { const neededFamily = getOpenAIModelFamily(model); const excludeTrials = model === "text-embedding-ada-002"; @@ -193,12 +181,6 @@ export class OpenAIKeyProvider implements KeyProvider { return { ...selectedKey }; } - /** Called by the key checker to update key information. */ - public update(keyHash: string, update: Partial) { - const keyFromPool = this.keys.find((k) => k.hash === keyHash)!; - Object.assign(keyFromPool, { lastChecked: Date.now(), ...update }); - } - /** Called by the key checker to create clones of keys for the given orgs. */ public clone(keyHash: string, newOrgIds: string[]) { const keyFromPool = this.keys.find((k) => k.hash === keyHash)!; @@ -220,19 +202,7 @@ export class OpenAIKeyProvider implements KeyProvider { ); return clone; }); - this.keys.push(...clones); - } - - /** Disables a key, or does nothing if the key isn't in this pool. */ - public disable(key: Key) { - const keyFromPool = this.keys.find((k) => k.hash === key.hash); - if (!keyFromPool || keyFromPool.isDisabled) return; - this.update(key.hash, { isDisabled: true }); - this.log.warn({ key: key.hash }, "Key disabled"); - } - - public available() { - return this.keys.filter((k) => !k.isDisabled).length; + clones.forEach((clone) => this.addKey(clone)); } /** diff --git a/src/shared/key-management/palm/provider.ts b/src/shared/key-management/palm/provider.ts index 902fe45..450ab3c 100644 --- a/src/shared/key-management/palm/provider.ts +++ b/src/shared/key-management/palm/provider.ts @@ -1,6 +1,7 @@ import { logger } from "../../../logger"; import type { GooglePalmModelFamily } from "../../models"; -import { Key, KeyProvider, KeyStore } from "../types"; +import { KeyProviderBase } from "../key-provider-base"; +import { Key } from "../types"; const RATE_LIMIT_LOCKOUT = 2000; const KEY_REUSE_DELAY = 500; @@ -22,16 +23,11 @@ export interface GooglePalmKey extends Key, GooglePalmKeyUsage { rateLimitedUntil: number; } -export class GooglePalmKeyProvider implements KeyProvider { +export class GooglePalmKeyProvider extends KeyProviderBase { readonly service = "google-palm"; - private keys: GooglePalmKey[] = []; - private store: KeyStore; - private log = logger.child({ module: "key-provider", service: this.service }); - - constructor(store: KeyStore) { - this.store = store; - } + protected keys: GooglePalmKey[] = []; + protected log = logger.child({ module: "key-provider", service: this.service }); public async init() { const storeName = this.store.constructor.name; @@ -48,10 +44,6 @@ export class GooglePalmKeyProvider implements KeyProvider { ); } - public list() { - return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); - } - public get(_model: GooglePalmModel) { const availableKeys = this.keys.filter((k) => !k.isDisabled); if (availableKeys.length === 0) { @@ -90,22 +82,6 @@ export class GooglePalmKeyProvider implements KeyProvider { return { ...selectedKey }; } - public disable(key: GooglePalmKey) { - const keyFromPool = this.keys.find((k) => k.hash === key.hash); - if (!keyFromPool || keyFromPool.isDisabled) return; - keyFromPool.isDisabled = true; - this.log.warn({ key: key.hash }, "Key disabled"); - } - - public update(hash: string, update: Partial) { - const keyFromPool = this.keys.find((k) => k.hash === hash)!; - Object.assign(keyFromPool, { lastChecked: Date.now(), ...update }); - } - - public available() { - return this.keys.filter((k) => !k.isDisabled).length; - } - public incrementUsage(hash: string, _model: string, tokens: number) { const key = this.keys.find((k) => k.hash === hash); if (!key) return; diff --git a/src/shared/key-management/serializers.ts b/src/shared/key-management/serializers.ts index 3a331be..e3f8270 100644 --- a/src/shared/key-management/serializers.ts +++ b/src/shared/key-management/serializers.ts @@ -1,8 +1,4 @@ import { assertNever } from "../utils"; -import { OpenAIKeySerializer } from "./openai/serializer"; -import { AnthropicKeySerializer } from "./anthropic/serializer"; -import { GooglePalmKeySerializer } from "./palm/serializer"; -import { AwsBedrockKeySerializer } from "./aws/serializer"; import { Key, KeySerializer, @@ -10,6 +6,10 @@ import { SerializedKey, ServiceToKey, } from "./types"; +import { OpenAIKeySerializer } from "./openai/serializer"; +import { AnthropicKeySerializer } from "./anthropic/serializer"; +import { GooglePalmKeySerializer } from "./palm/serializer"; +import { AwsBedrockKeySerializer } from "./aws/serializer"; export function assertSerializedKey(k: any): asserts k is SerializedKey { if (typeof k !== "object" || !k || typeof (k as any).key !== "string") { diff --git a/src/shared/key-management/stores/firebase.ts b/src/shared/key-management/stores/firebase.ts index 91298f0..262c0f1 100644 --- a/src/shared/key-management/stores/firebase.ts +++ b/src/shared/key-management/stores/firebase.ts @@ -32,7 +32,7 @@ export class FirebaseKeyStore implements KeyStore { this.serializer = serializer; this.service = service; this.pendingUpdates = new Map(); - this.schedulePeriodicFlush(); + this.scheduleFlush(); } public async load(isMigrating = false): Promise { @@ -55,17 +55,24 @@ export class FirebaseKeyStore implements KeyStore { } public add(key: K) { - throw new Error("Method not implemented."); + const serialized = this.serializer.serialize(key); + this.pendingUpdates.set(key.hash, serialized); + this.forceFlush(); } public update(id: string, update: Partial, force = false) { const existing = this.pendingUpdates.get(id) ?? {}; Object.assign(existing, this.serializer.partialSerialize(id, update)); this.pendingUpdates.set(id, existing); - if (force) setTimeout(() => this.flush(), 0); + if (force) this.forceFlush(); } - private schedulePeriodicFlush() { + private forceFlush() { + if (this.flushInterval) clearInterval(this.flushInterval); + this.flushInterval = setTimeout(() => this.flush(), 0); + } + + private scheduleFlush() { if (this.flushInterval) clearInterval(this.flushInterval); this.flushInterval = setInterval(() => this.flush(), 1000 * 60 * 5); } @@ -76,7 +83,7 @@ export class FirebaseKeyStore implements KeyStore { { pendingUpdates: this.pendingUpdates.size }, "Database not loaded yet. Skipping flush." ); - return; + return this.scheduleFlush(); } const updates: Record> = {}; @@ -85,11 +92,11 @@ export class FirebaseKeyStore implements KeyStore { await this.keysRef.update(updates); - this.log.info( + this.log.debug( { count: Object.keys(updates).length }, "Flushed pending key updates." ); - this.schedulePeriodicFlush(); + this.scheduleFlush(); } private async migrate(): Promise { diff --git a/src/shared/key-management/types.ts b/src/shared/key-management/types.ts index aae01b0..a4a9e42 100644 --- a/src/shared/key-management/types.ts +++ b/src/shared/key-management/types.ts @@ -32,9 +32,14 @@ export interface Key { service: LLMService; /** The model families that this key has access to. */ modelFamilies: ModelFamily[]; - /** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */ + /** Whether this key is currently disabled for some reason. */ isDisabled: boolean; - /** Whether this key specifically has been revoked. */ + /** + * Whether this key specifically has been revoked. This is different from + * `isDisabled` because a key can be disabled for other reasons, such as + * exceeding its quota. A revoked key is assumed to be permanently disabled, + * and KeyStore implementations should not return it when loading keys. + */ isRevoked: boolean; /** The number of prompts that have been sent with this key. */ promptCount: number; @@ -46,42 +51,14 @@ export interface Key { hash: string; } -export interface KeyProvider { - readonly service: LLMService; - - init(): Promise; - - get(model: Model): T; - - list(): Omit[]; - - disable(key: T): void; - - update(hash: string, update: Partial): void; - - available(): number; - - incrementUsage(hash: string, model: string, tokens: number): void; - - getLockoutPeriod(model: Model): number; - - markRateLimited(hash: string): void; - - recheck(): void; -} - export interface KeySerializer { serialize(keyObj: K): SerializedKey; - deserialize(serializedKey: SerializedKey): K; - partialSerialize(key: string, update: Partial): Partial; } export interface KeyStore { load(): Promise; - add(key: K): void; - update(id: string, update: Partial, force?: boolean): void; }