implements generic key serialization/deserialization

This commit is contained in:
nai-degen
2023-09-20 20:47:51 -05:00
parent f53e328398
commit 05ab8c37eb
13 changed files with 263 additions and 210 deletions
+2 -2
View File
@@ -16,7 +16,7 @@
*/
import type { Handler, Request } from "express";
import { keyPool, SupportedModel } from "../shared/key-management";
import { keyPool } from "../shared/key-management";
import {
getClaudeModelFamily,
getGooglePalmModelFamily,
@@ -138,7 +138,7 @@ function getPartitionForRequest(req: Request): ModelFamily {
// There is a single request queue, but it is partitioned by model family.
// Model families are typically separated on cost/rate limit boundaries so
// they should be treated as separate queues.
const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo";
const model = req.body.model ?? "gpt-3.5-turbo";
// Weird special case for AWS because they serve multiple models from
// different vendors, even if currently only one is supported.
+19 -44
View File
@@ -1,12 +1,23 @@
import crypto from "crypto";
import { BaseSerializableKey, Key, KeyProvider } from "..";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { AnthropicModelFamily } from "../../models";
import { KeyStore } from "../stores";
import { KeyStore, SerializedKey } from "../stores";
import { AnthropicKeyChecker } from "./checker";
import { AnthropicKeySerializer } from "./serializer";
// https://docs.anthropic.com/claude/reference/selecting-a-model
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
/* https://docs.anthropic.com/claude/reference/selecting-a-model */
export const ANTHROPIC_SUPPORTED_MODELS = [
"claude-instant-v1",
"claude-instant-v1-100k",
@@ -21,7 +32,7 @@ type AnthropicKeyUsage = {
};
const SERIALIZABLE_FIELDS = ["key", "service", "hash", "claudeTokens"] as const;
type SerializableAnthropicKey = BaseSerializableKey &
export type SerializedAnthropicKey = SerializedKey &
Partial<Pick<AnthropicKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export type AnthropicKeyUpdate = Omit<
@@ -56,27 +67,15 @@ export interface AnthropicKey extends Key, AnthropicKeyUsage {
isPozzed: boolean;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
readonly service = "anthropic" as const;
private readonly keys: AnthropicKey[] = [];
private store: KeyStore<SerializableAnthropicKey>;
private store: KeyStore<AnthropicKey>;
private checker?: AnthropicKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor(store: KeyStore<SerializableAnthropicKey>) {
constructor(store: KeyStore<AnthropicKey>) {
this.store = store;
}
@@ -92,7 +91,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
return;
}
this.keys.push(...serializedKeys.map(AnthropicKeyProvider.deserialize));
this.keys.push(...serializedKeys.map(AnthropicKeySerializer.deserialize));
this.log.info(
{ count: this.keys.length, via: storeName },
"Loaded Anthropic keys."
@@ -217,28 +216,4 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
});
this.checker?.scheduleNextCheck();
}
static deserialize({ key, ...rest }: SerializableAnthropicKey): AnthropicKey {
return {
key,
service: "anthropic" as const,
modelFamilies: ["claude" as const],
isTrial: false,
isDisabled: false,
isPozzed: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
requiresPreamble: false,
hash: `ant-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
claudeTokens: 0,
...rest,
};
}
}
@@ -0,0 +1,33 @@
import crypto from "crypto";
import { AnthropicKey } from "..";
import { KeySerializer } from "../stores";
import { SerializedAnthropicKey } from "./provider";
export const AnthropicKeySerializer: KeySerializer<AnthropicKey> = {
serialize(key: AnthropicKey): SerializedAnthropicKey {
return { key: key.key }; // TODO: serialize other fields
},
deserialize({ key, ...rest }: SerializedAnthropicKey): AnthropicKey {
return {
key,
service: "anthropic" as const,
modelFamilies: ["claude" as const],
isTrial: false,
isDisabled: false,
isPozzed: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
requiresPreamble: false,
hash: `ant-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
claudeTokens: 0,
...rest,
};
},
};
-9
View File
@@ -39,10 +39,6 @@ export interface Key {
hash: string;
}
export interface BaseSerializableKey {
key: string;
}
/*
KeyPool and KeyProvider's similarities are a relic of the old design where
there was only a single KeyPool for OpenAI keys. Now that there are multiple
@@ -75,11 +71,6 @@ export async function init() {
await keyPool.init();
}
export const SUPPORTED_MODELS = [
...OPENAI_SUPPORTED_MODELS,
...ANTHROPIC_SUPPORTED_MODELS,
] as const;
export type SupportedModel = (typeof SUPPORTED_MODELS)[number];
export {
OPENAI_SUPPORTED_MODELS,
ANTHROPIC_SUPPORTED_MODELS,
+9 -9
View File
@@ -5,8 +5,8 @@ import schedule from "node-schedule";
import { config } from "../../config";
import { logger } from "../../logger";
import { Key, Model, KeyProvider, LLMService } from "./index";
import { GooglePalmKeyProvider } from "./palm/provider";
import { FirebaseKeyStore, MemoryKeyStore } from "./stores";
import { getSerializer } from "./serializers";
import { FirebaseKeyStore, KeyStore, MemoryKeyStore } from "./stores";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider";
@@ -21,11 +21,9 @@ export class KeyPool {
};
constructor() {
this.keyProviders.push(new OpenAIKeyProvider(createKeyStore("openai")));
this.keyProviders.push(
new AnthropicKeyProvider(createKeyStore("anthropic"))
);
this.keyProviders.push(
new OpenAIKeyProvider(createKeyStore("openai")),
new AnthropicKeyProvider(createKeyStore("anthropic")),
new GooglePalmKeyProvider(createKeyStore("google-palm"))
);
// this.keyProviders.push(new AwsBedrockKeyProvider());
@@ -157,12 +155,14 @@ export class KeyPool {
}
}
function createKeyStore(service: LLMService) {
function createKeyStore(service: LLMService): KeyStore<Key> {
const serializer = getSerializer(service);
switch (config.persistenceProvider) {
case "memory":
return new MemoryKeyStore(service);
return new MemoryKeyStore(service, serializer);
case "firebase_rtdb":
return new FirebaseKeyStore(service);
return new FirebaseKeyStore(service, serializer);
default:
throw new Error(`Unknown store type: ${config.persistenceProvider}`);
}
+8 -36
View File
@@ -6,9 +6,10 @@ import http from "http";
import { Key, KeyProvider, Model } from "../index";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { OpenAIKeyChecker } from "./checker";
import { KeyStore } from "../stores";
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
import { KeyStore, SerializedKey } from "../stores";
import { OpenAIKeyChecker } from "./checker";
import { OpenAIKeySerializer } from "./serializer";
export type OpenAIModel =
| "gpt-3.5-turbo"
@@ -76,10 +77,8 @@ const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [
"gpt4-32kTokens",
"turboTokens",
];
type SerializableOpenAIKey = Partial<
Pick<OpenAIKey, (typeof SERIALIZABLE_FIELDS)[number]>
> &
Pick<OpenAIKey, "key">;
export type SerializedOpenAIKey = SerializedKey &
Partial<Pick<OpenAIKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export type OpenAIKeyUpdate = Omit<
Partial<OpenAIKey>,
@@ -97,11 +96,11 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
readonly service = "openai" as const;
private readonly keys: OpenAIKey[] = [];
private store: KeyStore<SerializableOpenAIKey>;
private store: KeyStore<OpenAIKey>;
private checker?: OpenAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor(store: KeyStore<SerializableOpenAIKey>) {
constructor(store: KeyStore<OpenAIKey>) {
this.store = store;
}
@@ -123,7 +122,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
return;
}
this.keys.push(...serializedKeys.map(OpenAIKeyProvider.deserialize));
this.keys.push(...serializedKeys.map(OpenAIKeySerializer.deserialize));
this.log.info(
{ count: this.keys.length, via: storeName },
"Loaded OpenAI keys."
@@ -382,33 +381,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
});
this.checker?.scheduleNextCheck();
}
static deserialize({ key, ...rest }: SerializableOpenAIKey): OpenAIKey {
return {
key,
service: "openai",
modelFamilies: ["turbo" as const, "gpt4" as const],
isTrial: false,
isDisabled: false,
isRevoked: false,
isOverQuota: false,
lastUsed: 0,
lastChecked: 0,
promptCount: 0,
hash: `oai-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
rateLimitedAt: 0,
rateLimitRequestsReset: 0,
rateLimitTokensReset: 0,
turboTokens: 0,
gpt4Tokens: 0,
"gpt4-32kTokens": 0,
...rest,
};
}
}
/**
@@ -0,0 +1,36 @@
import crypto from "crypto";
import { OpenAIKey } from "..";
import { KeySerializer } from "../stores";
import { SerializedOpenAIKey } from "./provider";
export const OpenAIKeySerializer: KeySerializer<OpenAIKey> = {
serialize(key: OpenAIKey): SerializedOpenAIKey {
return { key: key.key };
},
deserialize({ key, ...rest }: SerializedOpenAIKey): OpenAIKey {
return {
key,
service: "openai",
modelFamilies: ["turbo" as const, "gpt4" as const],
isTrial: false,
isDisabled: false,
isRevoked: false,
isOverQuota: false,
lastUsed: 0,
lastChecked: 0,
promptCount: 0,
hash: `oai-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
rateLimitedAt: 0,
rateLimitRequestsReset: 0,
rateLimitTokensReset: 0,
turboTokens: 0,
gpt4Tokens: 0,
"gpt4-32kTokens": 0,
...rest,
};
},
};
+8 -32
View File
@@ -1,7 +1,8 @@
import crypto from "crypto";
import { Key, KeyProvider, KeyStore } from "..";
import { Key, KeyProvider } from "..";
import { KeyStore, SerializedKey } from "../stores";
import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models";
import { GooglePalmKeySerializer } from "./serializer";
// https://developers.generativeai.google.com/models/language
export const GOOGLE_PALM_SUPPORTED_MODELS = [
@@ -39,10 +40,8 @@ const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [
"hash",
"bisonTokens",
];
type SerializableGooglePalmKey = Partial<
Pick<GooglePalmKey, (typeof SERIALIZABLE_FIELDS)[number]>
> &
Pick<GooglePalmKey, "key">;
export type SerializedGooglePalmKey = SerializedKey &
Partial<Pick<GooglePalmKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
@@ -60,10 +59,10 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
readonly service = "google-palm";
private keys: GooglePalmKey[] = [];
private store: KeyStore<SerializableGooglePalmKey>;
private store: KeyStore<GooglePalmKey>;
private log = logger.child({ module: "key-provider", service: this.service });
constructor(store: KeyStore<SerializableGooglePalmKey>) {
constructor(store: KeyStore<GooglePalmKey>) {
this.store = store;
}
@@ -79,7 +78,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return;
}
this.keys.push(...serializedKeys.map(GooglePalmKeyProvider.deserialize));
this.keys.push(...serializedKeys.map(GooglePalmKeySerializer.deserialize));
this.log.info(
{ keyCount: this.keys.length, via: storeName },
"Loaded PaLM keys."
@@ -184,27 +183,4 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
}
public recheck() {}
static deserialize(serializedKey: SerializableGooglePalmKey): GooglePalmKey {
const { key, ...rest } = serializedKey;
return {
key,
service: "google-palm",
modelFamilies: ["bison"],
isTrial: false,
isDisabled: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `plm-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
bisonTokens: 0,
...rest,
};
}
}
@@ -0,0 +1,32 @@
import crypto from "crypto";
import { GooglePalmKey } from "..";
import { KeySerializer } from "../stores";
import { SerializedGooglePalmKey } from "./provider";
export const GooglePalmKeySerializer: KeySerializer<GooglePalmKey> = {
serialize(key: GooglePalmKey): SerializedGooglePalmKey {
return { key: key.key };
},
deserialize(serializedKey: SerializedGooglePalmKey): GooglePalmKey {
const { key, ...rest } = serializedKey;
return {
key,
service: "google-palm" as const,
modelFamilies: ["bison"],
isTrial: false,
isDisabled: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `plm-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
bisonTokens: 0,
...rest,
};
},
};
+20
View File
@@ -0,0 +1,20 @@
import { APIFormat, Key } from ".";
import { assertNever } from "../utils";
import { KeySerializer } from "./stores";
import { OpenAIKeySerializer } from "./openai/serializer";
import { AnthropicKeySerializer } from "./anthropic/serializer";
import { GooglePalmKeySerializer } from "./palm/serializer";
export function getSerializer(service: APIFormat): KeySerializer<Key> {
switch (service) {
case "openai":
case "openai-text":
return OpenAIKeySerializer;
case "anthropic":
return AnthropicKeySerializer;
case "google-palm":
return GooglePalmKeySerializer;
default:
assertNever(service);
}
}
+56 -34
View File
@@ -1,54 +1,76 @@
import type firebase from "firebase-admin";
import firebase from "firebase-admin";
import { getFirebaseApp } from "../../../config";
import { logger } from "../../../logger";
import { KeyDeserializer, KeyStore, MemoryKeyStore, getDeserializer } from ".";
import { AIService, BaseSerializableKey } from "..";
import { APIFormat, Key } from "..";
import { KeyStore, assertSerializableKey } from ".";
import { KeySerializer } from ".";
export class FirebaseKeyStore<K extends BaseSerializableKey>
implements KeyStore<K>
{
private db: firebase.database.Database;
private service: AIService;
export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
private log: typeof logger;
private deserializer: KeyDeserializer;
private db: firebase.database.Database;
private keysRef: firebase.database.Reference | null = null;
private pendingUpdates: Map<string, Partial<K>> = new Map();
private flushInterval: NodeJS.Timeout | null = null;
constructor(service: AIService, app = getFirebaseApp()) {
this.db = app.database();
constructor(
private service: APIFormat,
private serializer: KeySerializer<K>,
app = getFirebaseApp()
) {
this.db = firebase.database(app);
this.service = service;
this.log = logger.child({ module: "key-store", service });
this.deserializer = getDeserializer(service);
this.log = logger.child({ module: "firebase-key-store", service });
this.schedulePeriodicFlush();
}
public async load() {
throw new Error("Method not implemented.");
return [];
const keysRef = this.db.ref(`keys/${this.service}`);
const snapshot = await keysRef.once("value");
const keys = snapshot.val();
if (!keys) {
this.log.warn("No keys found in Firebase. Migrating from environment.");
await this.migrate();
}
const values = Object.values(keys).map((k) => {
assertSerializableKey(k);
return this.serializer.deserialize(k);
});
this.keysRef = keysRef;
return values;
}
public add(key: K) {
throw new Error("Method not implemented.");
}
public update(key: K) {
throw new Error("Method not implemented.");
public update(id: string, update: Partial<K>, force = false) {
const existing = this.pendingUpdates.get(id) ?? {};
Object.assign(existing, update);
this.pendingUpdates.set(id, existing);
if (force) setTimeout(() => this.flush(), 0);
}
private schedulePeriodicFlush() {
if (this.flushInterval) clearInterval(this.flushInterval);
this.flushInterval = setInterval(() => this.flush(), 1000 * 60 * 5);
}
private async flush() {
if (!this.keysRef) {
this.log.warn(
{ pendingUpdates: this.pendingUpdates.size },
"Database not loaded yet. Skipping flush."
);
return;
}
this.schedulePeriodicFlush();
}
private async migrate() {
this.log.info("Migrating keys from environment to Firebase.");
const envStore = new MemoryKeyStore(this.service);
const keysRef = this.db.ref(`keys/${this.service}`);
const updates: Record<string, K> = {};
const keys = await envStore.load();
keys.forEach((key) => {
updates[key.key] = this.deserializer(key);
});
// envStore.load().then((keys) => {
// keys.forEach((key) => {
// updates[key.key] = key;
// });
// keysRef.update(updates);
// });
// TODO: If firebase is empty, try instantiating a MemoryKeyStore and
// loading keys from the environment.
}
}
+26 -29
View File
@@ -1,32 +1,29 @@
import { AIService, Key } from "..";
import { AnthropicKeyProvider } from "../anthropic/provider";
import { OpenAIKeyProvider } from "../openai/provider";
import { Key } from "..";
export interface KeyStore<K extends Key> {
load(): Promise<SerializedKey[]>;
add(key: K): void;
update(id: string, update: Partial<K>, force?: boolean): void;
}
export interface KeySerializer<K> {
serialize(key: K): SerializedKey;
deserialize(key: SerializedKey): K;
}
export type SerializedKey = { key: string };
export function assertSerializableKey(
data: unknown
): asserts data is SerializedKey {
if (
typeof data !== "object" ||
!data ||
typeof (data as any).key !== "string"
) {
throw new Error("Invalid serialized key data");
}
}
export { FirebaseKeyStore } from "./firebase";
export { MemoryKeyStore } from "./memory";
export interface KeyStore<T extends Pick<Key, "key">> {
load(): Promise<T[]>;
add(key: T): void;
update(key: T): void;
}
interface BaseSerializableKey {
key: string;
}
export type KeyDeserializer =
| typeof AnthropicKeyProvider.deserialize
| typeof OpenAIKeyProvider.deserialize;
export function getDeserializer(service: AIService): KeyDeserializer {
switch (service) {
case "anthropic":
return AnthropicKeyProvider.deserialize;
case "openai":
return OpenAIKeyProvider.deserialize;
default:
const never: never = service;
throw new Error(`Unknown service: ${never}`);
}
}
+14 -15
View File
@@ -1,13 +1,12 @@
import { KeyDeserializer, KeyStore, getDeserializer } from ".";
import { APIFormat, BaseSerializableKey } from "..";
import { assertNever } from "../../utils";
import { APIFormat, Key } from "..";
import { KeySerializer } from ".";
import { KeyStore } from ".";
export class MemoryKeyStore<K extends BaseSerializableKey>
implements KeyStore<K>
{
export class MemoryKeyStore<K extends Key> implements KeyStore<K> {
private env: string;
private deserializer: KeyDeserializer;
constructor(service: APIFormat) {
constructor(service: APIFormat, private serializer: KeySerializer<K>) {
switch (service) {
case "anthropic":
this.env = "ANTHROPIC_KEY";
@@ -20,21 +19,21 @@ export class MemoryKeyStore<K extends BaseSerializableKey>
this.env = "GOOGLE_PALM_KEY";
break;
default:
const never: never = service;
throw new Error(`Unknown service: ${never}`);
assertNever(service);
}
this.deserializer = getDeserializer(service);
}
public async load() {
let bareKeys: string[];
bareKeys = [
let envKeys: string[];
envKeys = [
...new Set(process.env[this.env]?.split(",").map((k) => k.trim())),
];
return bareKeys.map((key) => this.deserializer({ key }));
return envKeys
.filter((k) => k)
.map((k) => this.serializer.deserialize({ key: k }));
}
public add(_key: K) {}
public add() {}
public update(_key: K) {}
public update() {}
}