From ea2bfb9eef057ded84c22f769565cf8c5e50c366 Mon Sep 17 00:00:00 2001 From: nai-degen Date: Sun, 8 Oct 2023 04:21:49 -0500 Subject: [PATCH] implements most of firebasekeystore --- .../key-management/anthropic/provider.ts | 11 ++--- .../key-management/anthropic/serializer.ts | 24 ++++++---- src/shared/key-management/aws/provider.ts | 12 ++--- src/shared/key-management/aws/serializer.ts | 17 +++---- src/shared/key-management/index.ts | 5 +- src/shared/key-management/openai/provider.ts | 18 ++----- .../key-management/openai/serializer.ts | 17 +++---- src/shared/key-management/palm/provider.ts | 14 ++---- src/shared/key-management/palm/serializer.ts | 18 +++---- src/shared/key-management/serializers.ts | 47 ++++++++++++++++--- src/shared/key-management/stores/firebase.ts | 45 +++++++++++++++--- src/shared/key-management/stores/index.ts | 21 +-------- src/shared/key-management/stores/memory.ts | 4 +- 13 files changed, 146 insertions(+), 107 deletions(-) diff --git a/src/shared/key-management/anthropic/provider.ts b/src/shared/key-management/anthropic/provider.ts index c807b1e..eaa6c23 100644 --- a/src/shared/key-management/anthropic/provider.ts +++ b/src/shared/key-management/anthropic/provider.ts @@ -59,16 +59,13 @@ export class AnthropicKeyProvider implements KeyProvider { public async init() { const storeName = this.store.constructor.name; - const serializedKeys = await this.store.load(); + const loadedKeys = await this.store.load(); - if (serializedKeys.length === 0) { - return this.log.warn( - { via: storeName }, - "No Anthropic keys found. Anthropic API will not be available." - ); + if (loadedKeys.length === 0) { + return this.log.warn({ via: storeName }, "No Anthropic keys found."); } - this.keys.push(...serializedKeys.map(AnthropicKeySerializer.deserialize)); + this.keys.push(...loadedKeys); this.log.info( { count: this.keys.length, via: storeName }, "Loaded Anthropic keys." diff --git a/src/shared/key-management/anthropic/serializer.ts b/src/shared/key-management/anthropic/serializer.ts index 5c3b8f7..55c639e 100644 --- a/src/shared/key-management/anthropic/serializer.ts +++ b/src/shared/key-management/anthropic/serializer.ts @@ -1,15 +1,21 @@ import crypto from "crypto"; -import type { AnthropicKey } from "../index"; -import type { KeySerializer, SerializedKey } from "../stores"; +import type { AnthropicKey, SerializedKey } from "../index"; +import { KeySerializerBase } from "../serializers"; -const SERIALIZABLE_FIELDS = ["key", "service", "hash", "claudeTokens"] as const; +const SERIALIZABLE_FIELDS: (keyof AnthropicKey)[] = [ + "key", + "service", + "hash", + "claudeTokens", +]; export type SerializedAnthropicKey = SerializedKey & Partial>; -export const AnthropicKeySerializer: KeySerializer = { - serialize(key: AnthropicKey): SerializedAnthropicKey { - return { key: key.key }; // TODO: serialize other fields - }, +export class AnthropicKeySerializer extends KeySerializerBase { + constructor() { + super(SERIALIZABLE_FIELDS); + } + deserialize({ key, ...rest }: SerializedAnthropicKey): AnthropicKey { return { key, @@ -32,5 +38,5 @@ export const AnthropicKeySerializer: KeySerializer = { claudeTokens: 0, ...rest, }; - }, -}; + } +} diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index 88a3323..fc2d3a1 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -4,7 +4,6 @@ import type { AwsBedrockModelFamily } from "../../models"; import { Key, KeyProvider } from "../index"; import { KeyStore } from "../stores"; import { AwsKeyChecker } from "./checker"; -import { AwsBedrockKeySerializer } from "./serializer"; const RATE_LIMIT_LOCKOUT = 2000; const KEY_REUSE_DELAY = 500; @@ -51,16 +50,13 @@ export class AwsBedrockKeyProvider implements KeyProvider { public async init() { const storeName = this.store.constructor.name; - const serializedKeys = await this.store.load(); + const loadedKeys = await this.store.load(); - if (serializedKeys.length === 0) { - return this.log.warn( - { via: storeName }, - "No AWS credentials found. AWS Bedrock API will not be available." - ); + if (loadedKeys.length === 0) { + return this.log.warn({ via: storeName }, "No AWS credentials found."); } - this.keys.push(...serializedKeys.map(AwsBedrockKeySerializer.deserialize)); + this.keys.push(...loadedKeys); this.log.info( { count: this.keys.length, via: storeName }, "Loaded AWS Bedrock keys." diff --git a/src/shared/key-management/aws/serializer.ts b/src/shared/key-management/aws/serializer.ts index 091010a..86c7c74 100644 --- a/src/shared/key-management/aws/serializer.ts +++ b/src/shared/key-management/aws/serializer.ts @@ -1,6 +1,6 @@ import crypto from "crypto"; -import type { AwsBedrockKey } from "../index"; -import type { KeySerializer, SerializedKey } from "../stores"; +import type { AwsBedrockKey, SerializedKey } from "../index"; +import { KeySerializerBase } from "../serializers"; const SERIALIZABLE_FIELDS: (keyof AwsBedrockKey)[] = [ "key", @@ -11,10 +11,11 @@ const SERIALIZABLE_FIELDS: (keyof AwsBedrockKey)[] = [ export type SerializedAwsBedrockKey = SerializedKey & Partial>; -export const AwsBedrockKeySerializer: KeySerializer = { - serialize(key: AwsBedrockKey): SerializedAwsBedrockKey { - return { key: key.key }; - }, +export class AwsBedrockKeySerializer extends KeySerializerBase { + constructor() { + super(SERIALIZABLE_FIELDS); + } + deserialize(serializedKey: SerializedAwsBedrockKey): AwsBedrockKey { const { key, ...rest } = serializedKey; return { @@ -37,5 +38,5 @@ export const AwsBedrockKeySerializer: KeySerializer = { ["aws-claudeTokens"]: 0, ...rest, }; - }, -}; + } +} diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index b83622d..95c6530 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -80,4 +80,7 @@ export { export { AnthropicKey } from "./anthropic/provider"; export { OpenAIKey } from "./openai/provider"; export { GooglePalmKey } from "./palm/provider"; -export { AwsBedrockKey } from "./aws/provider"; \ No newline at end of file +export { AwsBedrockKey } from "./aws/provider"; +export { assertSerializedKey } from "./serializers"; +export { SerializedKey } from "./serializers"; +export { KeySerializer } from "./serializers"; \ No newline at end of file diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index 2f5421e..6d877d0 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -6,7 +6,6 @@ import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models"; import { Key, KeyProvider, Model } from "../index"; import { KeyStore } from "../stores"; import { OpenAIKeyChecker } from "./checker"; -import { OpenAIKeySerializer } from "./serializer"; const KEY_REUSE_DELAY = 1000; @@ -17,7 +16,7 @@ export const OPENAI_SUPPORTED_MODELS = [ "gpt-4-32k", "text-embedding-ada-002", ] as const; -export type OpenAIModel = typeof OPENAI_SUPPORTED_MODELS[number]; +export type OpenAIModel = (typeof OPENAI_SUPPORTED_MODELS)[number]; type OpenAIKeyUsage = { [K in OpenAIModelFamily as `${K}Tokens`]: number; @@ -76,23 +75,16 @@ export class OpenAIKeyProvider implements KeyProvider { public async init() { const storeName = this.store.constructor.name; - const serializedKeys = await this.store.load(); + const loadedKeys = await this.store.load(); - // TODO: If keystore is unavailable or returns no keys, instantiate a - // MemoryKeyStore and use the keys from process.env. Migrate them to the - // keystore when it becomes available. // TODO: after key management UI, keychecker should always be enabled // because keys may be added after initialization. - if (serializedKeys.length === 0) { - this.log.warn( - { via: storeName }, - "No OpenAI keys found. OpenAI API will not be available." - ); - return; + if (loadedKeys.length === 0) { + return this.log.warn({ via: storeName }, "No OpenAI keys found."); } - this.keys.push(...serializedKeys.map(OpenAIKeySerializer.deserialize)); + this.keys.push(...loadedKeys); this.log.info( { count: this.keys.length, via: storeName }, "Loaded OpenAI keys." diff --git a/src/shared/key-management/openai/serializer.ts b/src/shared/key-management/openai/serializer.ts index b7a21f6..02307e6 100644 --- a/src/shared/key-management/openai/serializer.ts +++ b/src/shared/key-management/openai/serializer.ts @@ -1,6 +1,6 @@ import crypto from "crypto"; -import type { OpenAIKey } from "../index"; -import type { KeySerializer, SerializedKey } from "../stores"; +import type { OpenAIKey, SerializedKey } from "../index"; +import { KeySerializerBase } from "../serializers"; const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [ "key", @@ -14,10 +14,11 @@ const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [ export type SerializedOpenAIKey = SerializedKey & Partial>; -export const OpenAIKeySerializer: KeySerializer = { - serialize(key: OpenAIKey): SerializedOpenAIKey { - return { key: key.key }; - }, +export class OpenAIKeySerializer extends KeySerializerBase { + constructor() { + super(SERIALIZABLE_FIELDS); + } + deserialize({ key, ...rest }: SerializedOpenAIKey): OpenAIKey { return { key, @@ -43,5 +44,5 @@ export const OpenAIKeySerializer: KeySerializer = { "gpt4-32kTokens": 0, ...rest, }; - }, -}; + } +} diff --git a/src/shared/key-management/palm/provider.ts b/src/shared/key-management/palm/provider.ts index 16c0c33..ec9178a 100644 --- a/src/shared/key-management/palm/provider.ts +++ b/src/shared/key-management/palm/provider.ts @@ -37,19 +37,15 @@ export class GooglePalmKeyProvider implements KeyProvider { public async init() { const storeName = this.store.constructor.name; - const serializedKeys = await this.store.load(); + const loadedKeys = await this.store.load(); - if (serializedKeys.length === 0) { - this.log.warn( - { via: storeName }, - "No PaLM keys found. PaLM API will not be available." - ); - return; + if (loadedKeys.length === 0) { + return this.log.warn({ via: storeName }, "No Google PaLM keys found."); } - this.keys.push(...serializedKeys.map(GooglePalmKeySerializer.deserialize)); + this.keys.push(...loadedKeys); this.log.info( - { keyCount: this.keys.length, via: storeName }, + { count: this.keys.length, via: storeName }, "Loaded PaLM keys." ); } diff --git a/src/shared/key-management/palm/serializer.ts b/src/shared/key-management/palm/serializer.ts index c8f390e..741eee2 100644 --- a/src/shared/key-management/palm/serializer.ts +++ b/src/shared/key-management/palm/serializer.ts @@ -1,7 +1,6 @@ import crypto from "crypto"; -import type { GooglePalmKey } from "../index"; -import type { KeySerializer } from "../stores"; -import { SerializedKey } from "../stores"; +import type { GooglePalmKey, SerializedKey } from "../index"; +import { KeySerializerBase } from "../serializers"; const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [ "key", @@ -12,10 +11,11 @@ const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [ export type SerializedGooglePalmKey = SerializedKey & Partial>; -export const GooglePalmKeySerializer: KeySerializer = { - serialize(key: GooglePalmKey): SerializedGooglePalmKey { - return { key: key.key }; - }, +export class GooglePalmKeySerializer extends KeySerializerBase { + constructor() { + super(SERIALIZABLE_FIELDS); + } + deserialize(serializedKey: SerializedGooglePalmKey): GooglePalmKey { const { key, ...rest } = serializedKey; return { @@ -37,5 +37,5 @@ export const GooglePalmKeySerializer: KeySerializer = { bisonTokens: 0, ...rest, }; - }, -}; + } +} diff --git a/src/shared/key-management/serializers.ts b/src/shared/key-management/serializers.ts index 81e73e9..076c6fb 100644 --- a/src/shared/key-management/serializers.ts +++ b/src/shared/key-management/serializers.ts @@ -1,21 +1,56 @@ -import { LLMService, Key } from "."; +import { Key, LLMService } from "."; import { assertNever } from "../utils"; -import { KeySerializer } from "./stores"; import { OpenAIKeySerializer } from "./openai/serializer"; import { AnthropicKeySerializer } from "./anthropic/serializer"; import { GooglePalmKeySerializer } from "./palm/serializer"; import { AwsBedrockKeySerializer } from "./aws/serializer"; +export type SerializedKey = { key: string }; + +export interface KeySerializer { + serialize(keyObj: K): SerializedKey; + deserialize(serializedKey: SerializedKey): K; + partialSerialize(key: string, update: Partial): Partial; +} + +export abstract class KeySerializerBase + implements KeySerializer +{ + protected constructor(protected serializableFields: (keyof K)[]) {} + + serialize(keyObj: K): SerializedKey { + return { + ...Object.fromEntries(this.serializableFields.map((f) => [f, keyObj[f]])), + key: keyObj.key, + }; + } + + partialSerialize(key: string, update: Partial): Partial { + return { + ...Object.fromEntries(this.serializableFields.map((f) => [f, update[f]])), + key, + }; + } + + abstract deserialize(serializedKey: SerializedKey): K; +} + +export function assertSerializedKey(k: any): asserts k is SerializedKey { + if (typeof k !== "object" || !k || typeof (k as any).key !== "string") { + throw new Error("Invalid serialized key data"); + } +} + export function getSerializer(service: LLMService): KeySerializer { switch (service) { case "openai": - return OpenAIKeySerializer; + return new OpenAIKeySerializer(); case "anthropic": - return AnthropicKeySerializer; + return new AnthropicKeySerializer(); case "google-palm": - return GooglePalmKeySerializer; + return new GooglePalmKeySerializer(); case "aws": - return AwsBedrockKeySerializer; + return new AwsBedrockKeySerializer(); default: assertNever(service); } diff --git a/src/shared/key-management/stores/firebase.ts b/src/shared/key-management/stores/firebase.ts index 73b2c6b..0710b42 100644 --- a/src/shared/key-management/stores/firebase.ts +++ b/src/shared/key-management/stores/firebase.ts @@ -1,15 +1,22 @@ import firebase from "firebase-admin"; import { config, getFirebaseApp } from "../../../config"; import { logger } from "../../../logger"; -import { Key, LLMService } from ".."; -import { assertSerializableKey, KeySerializer, KeyStore } from "."; +import { + assertSerializedKey, + Key, + KeySerializer, + LLMService, + SerializedKey, +} from "../index"; +import { KeyStore, MemoryKeyStore } from "./index"; export class FirebaseKeyStore implements KeyStore { private readonly db: firebase.database.Database; private readonly log: typeof logger; - private readonly pendingUpdates: Map> = new Map(); + private readonly pendingUpdates: Map>; private readonly root: string; private readonly serializer: KeySerializer; + private readonly service: LLMService; private flushInterval: NodeJS.Timeout | null = null; private keysRef: firebase.database.Reference | null = null; @@ -22,6 +29,8 @@ export class FirebaseKeyStore implements KeyStore { this.log = logger.child({ module: "firebase-key-store", service }); this.root = `keys/${config.firebaseRtdbRoot}/${service}`; this.serializer = serializer; + this.service = service; + this.pendingUpdates = new Map(); this.schedulePeriodicFlush(); } @@ -36,7 +45,7 @@ export class FirebaseKeyStore implements KeyStore { } const values = Object.values(keys).map((k) => { - assertSerializableKey(k); + assertSerializedKey(k); return this.serializer.deserialize(k); }); @@ -50,7 +59,7 @@ export class FirebaseKeyStore implements KeyStore { public update(id: string, update: Partial, force = false) { const existing = this.pendingUpdates.get(id) ?? {}; - Object.assign(existing, update); + Object.assign(existing, this.serializer.partialSerialize(id, update)); this.pendingUpdates.set(id, existing); if (force) setTimeout(() => this.flush(), 0); } @@ -68,11 +77,33 @@ export class FirebaseKeyStore implements KeyStore { ); return; } + + const updates: Record> = {}; + this.pendingUpdates.forEach((v, k) => (updates[k] = v)); + this.pendingUpdates.clear(); + + await this.keysRef.update(updates); + + this.log.info( + { count: Object.keys(updates).length }, + "Flushed pending key updates." + ); this.schedulePeriodicFlush(); } private async migrate() { - // TODO: If firebase is empty, try instantiating a MemoryKeyStore and - // loading keys from the environment. + const envStore = new MemoryKeyStore(this.service, this.serializer); + const keys = await envStore.load(); + + if (keys.length === 0) { + this.log.warn("No keys found in environment or Firebase."); + return; + } + + const updates: Record = {}; + keys.forEach((k) => (updates[k.hash] = this.serializer.serialize(k))); + await this.db.ref(this.root).update(updates); + + this.log.info({ count: keys.length }, "Migrated keys from environment."); } } diff --git a/src/shared/key-management/stores/index.ts b/src/shared/key-management/stores/index.ts index 8ba437f..e1c047a 100644 --- a/src/shared/key-management/stores/index.ts +++ b/src/shared/key-management/stores/index.ts @@ -1,29 +1,10 @@ import { Key } from ".."; export interface KeyStore { - load(): Promise; + load(): Promise; add(key: K): void; update(id: string, update: Partial, force?: boolean): void; } -export interface KeySerializer { - serialize(key: K): SerializedKey; - deserialize(key: SerializedKey): K; -} - -export type SerializedKey = { key: string }; - -export function assertSerializableKey( - data: unknown -): asserts data is SerializedKey { - if ( - typeof data !== "object" || - !data || - typeof (data as any).key !== "string" - ) { - throw new Error("Invalid serialized key data"); - } -} - export { FirebaseKeyStore } from "./firebase"; export { MemoryKeyStore } from "./memory"; diff --git a/src/shared/key-management/stores/memory.ts b/src/shared/key-management/stores/memory.ts index 971defd..2a257ed 100644 --- a/src/shared/key-management/stores/memory.ts +++ b/src/shared/key-management/stores/memory.ts @@ -1,6 +1,6 @@ import { assertNever } from "../../utils"; -import { LLMService, Key } from "../index"; -import { KeySerializer, KeyStore } from "."; +import { LLMService, Key, KeySerializer } from "../index"; +import { KeyStore } from "."; export class MemoryKeyStore implements KeyStore { private readonly env: string;