implements MemoryKeyStore; inject store when instantiating providers

This commit is contained in:
nai-degen
2023-09-10 18:38:13 -05:00
parent 4114dba4f5
commit 5d3433268f
9 changed files with 174 additions and 93 deletions
+9 -9
View File
@@ -5,16 +5,16 @@ import cors from "cors";
import path from "path";
import pinoHttp from "pino-http";
import childProcess from "child_process";
import { logger } from "./logger";
import { keyPool } from "./shared/key-management";
import { adminRouter } from "./admin/routes";
import { proxyRouter } from "./proxy/routes";
import { handleInfoPage } from "./info-page";
import { logQueue } from "./shared/prompt-logging";
import { start as startRequestQueue } from "./proxy/queue";
import { init as initUserStore } from "./shared/users/user-store";
import { init as initTokenizers } from "./shared/tokenization";
import { logger } from "./logger";
import { adminRouter } from "./admin/routes";
import { checkOrigin } from "./proxy/check-origin";
import { start as startRequestQueue } from "./proxy/queue";
import { proxyRouter } from "./proxy/routes";
import { init as initKeyPool } from "./shared/key-management";
import { logQueue } from "./shared/prompt-logging";
import { init as initTokenizers } from "./shared/tokenization";
import { init as initUserStore } from "./shared/users/user-store";
import { userRouter } from "./user/routes";
const PORT = config.port;
@@ -93,7 +93,7 @@ async function start() {
await assertConfigIsValid();
logger.info("Starting key pool...");
await keyPool.init();
await initKeyPool();
await initTokenizers();
+23 -17
View File
@@ -1,5 +1,5 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { Key, KeyProvider, KeyStore } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { AnthropicModelFamily } from "../../models";
@@ -70,29 +70,35 @@ const RATE_LIMIT_LOCKOUT = 2000;
const KEY_REUSE_DELAY = 500;
export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
readonly service = "anthropic";
readonly service = "anthropic" as const;
private keys: AnthropicKey[] = [];
private readonly keys: AnthropicKey[] = [];
private store: KeyStore<SerializableAnthropicKey>;
private checker?: AnthropicKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.anthropicKey?.trim();
if (!keyConfig) {
this.log.warn(
"ANTHROPIC_KEY is not set. Anthropic API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
this.keys.push(AnthropicKeyProvider.deserialize({ key }));
}
this.log.info({ keyCount: this.keys.length }, "Loaded Anthropic keys.");
constructor(store: KeyStore<SerializableAnthropicKey>) {
this.store = store;
}
public async init() {
const storeName = this.store.constructor.name;
const serializedKeys = await this.store.load();
if (serializedKeys.length === 0) {
this.log.warn(
{ via: storeName },
"No Anthropic keys found. Anthropic API will not be available."
);
return;
}
this.keys.push(...serializedKeys.map(AnthropicKeyProvider.deserialize));
this.log.info(
{ count: this.keys.length, via: storeName },
"Loaded Anthropic keys."
);
if (config.checkKeys) {
this.checker = new AnthropicKeyChecker(this.keys, this.update.bind(this));
this.checker.start();
+7 -2
View File
@@ -53,7 +53,7 @@ for service-agnostic functionality.
export interface KeyProvider<T extends Key = Key> {
readonly service: LLMService;
init(store: KeyStore<T>): Promise<void>;
init(): Promise<void>;
get(model: Model): T;
list(): Omit<T, "key">[];
disable(key: T): void;
@@ -71,7 +71,12 @@ export interface KeyStore<T extends Pick<Key, "key">> {
update(key: T): void;
}
export const keyPool = new KeyPool();
export let keyPool: KeyPool;
export async function init() {
keyPool = new KeyPool();
await keyPool.init();
}
export const SUPPORTED_MODELS = [
...OPENAI_SUPPORTED_MODELS,
...ANTHROPIC_SUPPORTED_MODELS,
+23 -10
View File
@@ -5,11 +5,12 @@ import schedule from "node-schedule";
import { config } from "../../config";
import { logger } from "../../logger";
import { Key, Model, KeyProvider, LLMService } from "./index";
import { GooglePalmKeyProvider } from "./palm/provider";
import { FirebaseKeyStore, MemoryKeyStore } from "./stores";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import { MemoryKeyStore } from "./stores/memory";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
@@ -20,21 +21,22 @@ export class KeyPool {
};
constructor() {
this.keyProviders.push(new OpenAIKeyProvider());
this.keyProviders.push(new AnthropicKeyProvider());
this.keyProviders.push(new GooglePalmKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new OpenAIKeyProvider(createKeyStore("openai")));
this.keyProviders.push(
new AnthropicKeyProvider(createKeyStore("anthropic"))
);
this.keyProviders.push(
new GooglePalmKeyProvider(createKeyStore("google-palm"))
);
// this.keyProviders.push(new AwsBedrockKeyProvider());
}
public async init() {
const KeyStore = MemoryKeyStore; // TODO: select based on config
await Promise.all(this.keyProviders.map((p) => p.init(new KeyStore())));
await Promise.all(this.keyProviders.map((p) => p.init()));
const availableKeys = this.available("all");
if (availableKeys === 0) {
throw new Error(
"No keys loaded. Ensure that at least one key is configured."
);
throw new Error("No keys loaded, the application cannot start.");
}
this.scheduleRecheck();
}
@@ -154,3 +156,14 @@ export class KeyPool {
this.recheckJobs.openai = job;
}
}
function createKeyStore(service: LLMService) {
switch (config.persistenceProvider) {
case "memory":
return new MemoryKeyStore(service);
case "firebase_rtdb":
return new FirebaseKeyStore(service);
default:
throw new Error(`Unknown store type: ${config.persistenceProvider}`);
}
}
+30 -17
View File
@@ -3,7 +3,7 @@ round-robin access to keys. Keys are stored in the OPENAI_KEY environment
variable as a comma-separated list of keys. */
import crypto from "crypto";
import http from "http";
import { Key, KeyProvider, Model } from "../index";
import { Key, KeyProvider, Model, KeyStore } from "../index";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { OpenAIKeyChecker } from "./checker";
@@ -66,7 +66,7 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
rateLimitTokensReset: number;
}
const SERIALIZABLE_FIELDS = [
const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [
"key",
"service",
"hash",
@@ -74,7 +74,7 @@ const SERIALIZABLE_FIELDS = [
"gpt4Tokens",
"gpt4-32kTokens",
"turboTokens",
] as const;
];
type SerializableOpenAIKey = Partial<
Pick<OpenAIKey, (typeof SERIALIZABLE_FIELDS)[number]>
> &
@@ -96,25 +96,38 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
readonly service = "openai" as const;
private readonly keys: OpenAIKey[] = [];
private store: KeyStore<SerializableOpenAIKey>;
private checker?: OpenAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyString = config.openaiKey?.trim();
if (!keyString) {
this.log.warn("OPENAI_KEY is not set. OpenAI API will not be available.");
return;
}
let bareKeys: string[];
bareKeys = keyString.split(",").map((k) => k.trim());
bareKeys = [...new Set(bareKeys)];
for (const k of bareKeys) {
this.keys.push(OpenAIKeyProvider.deserialize({ key: k }));
}
this.log.info({ keyCount: this.keys.length }, "Loaded OpenAI keys.");
constructor(store: KeyStore<SerializableOpenAIKey>) {
this.store = store;
}
public async init() {
const storeName = this.store.constructor.name;
const serializedKeys = await this.store.load();
// TODO: If keystore is unavailable or returns no keys, instantiate a
// MemoryKeyStore and use the keys from process.env. Migrate them to the
// keystore when it becomes available.
// TODO: after key management UI, keychecker should always be enabled
// because keys may be added after initialization.
if (serializedKeys.length === 0) {
this.log.warn(
{ via: storeName },
"No OpenAI keys found. OpenAI API will not be available."
);
return;
}
this.keys.push(...serializedKeys.map(OpenAIKeyProvider.deserialize));
this.log.info(
{ count: this.keys.length, via: storeName },
"Loaded OpenAI keys."
);
if (config.checkKeys) {
const cloneFn = this.clone.bind(this);
const updateFn = this.update.bind(this);
@@ -372,7 +385,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
static deserialize({ key, ...rest }: SerializableOpenAIKey): OpenAIKey {
return {
key,
service: "openai" as const,
service: "openai",
modelFamilies: ["turbo" as const, "gpt4" as const],
isTrial: false,
isDisabled: false,
+53 -32
View File
@@ -1,6 +1,5 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { Key, KeyProvider, KeyStore } from "..";
import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models";
@@ -34,6 +33,17 @@ export interface GooglePalmKey extends Key, GooglePalmKeyUsage {
rateLimitedUntil: number;
}
const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [
"key",
"service",
"hash",
"bisonTokens",
];
type SerializableGooglePalmKey = Partial<
Pick<GooglePalmKey, (typeof SERIALIZABLE_FIELDS)[number]>
> &
Pick<GooglePalmKey, "key">;
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
@@ -50,43 +60,31 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
readonly service = "google-palm";
private keys: GooglePalmKey[] = [];
private store: KeyStore<SerializableGooglePalmKey>;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.googlePalmKey?.trim();
if (!keyConfig) {
constructor(store: KeyStore<SerializableGooglePalmKey>) {
this.store = store;
}
public async init() {
const storeName = this.store.constructor.name;
const serializedKeys = await this.store.load();
if (serializedKeys.length === 0) {
this.log.warn(
"GOOGLE_PALM_KEY is not set. PaLM API will not be available."
{ via: storeName },
"No PaLM keys found. PaLM API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: GooglePalmKey = {
key,
service: this.service,
modelFamilies: ["bison"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `plm-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
bisonTokens: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys.");
}
public init() {}
this.keys.push(...serializedKeys.map(GooglePalmKeyProvider.deserialize));
this.log.info(
{ keyCount: this.keys.length, via: storeName },
"Loaded PaLM keys."
);
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
@@ -186,4 +184,27 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
}
public recheck() {}
static deserialize(serializedKey: SerializableGooglePalmKey): GooglePalmKey {
const { key, ...rest } = serializedKey;
return {
key,
service: "google-palm",
modelFamilies: ["bison"],
isTrial: false,
isDisabled: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `plm-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
bisonTokens: 0,
...rest,
};
}
}
+2 -2
View File
@@ -1,5 +1,5 @@
import type firebase from "firebase-admin";
import { Key, KeyStore } from "..";
import { AIService, Key, KeyStore } from "..";
import { getFirebaseApp } from "../../../config";
export class FirebaseKeyStore<K extends Pick<Key, "key">>
@@ -7,7 +7,7 @@ export class FirebaseKeyStore<K extends Pick<Key, "key">>
{
private db: firebase.database.Database;
constructor(app = getFirebaseApp()) {
constructor(service: AIService, app = getFirebaseApp()) {
this.db = app.database();
}
@@ -0,0 +1,2 @@
export { FirebaseKeyStore } from "./firebase";
export { MemoryKeyStore } from "./memory";
+25 -4
View File
@@ -1,11 +1,32 @@
import { Key, KeyStore } from "..";
import { APIFormat, Key, KeyStore } from "..";
export class MemoryKeyStore<K extends Pick<Key, "key">> implements KeyStore<K> {
constructor() {}
private env: string;
constructor(service: APIFormat) {
switch (service) {
case "anthropic":
this.env = "ANTHROPIC_KEY";
break;
case "openai":
case "openai-text":
this.env = "OPENAI_KEY";
break;
case "google-palm":
this.env = "GOOGLE_PALM_KEY";
break;
default:
const never: never = service;
throw new Error(`Unknown service: ${never}`);
}
}
public async load() {
// TODO: load from process.env
return [];
let bareKeys: string[];
bareKeys = [
...new Set(process.env[this.env]?.split(",").map((k) => k.trim())),
];
return bareKeys.map((key) => ({ key } as K));
}
public add(_key: K) {}