implements most of firebasekeystore

This commit is contained in:
nai-degen
2023-10-08 04:21:49 -05:00
parent 39436e7492
commit ea2bfb9eef
13 changed files with 146 additions and 107 deletions
@@ -59,16 +59,13 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
public async init() {
const storeName = this.store.constructor.name;
const serializedKeys = await this.store.load();
const loadedKeys = await this.store.load();
if (serializedKeys.length === 0) {
return this.log.warn(
{ via: storeName },
"No Anthropic keys found. Anthropic API will not be available."
);
if (loadedKeys.length === 0) {
return this.log.warn({ via: storeName }, "No Anthropic keys found.");
}
this.keys.push(...serializedKeys.map(AnthropicKeySerializer.deserialize));
this.keys.push(...loadedKeys);
this.log.info(
{ count: this.keys.length, via: storeName },
"Loaded Anthropic keys."
@@ -1,15 +1,21 @@
import crypto from "crypto";
import type { AnthropicKey } from "../index";
import type { KeySerializer, SerializedKey } from "../stores";
import type { AnthropicKey, SerializedKey } from "../index";
import { KeySerializerBase } from "../serializers";
const SERIALIZABLE_FIELDS = ["key", "service", "hash", "claudeTokens"] as const;
const SERIALIZABLE_FIELDS: (keyof AnthropicKey)[] = [
"key",
"service",
"hash",
"claudeTokens",
];
export type SerializedAnthropicKey = SerializedKey &
Partial<Pick<AnthropicKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export const AnthropicKeySerializer: KeySerializer<AnthropicKey> = {
serialize(key: AnthropicKey): SerializedAnthropicKey {
return { key: key.key }; // TODO: serialize other fields
},
export class AnthropicKeySerializer extends KeySerializerBase<AnthropicKey> {
constructor() {
super(SERIALIZABLE_FIELDS);
}
deserialize({ key, ...rest }: SerializedAnthropicKey): AnthropicKey {
return {
key,
@@ -32,5 +38,5 @@ export const AnthropicKeySerializer: KeySerializer<AnthropicKey> = {
claudeTokens: 0,
...rest,
};
},
};
}
}
+4 -8
View File
@@ -4,7 +4,6 @@ import type { AwsBedrockModelFamily } from "../../models";
import { Key, KeyProvider } from "../index";
import { KeyStore } from "../stores";
import { AwsKeyChecker } from "./checker";
import { AwsBedrockKeySerializer } from "./serializer";
const RATE_LIMIT_LOCKOUT = 2000;
const KEY_REUSE_DELAY = 500;
@@ -51,16 +50,13 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
public async init() {
const storeName = this.store.constructor.name;
const serializedKeys = await this.store.load();
const loadedKeys = await this.store.load();
if (serializedKeys.length === 0) {
return this.log.warn(
{ via: storeName },
"No AWS credentials found. AWS Bedrock API will not be available."
);
if (loadedKeys.length === 0) {
return this.log.warn({ via: storeName }, "No AWS credentials found.");
}
this.keys.push(...serializedKeys.map(AwsBedrockKeySerializer.deserialize));
this.keys.push(...loadedKeys);
this.log.info(
{ count: this.keys.length, via: storeName },
"Loaded AWS Bedrock keys."
+9 -8
View File
@@ -1,6 +1,6 @@
import crypto from "crypto";
import type { AwsBedrockKey } from "../index";
import type { KeySerializer, SerializedKey } from "../stores";
import type { AwsBedrockKey, SerializedKey } from "../index";
import { KeySerializerBase } from "../serializers";
const SERIALIZABLE_FIELDS: (keyof AwsBedrockKey)[] = [
"key",
@@ -11,10 +11,11 @@ const SERIALIZABLE_FIELDS: (keyof AwsBedrockKey)[] = [
export type SerializedAwsBedrockKey = SerializedKey &
Partial<Pick<AwsBedrockKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export const AwsBedrockKeySerializer: KeySerializer<AwsBedrockKey> = {
serialize(key: AwsBedrockKey): SerializedAwsBedrockKey {
return { key: key.key };
},
export class AwsBedrockKeySerializer extends KeySerializerBase<AwsBedrockKey> {
constructor() {
super(SERIALIZABLE_FIELDS);
}
deserialize(serializedKey: SerializedAwsBedrockKey): AwsBedrockKey {
const { key, ...rest } = serializedKey;
return {
@@ -37,5 +38,5 @@ export const AwsBedrockKeySerializer: KeySerializer<AwsBedrockKey> = {
["aws-claudeTokens"]: 0,
...rest,
};
},
};
}
}
+4 -1
View File
@@ -80,4 +80,7 @@ export {
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider";
export { AwsBedrockKey } from "./aws/provider";
export { AwsBedrockKey } from "./aws/provider";
export { assertSerializedKey } from "./serializers";
export { SerializedKey } from "./serializers";
export { KeySerializer } from "./serializers";
+5 -13
View File
@@ -6,7 +6,6 @@ import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
import { Key, KeyProvider, Model } from "../index";
import { KeyStore } from "../stores";
import { OpenAIKeyChecker } from "./checker";
import { OpenAIKeySerializer } from "./serializer";
const KEY_REUSE_DELAY = 1000;
@@ -17,7 +16,7 @@ export const OPENAI_SUPPORTED_MODELS = [
"gpt-4-32k",
"text-embedding-ada-002",
] as const;
export type OpenAIModel = typeof OPENAI_SUPPORTED_MODELS[number];
export type OpenAIModel = (typeof OPENAI_SUPPORTED_MODELS)[number];
type OpenAIKeyUsage = {
[K in OpenAIModelFamily as `${K}Tokens`]: number;
@@ -76,23 +75,16 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
public async init() {
const storeName = this.store.constructor.name;
const serializedKeys = await this.store.load();
const loadedKeys = await this.store.load();
// TODO: If keystore is unavailable or returns no keys, instantiate a
// MemoryKeyStore and use the keys from process.env. Migrate them to the
// keystore when it becomes available.
// TODO: after key management UI, keychecker should always be enabled
// because keys may be added after initialization.
if (serializedKeys.length === 0) {
this.log.warn(
{ via: storeName },
"No OpenAI keys found. OpenAI API will not be available."
);
return;
if (loadedKeys.length === 0) {
return this.log.warn({ via: storeName }, "No OpenAI keys found.");
}
this.keys.push(...serializedKeys.map(OpenAIKeySerializer.deserialize));
this.keys.push(...loadedKeys);
this.log.info(
{ count: this.keys.length, via: storeName },
"Loaded OpenAI keys."
@@ -1,6 +1,6 @@
import crypto from "crypto";
import type { OpenAIKey } from "../index";
import type { KeySerializer, SerializedKey } from "../stores";
import type { OpenAIKey, SerializedKey } from "../index";
import { KeySerializerBase } from "../serializers";
const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [
"key",
@@ -14,10 +14,11 @@ const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [
export type SerializedOpenAIKey = SerializedKey &
Partial<Pick<OpenAIKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export const OpenAIKeySerializer: KeySerializer<OpenAIKey> = {
serialize(key: OpenAIKey): SerializedOpenAIKey {
return { key: key.key };
},
export class OpenAIKeySerializer extends KeySerializerBase<OpenAIKey> {
constructor() {
super(SERIALIZABLE_FIELDS);
}
deserialize({ key, ...rest }: SerializedOpenAIKey): OpenAIKey {
return {
key,
@@ -43,5 +44,5 @@ export const OpenAIKeySerializer: KeySerializer<OpenAIKey> = {
"gpt4-32kTokens": 0,
...rest,
};
},
};
}
}
+5 -9
View File
@@ -37,19 +37,15 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
public async init() {
const storeName = this.store.constructor.name;
const serializedKeys = await this.store.load();
const loadedKeys = await this.store.load();
if (serializedKeys.length === 0) {
this.log.warn(
{ via: storeName },
"No PaLM keys found. PaLM API will not be available."
);
return;
if (loadedKeys.length === 0) {
return this.log.warn({ via: storeName }, "No Google PaLM keys found.");
}
this.keys.push(...serializedKeys.map(GooglePalmKeySerializer.deserialize));
this.keys.push(...loadedKeys);
this.log.info(
{ keyCount: this.keys.length, via: storeName },
{ count: this.keys.length, via: storeName },
"Loaded PaLM keys."
);
}
+9 -9
View File
@@ -1,7 +1,6 @@
import crypto from "crypto";
import type { GooglePalmKey } from "../index";
import type { KeySerializer } from "../stores";
import { SerializedKey } from "../stores";
import type { GooglePalmKey, SerializedKey } from "../index";
import { KeySerializerBase } from "../serializers";
const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [
"key",
@@ -12,10 +11,11 @@ const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [
export type SerializedGooglePalmKey = SerializedKey &
Partial<Pick<GooglePalmKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export const GooglePalmKeySerializer: KeySerializer<GooglePalmKey> = {
serialize(key: GooglePalmKey): SerializedGooglePalmKey {
return { key: key.key };
},
export class GooglePalmKeySerializer extends KeySerializerBase<GooglePalmKey> {
constructor() {
super(SERIALIZABLE_FIELDS);
}
deserialize(serializedKey: SerializedGooglePalmKey): GooglePalmKey {
const { key, ...rest } = serializedKey;
return {
@@ -37,5 +37,5 @@ export const GooglePalmKeySerializer: KeySerializer<GooglePalmKey> = {
bisonTokens: 0,
...rest,
};
},
};
}
}
+41 -6
View File
@@ -1,21 +1,56 @@
import { LLMService, Key } from ".";
import { Key, LLMService } from ".";
import { assertNever } from "../utils";
import { KeySerializer } from "./stores";
import { OpenAIKeySerializer } from "./openai/serializer";
import { AnthropicKeySerializer } from "./anthropic/serializer";
import { GooglePalmKeySerializer } from "./palm/serializer";
import { AwsBedrockKeySerializer } from "./aws/serializer";
export type SerializedKey = { key: string };
export interface KeySerializer<K> {
serialize(keyObj: K): SerializedKey;
deserialize(serializedKey: SerializedKey): K;
partialSerialize(key: string, update: Partial<K>): Partial<SerializedKey>;
}
export abstract class KeySerializerBase<K extends Key>
implements KeySerializer<K>
{
protected constructor(protected serializableFields: (keyof K)[]) {}
serialize(keyObj: K): SerializedKey {
return {
...Object.fromEntries(this.serializableFields.map((f) => [f, keyObj[f]])),
key: keyObj.key,
};
}
partialSerialize(key: string, update: Partial<K>): Partial<SerializedKey> {
return {
...Object.fromEntries(this.serializableFields.map((f) => [f, update[f]])),
key,
};
}
abstract deserialize(serializedKey: SerializedKey): K;
}
export function assertSerializedKey(k: any): asserts k is SerializedKey {
if (typeof k !== "object" || !k || typeof (k as any).key !== "string") {
throw new Error("Invalid serialized key data");
}
}
export function getSerializer(service: LLMService): KeySerializer<Key> {
switch (service) {
case "openai":
return OpenAIKeySerializer;
return new OpenAIKeySerializer();
case "anthropic":
return AnthropicKeySerializer;
return new AnthropicKeySerializer();
case "google-palm":
return GooglePalmKeySerializer;
return new GooglePalmKeySerializer();
case "aws":
return AwsBedrockKeySerializer;
return new AwsBedrockKeySerializer();
default:
assertNever(service);
}
+38 -7
View File
@@ -1,15 +1,22 @@
import firebase from "firebase-admin";
import { config, getFirebaseApp } from "../../../config";
import { logger } from "../../../logger";
import { Key, LLMService } from "..";
import { assertSerializableKey, KeySerializer, KeyStore } from ".";
import {
assertSerializedKey,
Key,
KeySerializer,
LLMService,
SerializedKey,
} from "../index";
import { KeyStore, MemoryKeyStore } from "./index";
export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
private readonly db: firebase.database.Database;
private readonly log: typeof logger;
private readonly pendingUpdates: Map<string, Partial<K>> = new Map();
private readonly pendingUpdates: Map<string, Partial<SerializedKey>>;
private readonly root: string;
private readonly serializer: KeySerializer<K>;
private readonly service: LLMService;
private flushInterval: NodeJS.Timeout | null = null;
private keysRef: firebase.database.Reference | null = null;
@@ -22,6 +29,8 @@ export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
this.log = logger.child({ module: "firebase-key-store", service });
this.root = `keys/${config.firebaseRtdbRoot}/${service}`;
this.serializer = serializer;
this.service = service;
this.pendingUpdates = new Map();
this.schedulePeriodicFlush();
}
@@ -36,7 +45,7 @@ export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
}
const values = Object.values(keys).map((k) => {
assertSerializableKey(k);
assertSerializedKey(k);
return this.serializer.deserialize(k);
});
@@ -50,7 +59,7 @@ export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
public update(id: string, update: Partial<K>, force = false) {
const existing = this.pendingUpdates.get(id) ?? {};
Object.assign(existing, update);
Object.assign(existing, this.serializer.partialSerialize(id, update));
this.pendingUpdates.set(id, existing);
if (force) setTimeout(() => this.flush(), 0);
}
@@ -68,11 +77,33 @@ export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
);
return;
}
const updates: Record<string, Partial<SerializedKey>> = {};
this.pendingUpdates.forEach((v, k) => (updates[k] = v));
this.pendingUpdates.clear();
await this.keysRef.update(updates);
this.log.info(
{ count: Object.keys(updates).length },
"Flushed pending key updates."
);
this.schedulePeriodicFlush();
}
private async migrate() {
// TODO: If firebase is empty, try instantiating a MemoryKeyStore and
// loading keys from the environment.
const envStore = new MemoryKeyStore<K>(this.service, this.serializer);
const keys = await envStore.load();
if (keys.length === 0) {
this.log.warn("No keys found in environment or Firebase.");
return;
}
const updates: Record<string, SerializedKey> = {};
keys.forEach((k) => (updates[k.hash] = this.serializer.serialize(k)));
await this.db.ref(this.root).update(updates);
this.log.info({ count: keys.length }, "Migrated keys from environment.");
}
}
+1 -20
View File
@@ -1,29 +1,10 @@
import { Key } from "..";
export interface KeyStore<K extends Key> {
load(): Promise<SerializedKey[]>;
load(): Promise<K[]>;
add(key: K): void;
update(id: string, update: Partial<K>, force?: boolean): void;
}
export interface KeySerializer<K> {
serialize(key: K): SerializedKey;
deserialize(key: SerializedKey): K;
}
export type SerializedKey = { key: string };
export function assertSerializableKey(
data: unknown
): asserts data is SerializedKey {
if (
typeof data !== "object" ||
!data ||
typeof (data as any).key !== "string"
) {
throw new Error("Invalid serialized key data");
}
}
export { FirebaseKeyStore } from "./firebase";
export { MemoryKeyStore } from "./memory";
+2 -2
View File
@@ -1,6 +1,6 @@
import { assertNever } from "../../utils";
import { LLMService, Key } from "../index";
import { KeySerializer, KeyStore } from ".";
import { LLMService, Key, KeySerializer } from "../index";
import { KeyStore } from ".";
export class MemoryKeyStore<K extends Key> implements KeyStore<K> {
private readonly env: string;