consolidates some duplicated keyprovider stuff

This commit is contained in:
nai-degen
2023-10-09 00:03:46 -05:00
parent df2e986366
commit 00402c8310
11 changed files with 104 additions and 191 deletions
@@ -1,7 +1,8 @@
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { AnthropicModelFamily } from "../../models";
import { Key, KeyProvider, KeyStore } from "../types";
import { KeyProviderBase } from "../key-provider-base";
import { Key } from "../types";
import { AnthropicKeyChecker } from "./checker";
const RATE_LIMIT_LOCKOUT = 2000;
@@ -43,17 +44,12 @@ export interface AnthropicKey extends Key, AnthropicKeyUsage {
isPozzed: boolean;
}
export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
export class AnthropicKeyProvider extends KeyProviderBase<AnthropicKey> {
readonly service = "anthropic" as const;
private readonly keys: AnthropicKey[] = [];
private store: KeyStore<AnthropicKey>;
protected readonly keys: AnthropicKey[] = [];
private checker?: AnthropicKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor(store: KeyStore<AnthropicKey>) {
this.store = store;
}
protected log = logger.child({ module: "key-provider", service: this.service });
public async init() {
const storeName = this.store.constructor.name;
@@ -75,10 +71,6 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: AnthropicModel) {
// Currently, all Anthropic keys have access to all models. This will almost
// certainly change when they move out of beta later this year.
@@ -123,22 +115,6 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
return { ...selectedKey };
}
public disable(key: AnthropicKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<AnthropicKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, _model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
+5 -29
View File
@@ -1,7 +1,8 @@
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { AwsBedrockModelFamily } from "../../models";
import { Key, KeyProvider, KeyStore } from "../types";
import { KeyProviderBase } from "../key-provider-base";
import { Key } from "../types";
import { AwsKeyChecker } from "./checker";
const RATE_LIMIT_LOCKOUT = 2000;
@@ -35,17 +36,12 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
awsLoggingStatus: "unknown" | "disabled" | "enabled";
}
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
export class AwsBedrockKeyProvider extends KeyProviderBase<AwsBedrockKey> {
readonly service = "aws" as const;
private readonly keys: AwsBedrockKey[] = [];
private store: KeyStore<AwsBedrockKey>;
protected readonly keys: AwsBedrockKey[] = [];
private checker?: AwsKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor(store: KeyStore<AwsBedrockKey>) {
this.store = store;
}
protected log = logger.child({ module: "key-provider", service: this.service });
public async init() {
const storeName = this.store.constructor.name;
@@ -67,10 +63,6 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: AwsBedrockModel) {
const availableKeys = this.keys.filter((k) => {
const isNotLogged = k.awsLoggingStatus === "disabled";
@@ -112,22 +104,6 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
return { ...selectedKey };
}
public disable(key: AwsBedrockKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<AwsBedrockKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, _model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
@@ -1,7 +1,6 @@
import { AxiosError } from "axios";
import pino from "pino";
import { logger } from "../../logger";
import { Key } from "./types";
type KeyCheckerOptions = {
+4 -11
View File
@@ -4,24 +4,17 @@ import os from "os";
import schedule from "node-schedule";
import { config } from "../../config";
import { logger } from "../../logger";
import { KeyProviderBase } from "./key-provider-base";
import { getSerializer } from "./serializers";
import { FirebaseKeyStore, MemoryKeyStore } from "./stores";
import { AnthropicKeyProvider } from "./anthropic/provider";
import { OpenAIKeyProvider } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import {
Key,
KeyProvider,
KeyStore,
LLMService,
Model,
ServiceToKey,
} from "./types";
import { Key, KeyStore, LLMService, Model, ServiceToKey } from "./types";
export class KeyPool {
private keyProviders: KeyProvider[] = [];
private keyProviders: KeyProviderBase[] = [];
private recheckJobs: Partial<Record<LLMService, schedule.Job | null>> = {
openai: null,
};
@@ -131,7 +124,7 @@ export class KeyPool {
throw new Error(`Unknown service for model '${model}'`);
}
private getKeyProvider(service: LLMService): KeyProvider {
private getKeyProvider(service: LLMService): KeyProviderBase {
return this.keyProviders.find((provider) => provider.service === service)!;
}
+53 -13
View File
@@ -1,25 +1,65 @@
import { Key, KeyProvider, LLMService, Model } from "./types";
import { logger } from "../../logger";
import { Key, KeyStore, LLMService, Model } from "./types";
export abstract class KeyProvierBase<K extends Key> implements KeyProvider<K> {
abstract readonly service: LLMService;
export abstract class KeyProviderBase<K extends Key = Key> {
public abstract readonly service: LLMService;
abstract init(): Promise<void>;
protected abstract readonly keys: K[];
protected abstract log: typeof logger;
protected readonly store: KeyStore<K>;
abstract get(model: Model): K;
public constructor(keyStore: KeyStore<K>) {
this.store = keyStore;
}
abstract list(): Omit<K, "key">[];
public abstract init(): Promise<void>;
abstract disable(key: K): void;
public addKey(key: K): void {
this.keys.push(key);
this.store.add(key);
}
abstract update(hash: string, update: Partial<K>): void;
public abstract get(model: Model): K;
abstract available(): number;
/**
* Returns a list of all keys, with the actual key value removed. Don't
* mutate the returned objects; use `update` instead to ensure the changes
* are synced to the key store.
*/
public list(): Omit<K, "key">[] {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
abstract incrementUsage(hash: string, model: string, tokens: number): void;
public disable(key: K): void {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
this.update(key.hash, { isDisabled: true } as Partial<K>, true);
this.log.warn({ key: key.hash }, "Key disabled");
}
abstract getLockoutPeriod(model: Model): number;
public update(hash: string, update: Partial<K>, force = false): void {
const key = this.keys.find((k) => k.hash === hash);
if (!key) {
throw new Error(`No key with hash ${hash}`);
}
abstract markRateLimited(hash: string): void;
Object.assign(key, { lastChecked: Date.now(), ...update });
this.store.update(hash, update, force);
}
abstract recheck(): void;
public available(): number {
return this.keys.filter((k) => !k.isDisabled).length;
}
public abstract incrementUsage(
hash: string,
model: string,
tokens: number
): void;
public abstract getLockoutPeriod(model: Model): number;
public abstract markRateLimited(hash: string): void;
public abstract recheck(): void;
}
@@ -1,5 +1,4 @@
import { KeySerializer, SerializedKey } from "./index";
import { Key } from "./types";
import { Key, KeySerializer, SerializedKey } from "./types";
export abstract class KeySerializerBase<K extends Key>
implements KeySerializer<K>
+6 -36
View File
@@ -3,8 +3,9 @@ import { IncomingHttpHeaders } from "http";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
import { Key, KeyProvider, KeyStore, Model } from "../types";
import { Key, Model } from "../types";
import { OpenAIKeyChecker } from "./checker";
import { KeyProviderBase } from "../key-provider-base";
const KEY_REUSE_DELAY = 1000;
@@ -60,17 +61,12 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
rateLimitTokensReset: number;
}
export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
export class OpenAIKeyProvider extends KeyProviderBase<OpenAIKey> {
readonly service = "openai" as const;
private readonly keys: OpenAIKey[] = [];
private store: KeyStore<OpenAIKey>;
protected readonly keys: OpenAIKey[] = [];
private checker?: OpenAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor(store: KeyStore<OpenAIKey>) {
this.store = store;
}
protected log = logger.child({ module: "key-provider", service: this.service });
public async init() {
const storeName = this.store.constructor.name;
@@ -97,14 +93,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
}
}
/**
* Returns a list of all keys, with the key field removed.
* Don't mutate returned keys, use a KeyPool method instead.
**/
public list() {
return this.keys.map((key) => Object.freeze({ ...key, key: undefined }));
}
public get(model: Model) {
const neededFamily = getOpenAIModelFamily(model);
const excludeTrials = model === "text-embedding-ada-002";
@@ -193,12 +181,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
return { ...selectedKey };
}
/** Called by the key checker to update key information. */
public update(keyHash: string, update: Partial<OpenAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
/** Called by the key checker to create clones of keys for the given orgs. */
public clone(keyHash: string, newOrgIds: string[]) {
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
@@ -220,19 +202,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
);
return clone;
});
this.keys.push(...clones);
}
/** Disables a key, or does nothing if the key isn't in this pool. */
public disable(key: Key) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
this.update(key.hash, { isDisabled: true });
this.log.warn({ key: key.hash }, "Key disabled");
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
clones.forEach((clone) => this.addKey(clone));
}
/**
+5 -29
View File
@@ -1,6 +1,7 @@
import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models";
import { Key, KeyProvider, KeyStore } from "../types";
import { KeyProviderBase } from "../key-provider-base";
import { Key } from "../types";
const RATE_LIMIT_LOCKOUT = 2000;
const KEY_REUSE_DELAY = 500;
@@ -22,16 +23,11 @@ export interface GooglePalmKey extends Key, GooglePalmKeyUsage {
rateLimitedUntil: number;
}
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
export class GooglePalmKeyProvider extends KeyProviderBase<GooglePalmKey> {
readonly service = "google-palm";
private keys: GooglePalmKey[] = [];
private store: KeyStore<GooglePalmKey>;
private log = logger.child({ module: "key-provider", service: this.service });
constructor(store: KeyStore<GooglePalmKey>) {
this.store = store;
}
protected keys: GooglePalmKey[] = [];
protected log = logger.child({ module: "key-provider", service: this.service });
public async init() {
const storeName = this.store.constructor.name;
@@ -48,10 +44,6 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
);
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: GooglePalmModel) {
const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) {
@@ -90,22 +82,6 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return { ...selectedKey };
}
public disable(key: GooglePalmKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<GooglePalmKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, _model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
+4 -4
View File
@@ -1,8 +1,4 @@
import { assertNever } from "../utils";
import { OpenAIKeySerializer } from "./openai/serializer";
import { AnthropicKeySerializer } from "./anthropic/serializer";
import { GooglePalmKeySerializer } from "./palm/serializer";
import { AwsBedrockKeySerializer } from "./aws/serializer";
import {
Key,
KeySerializer,
@@ -10,6 +6,10 @@ import {
SerializedKey,
ServiceToKey,
} from "./types";
import { OpenAIKeySerializer } from "./openai/serializer";
import { AnthropicKeySerializer } from "./anthropic/serializer";
import { GooglePalmKeySerializer } from "./palm/serializer";
import { AwsBedrockKeySerializer } from "./aws/serializer";
export function assertSerializedKey(k: any): asserts k is SerializedKey {
if (typeof k !== "object" || !k || typeof (k as any).key !== "string") {
+14 -7
View File
@@ -32,7 +32,7 @@ export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
this.serializer = serializer;
this.service = service;
this.pendingUpdates = new Map();
this.schedulePeriodicFlush();
this.scheduleFlush();
}
public async load(isMigrating = false): Promise<K[]> {
@@ -55,17 +55,24 @@ export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
}
public add(key: K) {
throw new Error("Method not implemented.");
const serialized = this.serializer.serialize(key);
this.pendingUpdates.set(key.hash, serialized);
this.forceFlush();
}
public update(id: string, update: Partial<K>, force = false) {
const existing = this.pendingUpdates.get(id) ?? {};
Object.assign(existing, this.serializer.partialSerialize(id, update));
this.pendingUpdates.set(id, existing);
if (force) setTimeout(() => this.flush(), 0);
if (force) this.forceFlush();
}
private schedulePeriodicFlush() {
private forceFlush() {
if (this.flushInterval) clearInterval(this.flushInterval);
this.flushInterval = setTimeout(() => this.flush(), 0);
}
private scheduleFlush() {
if (this.flushInterval) clearInterval(this.flushInterval);
this.flushInterval = setInterval(() => this.flush(), 1000 * 60 * 5);
}
@@ -76,7 +83,7 @@ export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
{ pendingUpdates: this.pendingUpdates.size },
"Database not loaded yet. Skipping flush."
);
return;
return this.scheduleFlush();
}
const updates: Record<string, Partial<SerializedKey>> = {};
@@ -85,11 +92,11 @@ export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
await this.keysRef.update(updates);
this.log.info(
this.log.debug(
{ count: Object.keys(updates).length },
"Flushed pending key updates."
);
this.schedulePeriodicFlush();
this.scheduleFlush();
}
private async migrate(): Promise<SerializedKey[]> {
+7 -30
View File
@@ -32,9 +32,14 @@ export interface Key {
service: LLMService;
/** The model families that this key has access to. */
modelFamilies: ModelFamily[];
/** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */
/** Whether this key is currently disabled for some reason. */
isDisabled: boolean;
/** Whether this key specifically has been revoked. */
/**
* Whether this key specifically has been revoked. This is different from
* `isDisabled` because a key can be disabled for other reasons, such as
* exceeding its quota. A revoked key is assumed to be permanently disabled,
* and KeyStore implementations should not return it when loading keys.
*/
isRevoked: boolean;
/** The number of prompts that have been sent with this key. */
promptCount: number;
@@ -46,42 +51,14 @@ export interface Key {
hash: string;
}
export interface KeyProvider<T extends Key = Key> {
readonly service: LLMService;
init(): Promise<void>;
get(model: Model): T;
list(): Omit<T, "key">[];
disable(key: T): void;
update(hash: string, update: Partial<T>): void;
available(): number;
incrementUsage(hash: string, model: string, tokens: number): void;
getLockoutPeriod(model: Model): number;
markRateLimited(hash: string): void;
recheck(): void;
}
export interface KeySerializer<K> {
serialize(keyObj: K): SerializedKey;
deserialize(serializedKey: SerializedKey): K;
partialSerialize(key: string, update: Partial<K>): Partial<SerializedKey>;
}
export interface KeyStore<K extends Key> {
load(): Promise<K[]>;
add(key: K): void;
update(id: string, update: Partial<K>, force?: boolean): void;
}