16 Commits

Author SHA1 Message Date
nai-degen b8cc5e563e wip, broke something with serializer 2023-10-12 15:13:55 -05:00
nai-degen 00402c8310 consolidates some duplicated keyprovider stuff 2023-10-09 00:03:46 -05:00
nai-degen df2e986366 adds .editorconfig for line endings 2023-10-08 18:44:35 -05:00
nai-degen f9620991e7 reorganizes imports and types 2023-10-08 18:44:14 -05:00
nai-degen dd511fe60d made it out of generic hell 2023-10-08 11:08:47 -05:00
nai-degen ea2bfb9eef implements most of firebasekeystore 2023-10-08 04:21:49 -05:00
nai-degen 39436e7492 adds root firebase field name configuration 2023-10-08 02:26:03 -05:00
nai-degen 3b9013cd1e minor keyprovider cleanup 2023-10-08 02:09:05 -05:00
nai-degen 8884544b05 fixes rebase issues and adds aws key serializer 2023-10-08 01:50:23 -05:00
nai-degen 05ab8c37eb implements generic key serialization/deserialization 2023-10-08 01:32:34 -05:00
nai-degen f53e328398 wip broken shit 2023-10-08 01:27:58 -05:00
nai-degen 21af866fd9 moves keystore interface 2023-10-08 01:27:56 -05:00
nai-degen 5d3433268f implements MemoryKeyStore; inject store when instantiating providers 2023-10-08 01:27:27 -05:00
nai-degen 4114dba4f5 adds anthropic provider deserialize method 2023-10-08 01:24:25 -05:00
nai-degen e44d24a3af migrates GATEKEEPER_STORE config to PERSISTENCE_PROVIDER 2023-10-08 01:23:12 -05:00
nai-degen d611aeee18 adds wip keystore interface 2023-10-08 01:23:09 -05:00
28 changed files with 753 additions and 497 deletions
+4
View File
@@ -0,0 +1,4 @@
root = true
[*]
end_of_line = crlf
+32 -10
View File
@@ -1,5 +1,6 @@
import dotenv from "dotenv"; import dotenv from "dotenv";
import type firebase from "firebase-admin"; import type firebase from "firebase-admin";
import { hostname } from "os";
import pino from "pino"; import pino from "pino";
import type { ModelFamily } from "./shared/models"; import type { ModelFamily } from "./shared/models";
dotenv.config(); dotenv.config();
@@ -50,12 +51,12 @@ type Config = {
*/ */
gatekeeper: "none" | "proxy_key" | "user_token"; gatekeeper: "none" | "proxy_key" | "user_token";
/** /**
* Persistence layer to use for user management. * Persistence layer to use for user and key management.
* - `memory`: Users are stored in memory and are lost on restart (default) * - `memory`: Data is stored in memory and lost on restart (default)
* - `firebase_rtdb`: Users are stored in a Firebase Realtime Database; * - `firebase_rtdb`: Data is stored in Firebase Realtime Database; requires
* requires `firebaseKey` and `firebaseRtdbUrl` to be set. * `firebaseKey` and `firebaseRtdbUrl` to be set.
*/ */
gatekeeperStore: "memory" | "firebase_rtdb"; persistenceProvider: "memory" | "firebase_rtdb";
/** URL of the Firebase Realtime Database if using the Firebase RTDB store. */ /** URL of the Firebase Realtime Database if using the Firebase RTDB store. */
firebaseRtdbUrl?: string; firebaseRtdbUrl?: string;
/** /**
@@ -64,6 +65,19 @@ type Config = {
* `private_key` field inside it. * `private_key` field inside it.
*/ */
firebaseKey?: string; firebaseKey?: string;
/**
* The root key under which data will be stored in the Firebase RTDB. This
* allows multiple instances of the proxy to share the same database while
* keeping their data separate.
*
* If you want multiple proxies to share the same data, set all of their
* `firebaseRtdbRoot` to the same value. Beware that there will likely
* be conflicts because concurrent writes are not yet supported and proxies
* currently assume they have exclusive access to the database.
*
* Defaults to the system hostname so that data is kept separate.
*/
firebaseRtdbRoot: string;
/** /**
* Maximum number of IPs per user, after which their token is disabled. * Maximum number of IPs per user, after which their token is disabled.
* Users with the manually-assigned `special` role are exempt from this limit. * Users with the manually-assigned `special` role are exempt from this limit.
@@ -165,10 +179,11 @@ export const config: Config = {
proxyKey: getEnvWithDefault("PROXY_KEY", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""), adminKey: getEnvWithDefault("ADMIN_KEY", ""),
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"), gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"), persistenceProvider: getEnvWithDefault("PERSISTENCE_PROVIDER", "memory"),
maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0), maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0),
firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined), firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined),
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined), firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
firebaseRtdbRoot: getEnvWithDefault("FIREBASE_RTDB_ROOT", hostname()),
modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4), modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 0), maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 0),
maxContextTokensAnthropic: getEnvWithDefault( maxContextTokensAnthropic: getEnvWithDefault(
@@ -247,6 +262,13 @@ export async function assertConfigIsValid() {
); );
} }
if (!!process.env.GATEKEEPER_STORE) {
startupLogger.warn(
"GATEKEEPER_STORE is deprecated. Use PERSISTENCE_PROVIDER instead. Configuration will be migrated."
);
config.persistenceProvider = process.env.GATEKEEPER_STORE as any;
}
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) { if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
throw new Error( throw new Error(
`Invalid gatekeeper mode: ${config.gatekeeper}. Must be one of: none, proxy_key, user_token.` `Invalid gatekeeper mode: ${config.gatekeeper}. Must be one of: none, proxy_key, user_token.`
@@ -272,11 +294,11 @@ export async function assertConfigIsValid() {
} }
if ( if (
config.gatekeeperStore === "firebase_rtdb" && config.persistenceProvider === "firebase_rtdb" &&
(!config.firebaseKey || !config.firebaseRtdbUrl) (!config.firebaseKey || !config.firebaseRtdbUrl)
) { ) {
throw new Error( throw new Error(
"Firebase RTDB store requires `FIREBASE_KEY` and `FIREBASE_RTDB_URL` to be set." "Firebase RTDB persistence requires `FIREBASE_KEY` and `FIREBASE_RTDB_URL` to be set."
); );
} }
@@ -318,9 +340,9 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"checkKeys", "checkKeys",
"showTokenCosts", "showTokenCosts",
"googleSheetsKey", "googleSheetsKey",
"persistenceProvider",
"firebaseKey", "firebaseKey",
"firebaseRtdbUrl", "firebaseRtdbUrl",
"gatekeeperStore",
"maxIpsPerUser", "maxIpsPerUser",
"blockedOrigins", "blockedOrigins",
"blockMessage", "blockMessage",
@@ -393,7 +415,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
let firebaseApp: firebase.app.App | undefined; let firebaseApp: firebase.app.App | undefined;
async function maybeInitializeFirebase() { async function maybeInitializeFirebase() {
if (!config.gatekeeperStore.startsWith("firebase")) { if (!config.persistenceProvider.startsWith("firebase")) {
return; return;
} }
+1 -1
View File
@@ -4,9 +4,9 @@ import showdown from "showdown";
import { config, listConfig } from "./config"; import { config, listConfig } from "./config";
import { import {
AnthropicKey, AnthropicKey,
AwsBedrockKey,
GooglePalmKey, GooglePalmKey,
OpenAIKey, OpenAIKey,
AwsBedrockKey,
keyPool, keyPool,
} from "./shared/key-management"; } from "./shared/key-management";
import { ModelFamily, OpenAIModelFamily } from "./shared/models"; import { ModelFamily, OpenAIModelFamily } from "./shared/models";
@@ -5,9 +5,9 @@ import { RequestPreprocessor } from ".";
export const setApiFormat = (api: { export const setApiFormat = (api: {
inApi: Request["inboundApi"]; inApi: Request["inboundApi"];
outApi: APIFormat; outApi: APIFormat;
service: LLMService, service: LLMService;
}): RequestPreprocessor => { }): RequestPreprocessor => {
return function configureRequestApiFormat (req) { return function configureRequestApiFormat(req) {
req.inboundApi = api.inApi; req.inboundApi = api.inApi;
req.outboundApi = api.outApi; req.outboundApi = api.outApi;
req.service = api.service; req.service = api.service;
+2 -2
View File
@@ -6,7 +6,7 @@ import zlib from "zlib";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import { enqueue, trackWaitTime } from "../../queue"; import { enqueue, trackWaitTime } from "../../queue";
import { HttpError } from "../../../shared/errors"; import { HttpError } from "../../../shared/errors";
import { keyPool } from "../../../shared/key-management"; import { AnthropicKey, keyPool } from "../../../shared/key-management";
import { getOpenAIModelFamily } from "../../../shared/models"; import { getOpenAIModelFamily } from "../../../shared/models";
import { countTokens } from "../../../shared/tokenization"; import { countTokens } from "../../../shared/tokenization";
import { import {
@@ -407,7 +407,7 @@ function maybeHandleMissingPreambleError(
{ key: req.key?.hash }, { key: req.key?.hash },
"Request failed due to missing preamble. Key will be marked as such for subsequent requests." "Request failed due to missing preamble. Key will be marked as such for subsequent requests."
); );
keyPool.update(req.key!, { requiresPreamble: true }); keyPool.update(req.key as AnthropicKey, { requiresPreamble: true });
reenqueueRequest(req); reenqueueRequest(req);
throw new RetryableError("Claude request re-enqueued to add preamble."); throw new RetryableError("Claude request re-enqueued to add preamble.");
} else { } else {
@@ -4,7 +4,7 @@ import {
mergeEventsForAnthropic, mergeEventsForAnthropic,
mergeEventsForOpenAIChat, mergeEventsForOpenAIChat,
mergeEventsForOpenAIText, mergeEventsForOpenAIText,
OpenAIChatCompletionStreamEvent OpenAIChatCompletionStreamEvent,
} from "./index"; } from "./index";
/** /**
+2 -2
View File
@@ -16,7 +16,7 @@
*/ */
import type { Handler, Request } from "express"; import type { Handler, Request } from "express";
import { keyPool, SupportedModel } from "../shared/key-management"; import { keyPool } from "../shared/key-management";
import { import {
getClaudeModelFamily, getClaudeModelFamily,
getGooglePalmModelFamily, getGooglePalmModelFamily,
@@ -138,7 +138,7 @@ function getPartitionForRequest(req: Request): ModelFamily {
// There is a single request queue, but it is partitioned by model family. // There is a single request queue, but it is partitioned by model family.
// Model families are typically separated on cost/rate limit boundaries so // Model families are typically separated on cost/rate limit boundaries so
// they should be treated as separate queues. // they should be treated as separate queues.
const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo"; const model = req.body.model ?? "gpt-3.5-turbo";
// Weird special case for AWS because they serve multiple models from // Weird special case for AWS because they serve multiple models from
// different vendors, even if currently only one is supported. // different vendors, even if currently only one is supported.
+10 -9
View File
@@ -5,16 +5,16 @@ import cors from "cors";
import path from "path"; import path from "path";
import pinoHttp from "pino-http"; import pinoHttp from "pino-http";
import childProcess from "child_process"; import childProcess from "child_process";
import { logger } from "./logger";
import { keyPool } from "./shared/key-management";
import { adminRouter } from "./admin/routes";
import { proxyRouter } from "./proxy/routes";
import { handleInfoPage } from "./info-page"; import { handleInfoPage } from "./info-page";
import { logQueue } from "./shared/prompt-logging"; import { logger } from "./logger";
import { start as startRequestQueue } from "./proxy/queue"; import { adminRouter } from "./admin/routes";
import { init as initUserStore } from "./shared/users/user-store";
import { init as initTokenizers } from "./shared/tokenization";
import { checkOrigin } from "./proxy/check-origin"; import { checkOrigin } from "./proxy/check-origin";
import { start as startRequestQueue } from "./proxy/queue";
import { proxyRouter } from "./proxy/routes";
import { init as initKeyPool } from "./shared/key-management/key-pool";
import { logQueue } from "./shared/prompt-logging";
import { init as initTokenizers } from "./shared/tokenization";
import { init as initUserStore } from "./shared/users/user-store";
import { userRouter } from "./user/routes"; import { userRouter } from "./user/routes";
const PORT = config.port; const PORT = config.port;
@@ -92,7 +92,8 @@ async function start() {
logger.info("Checking configs and external dependencies..."); logger.info("Checking configs and external dependencies...");
await assertConfigIsValid(); await assertConfigIsValid();
keyPool.init(); logger.info("Starting key pool...");
await initKeyPool();
await initTokenizers(); await initTokenizers();
+1 -1
View File
@@ -11,7 +11,7 @@ export const injectLocals: RequestHandler = (req, res, next) => {
quota.turbo > 0 || quota.gpt4 > 0 || quota.claude > 0; quota.turbo > 0 || quota.gpt4 > 0 || quota.claude > 0;
res.locals.quota = quota; res.locals.quota = quota;
res.locals.nextQuotaRefresh = userStore.getNextQuotaRefresh(); res.locals.nextQuotaRefresh = userStore.getNextQuotaRefresh();
res.locals.persistenceEnabled = config.gatekeeperStore !== "memory"; res.locals.persistenceEnabled = config.persistenceProvider !== "memory";
res.locals.showTokenCosts = config.showTokenCosts; res.locals.showTokenCosts = config.showTokenCosts;
res.locals.maxIps = config.maxIpsPerUser; res.locals.maxIps = config.maxIpsPerUser;
+22 -84
View File
@@ -1,10 +1,13 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import type { AnthropicModelFamily } from "../../models"; import type { AnthropicModelFamily } from "../../models";
import { KeyProviderBase } from "../key-provider-base";
import { Key } from "../types";
import { AnthropicKeyChecker } from "./checker"; import { AnthropicKeyChecker } from "./checker";
const RATE_LIMIT_LOCKOUT = 2000;
const KEY_REUSE_DELAY = 500;
// https://docs.anthropic.com/claude/reference/selecting-a-model // https://docs.anthropic.com/claude/reference/selecting-a-model
export const ANTHROPIC_SUPPORTED_MODELS = [ export const ANTHROPIC_SUPPORTED_MODELS = [
"claude-instant-v1", "claude-instant-v1",
@@ -15,16 +18,6 @@ export const ANTHROPIC_SUPPORTED_MODELS = [
] as const; ] as const;
export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number]; export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number];
export type AnthropicKeyUpdate = Omit<
Partial<AnthropicKey>,
| "key"
| "hash"
| "lastUsed"
| "promptCount"
| "rateLimitedAt"
| "rateLimitedUntil"
>;
type AnthropicKeyUsage = { type AnthropicKeyUsage = {
[K in AnthropicModelFamily as `${K}Tokens`]: number; [K in AnthropicModelFamily as `${K}Tokens`]: number;
}; };
@@ -51,72 +44,33 @@ export interface AnthropicKey extends Key, AnthropicKeyUsage {
isPozzed: boolean; isPozzed: boolean;
} }
/** export class AnthropicKeyProvider extends KeyProviderBase<AnthropicKey> {
* Upon being rate limited, a key will be locked out for this many milliseconds readonly service = "anthropic" as const;
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> { protected readonly keys: AnthropicKey[] = [];
readonly service = "anthropic";
private keys: AnthropicKey[] = [];
private checker?: AnthropicKeyChecker; private checker?: AnthropicKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service }); protected log = logger.child({ module: "key-provider", service: this.service });
constructor() { public async init() {
const keyConfig = config.anthropicKey?.trim(); const storeName = this.store.constructor.name;
if (!keyConfig) { const loadedKeys = await this.store.load();
this.log.warn(
"ANTHROPIC_KEY is not set. Anthropic API will not be available." if (loadedKeys.length === 0) {
); return this.log.warn({ via: storeName }, "No Anthropic keys found.");
return; }
}
let bareKeys: string[]; this.keys.push(...loadedKeys);
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; this.log.info(
for (const key of bareKeys) { { count: this.keys.length, via: storeName },
const newKey: AnthropicKey = { "Loaded Anthropic keys."
key, );
service: this.service,
modelFamilies: ["claude"],
isDisabled: false,
isRevoked: false,
isPozzed: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
requiresPreamble: false,
hash: `ant-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
claudeTokens: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded Anthropic keys.");
}
public init() {
if (config.checkKeys) { if (config.checkKeys) {
this.checker = new AnthropicKeyChecker(this.keys, this.update.bind(this)); this.checker = new AnthropicKeyChecker(this.keys, this.update.bind(this));
this.checker.start(); this.checker.start();
} }
} }
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: AnthropicModel) { public get(_model: AnthropicModel) {
// Currently, all Anthropic keys have access to all models. This will almost // Currently, all Anthropic keys have access to all models. This will almost
// certainly change when they move out of beta later this year. // certainly change when they move out of beta later this year.
@@ -161,22 +115,6 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
return { ...selectedKey }; return { ...selectedKey };
} }
public disable(key: AnthropicKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<AnthropicKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, _model: string, tokens: number) { public incrementUsage(hash: string, _model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === hash);
if (!key) return; if (!key) return;
@@ -0,0 +1,43 @@
import crypto from "crypto";
import type { AnthropicKey, SerializedKey } from "../index";
import { KeySerializerBase } from "../key-serializer-base";
const SERIALIZABLE_FIELDS: (keyof AnthropicKey)[] = [
"key",
"service",
"hash",
"promptCount",
"claudeTokens",
];
export type SerializedAnthropicKey = SerializedKey &
Partial<Pick<AnthropicKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export class AnthropicKeySerializer extends KeySerializerBase<AnthropicKey> {
constructor() {
super(SERIALIZABLE_FIELDS);
}
deserialize({ key, ...rest }: SerializedAnthropicKey): AnthropicKey {
return {
key,
service: "anthropic" as const,
modelFamilies: ["claude" as const],
isDisabled: false,
isRevoked: false,
isPozzed: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
requiresPreamble: false,
hash: `ant-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
claudeTokens: 0,
...rest,
};
}
}
+22 -73
View File
@@ -1,10 +1,13 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import type { AwsBedrockModelFamily } from "../../models"; import type { AwsBedrockModelFamily } from "../../models";
import { KeyProviderBase } from "../key-provider-base";
import { Key } from "../types";
import { AwsKeyChecker } from "./checker"; import { AwsKeyChecker } from "./checker";
const RATE_LIMIT_LOCKOUT = 2000;
const KEY_REUSE_DELAY = 500;
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
export const AWS_BEDROCK_SUPPORTED_MODELS = [ export const AWS_BEDROCK_SUPPORTED_MODELS = [
"anthropic.claude-v1", "anthropic.claude-v1",
@@ -33,71 +36,33 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
awsLoggingStatus: "unknown" | "disabled" | "enabled"; awsLoggingStatus: "unknown" | "disabled" | "enabled";
} }
/** export class AwsBedrockKeyProvider extends KeyProviderBase<AwsBedrockKey> {
* Upon being rate limited, a key will be locked out for this many milliseconds readonly service = "aws" as const;
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 300;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> { protected readonly keys: AwsBedrockKey[] = [];
readonly service = "aws";
private keys: AwsBedrockKey[] = [];
private checker?: AwsKeyChecker; private checker?: AwsKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service }); protected log = logger.child({ module: "key-provider", service: this.service });
constructor() { public async init() {
const keyConfig = config.awsCredentials?.trim(); const storeName = this.store.constructor.name;
if (!keyConfig) { const loadedKeys = await this.store.load();
this.log.warn(
"AWS_CREDENTIALS is not set. AWS Bedrock API will not be available." if (loadedKeys.length === 0) {
); return this.log.warn({ via: storeName }, "No AWS credentials found.");
return; }
}
let bareKeys: string[]; this.keys.push(...loadedKeys);
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; this.log.info(
for (const key of bareKeys) { { count: this.keys.length, via: storeName },
const newKey: AwsBedrockKey = { "Loaded AWS Bedrock keys."
key, );
service: this.service,
modelFamilies: ["aws-claude"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
awsLoggingStatus: "unknown",
hash: `aws-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
["aws-claudeTokens"]: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded AWS Bedrock keys.");
}
public init() {
if (config.checkKeys) { if (config.checkKeys) {
this.checker = new AwsKeyChecker(this.keys, this.update.bind(this)); this.checker = new AwsKeyChecker(this.keys, this.update.bind(this));
this.checker.start(); this.checker.start();
} }
} }
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: AwsBedrockModel) { public get(_model: AwsBedrockModel) {
const availableKeys = this.keys.filter((k) => { const availableKeys = this.keys.filter((k) => {
const isNotLogged = k.awsLoggingStatus === "disabled"; const isNotLogged = k.awsLoggingStatus === "disabled";
@@ -139,22 +104,6 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
return { ...selectedKey }; return { ...selectedKey };
} }
public disable(key: AwsBedrockKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<AwsBedrockKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, _model: string, tokens: number) { public incrementUsage(hash: string, _model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === hash);
if (!key) return; if (!key) return;
@@ -0,0 +1,43 @@
import crypto from "crypto";
import type { AwsBedrockKey, SerializedKey } from "../index";
import { KeySerializerBase } from "../key-serializer-base";
const SERIALIZABLE_FIELDS: (keyof AwsBedrockKey)[] = [
"key",
"service",
"hash",
"promptCount",
"aws-claudeTokens",
];
export type SerializedAwsBedrockKey = SerializedKey &
Partial<Pick<AwsBedrockKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export class AwsBedrockKeySerializer extends KeySerializerBase<AwsBedrockKey> {
constructor() {
super(SERIALIZABLE_FIELDS);
}
deserialize(serializedKey: SerializedAwsBedrockKey): AwsBedrockKey {
const { key, ...rest } = serializedKey;
return {
key,
service: "aws",
modelFamilies: ["aws-claude"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
awsLoggingStatus: "unknown",
hash: `aws-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
["aws-claudeTokens"]: 0,
...rest,
};
}
}
+10 -83
View File
@@ -1,83 +1,10 @@
import { OPENAI_SUPPORTED_MODELS, OpenAIModel } from "./openai/provider"; export { keyPool } from "./key-pool";
import { export { OPENAI_SUPPORTED_MODELS } from "./openai/provider";
ANTHROPIC_SUPPORTED_MODELS, export { ANTHROPIC_SUPPORTED_MODELS } from "./anthropic/provider";
AnthropicModel, export { GOOGLE_PALM_SUPPORTED_MODELS } from "./palm/provider";
} from "./anthropic/provider"; export { AWS_BEDROCK_SUPPORTED_MODELS } from "./aws/provider";
import { GOOGLE_PALM_SUPPORTED_MODELS, GooglePalmModel } from "./palm/provider"; export type { AnthropicKey } from "./anthropic/provider";
import { AWS_BEDROCK_SUPPORTED_MODELS, AwsBedrockModel } from "./aws/provider"; export type { OpenAIKey } from "./openai/provider";
import { KeyPool } from "./key-pool"; export type { GooglePalmKey } from "./palm/provider";
import type { ModelFamily } from "../models"; export type { AwsBedrockKey } from "./aws/provider";
export * from "./types";
/** The request and response format used by a model's API. */
export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text";
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
export type Model =
| OpenAIModel
| AnthropicModel
| GooglePalmModel
| AwsBedrockModel;
export interface Key {
/** The API key itself. Never log this, use `hash` instead. */
readonly key: string;
/** The service that this key is for. */
service: LLMService;
/** The model families that this key has access to. */
modelFamilies: ModelFamily[];
/** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */
isDisabled: boolean;
/** Whether this key specifically has been revoked. */
isRevoked: boolean;
/** The number of prompts that have been sent with this key. */
promptCount: number;
/** The time at which this key was last used. */
lastUsed: number;
/** The time at which this key was last checked. */
lastChecked: number;
/** Hash of the key, for logging and to find the key in the pool. */
hash: string;
}
/*
KeyPool and KeyProvider's similarities are a relic of the old design where
there was only a single KeyPool for OpenAI keys. Now that there are multiple
supported services, the service-specific functionality has been moved to
KeyProvider and KeyPool is just a wrapper around multiple KeyProviders,
delegating to the appropriate one based on the model requested.
Existing code will continue to call methods on KeyPool, which routes them to
the appropriate KeyProvider or returns data aggregated across all KeyProviders
for service-agnostic functionality.
*/
export interface KeyProvider<T extends Key = Key> {
readonly service: LLMService;
init(): void;
get(model: Model): T;
list(): Omit<T, "key">[];
disable(key: T): void;
update(hash: string, update: Partial<T>): void;
available(): number;
incrementUsage(hash: string, model: string, tokens: number): void;
getLockoutPeriod(model: Model): number;
markRateLimited(hash: string): void;
recheck(): void;
}
export const keyPool = new KeyPool();
export const SUPPORTED_MODELS = [
...OPENAI_SUPPORTED_MODELS,
...ANTHROPIC_SUPPORTED_MODELS,
] as const;
export type SupportedModel = (typeof SUPPORTED_MODELS)[number];
export {
OPENAI_SUPPORTED_MODELS,
ANTHROPIC_SUPPORTED_MODELS,
GOOGLE_PALM_SUPPORTED_MODELS,
AWS_BEDROCK_SUPPORTED_MODELS,
};
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider";
export { AwsBedrockKey } from "./aws/provider";
@@ -1,13 +1,13 @@
import { AxiosError } from "axios";
import pino from "pino"; import pino from "pino";
import { logger } from "../../logger"; import { logger } from "../../logger";
import { Key } from "./index"; import { Key } from "./types";
import { AxiosError } from "axios";
type KeyCheckerOptions = { type KeyCheckerOptions = {
service: string; service: string;
keyCheckPeriod: number; keyCheckPeriod: number;
minCheckInterval: number; minCheckInterval: number;
} };
export abstract class KeyCheckerBase<TKey extends Key> { export abstract class KeyCheckerBase<TKey extends Key> {
protected readonly service: string; protected readonly service: string;
@@ -116,4 +116,4 @@ export abstract class KeyCheckerBase<TKey extends Key> {
protected abstract checkKey(key: TKey): Promise<void>; protected abstract checkKey(key: TKey): Promise<void>;
protected abstract handleAxiosError(key: TKey, error: AxiosError): void; protected abstract handleAxiosError(key: TKey, error: AxiosError): void;
} }
+41 -17
View File
@@ -4,34 +4,36 @@ import os from "os";
import schedule from "node-schedule"; import schedule from "node-schedule";
import { config } from "../../config"; import { config } from "../../config";
import { logger } from "../../logger"; import { logger } from "../../logger";
import { Key, Model, KeyProvider, LLMService } from "./index"; import { KeyProviderBase } from "./key-provider-base";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { getSerializer } from "./serializers";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; import { FirebaseKeyStore, MemoryKeyStore } from "./stores";
import { AnthropicKeyProvider } from "./anthropic/provider";
import { OpenAIKeyProvider } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider"; import { GooglePalmKeyProvider } from "./palm/provider";
import { AwsBedrockKeyProvider } from "./aws/provider"; import { AwsBedrockKeyProvider } from "./aws/provider";
import { Key, KeyStore, LLMService, Model, ServiceToKey } from "./types";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
export class KeyPool { export class KeyPool {
private keyProviders: KeyProvider[] = []; private keyProviders: KeyProviderBase[] = [];
private recheckJobs: Partial<Record<LLMService, schedule.Job | null>> = { private recheckJobs: Partial<Record<LLMService, schedule.Job | null>> = {
openai: null, openai: null,
}; };
constructor() { constructor() {
this.keyProviders.push(new OpenAIKeyProvider()); this.keyProviders.push(
this.keyProviders.push(new AnthropicKeyProvider()); new OpenAIKeyProvider(createKeyStore("openai")),
this.keyProviders.push(new GooglePalmKeyProvider()); new AnthropicKeyProvider(createKeyStore("anthropic")),
this.keyProviders.push(new AwsBedrockKeyProvider()); new GooglePalmKeyProvider(createKeyStore("google-palm")),
new AwsBedrockKeyProvider(createKeyStore("aws"))
);
} }
public init() { public async init() {
this.keyProviders.forEach((provider) => provider.init()); await Promise.all(this.keyProviders.map((p) => p.init()));
const availableKeys = this.available("all"); const availableKeys = this.available("all");
if (availableKeys === 0) { if (availableKeys === 0) {
throw new Error( throw new Error("No keys loaded, the application cannot start.");
"No keys loaded. Ensure that at least one key is configured."
);
} }
this.scheduleRecheck(); this.scheduleRecheck();
} }
@@ -59,7 +61,7 @@ export class KeyPool {
} }
} }
public update(key: Key, props: AllowedPartial): void { public update<T extends Key>(key: T, props: Partial<T>): void {
const service = this.getKeyProvider(key.service); const service = this.getKeyProvider(key.service);
service.update(key.hash, props); service.update(key.hash, props);
} }
@@ -122,7 +124,7 @@ export class KeyPool {
throw new Error(`Unknown service for model '${model}'`); throw new Error(`Unknown service for model '${model}'`);
} }
private getKeyProvider(service: LLMService): KeyProvider { private getKeyProvider(service: LLMService): KeyProviderBase {
return this.keyProviders.find((provider) => provider.service === service)!; return this.keyProviders.find((provider) => provider.service === service)!;
} }
@@ -151,3 +153,25 @@ export class KeyPool {
this.recheckJobs.openai = job; this.recheckJobs.openai = job;
} }
} }
function createKeyStore<S extends LLMService>(
service: S
): KeyStore<ServiceToKey[S]> {
const serializer = getSerializer(service);
switch (config.persistenceProvider) {
case "memory":
return new MemoryKeyStore(service, serializer);
case "firebase_rtdb":
return new FirebaseKeyStore(service, serializer);
default:
throw new Error(`Unknown store type: ${config.persistenceProvider}`);
}
}
export let keyPool: KeyPool;
export async function init() {
keyPool = new KeyPool();
await keyPool.init();
}
@@ -0,0 +1,65 @@
import { logger } from "../../logger";
import { Key, KeyStore, LLMService, Model } from "./types";
export abstract class KeyProviderBase<K extends Key = Key> {
public abstract readonly service: LLMService;
protected abstract readonly keys: K[];
protected abstract log: typeof logger;
protected readonly store: KeyStore<K>;
public constructor(keyStore: KeyStore<K>) {
this.store = keyStore;
}
public abstract init(): Promise<void>;
public addKey(key: K): void {
this.keys.push(key);
this.store.add(key);
}
public abstract get(model: Model): K;
/**
* Returns a list of all keys, with the actual key value removed. Don't
* mutate the returned objects; use `update` instead to ensure the changes
* are synced to the key store.
*/
public list(): Omit<K, "key">[] {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public disable(key: K): void {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
this.update(key.hash, { isDisabled: true } as Partial<K>, true);
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<K>, force = false): void {
const key = this.keys.find((k) => k.hash === hash);
if (!key) {
throw new Error(`No key with hash ${hash}`);
}
Object.assign(key, { lastChecked: Date.now(), ...update });
this.store.update(hash, update, force);
}
public available(): number {
return this.keys.filter((k) => !k.isDisabled).length;
}
public abstract incrementUsage(
hash: string,
model: string,
tokens: number
): void;
public abstract getLockoutPeriod(model: Model): number;
public abstract markRateLimited(hash: string): void;
public abstract recheck(): void;
}
@@ -0,0 +1,31 @@
import { Key, KeySerializer, SerializedKey } from "./types";
export abstract class KeySerializerBase<K extends Key>
implements KeySerializer<K>
{
protected constructor(protected serializableFields: (keyof K)[]) {}
serialize(keyObj: K): SerializedKey {
return {
...Object.fromEntries(
this.serializableFields
.map((f) => [f, keyObj[f]])
.filter(([, v]) => v !== undefined)
),
key: keyObj.key,
};
}
partialSerialize(key: string, update: Partial<K>): Partial<SerializedKey> {
return {
...Object.fromEntries(
this.serializableFields
.map((f) => [f, update[f]])
.filter(([, v]) => v !== undefined)
),
key,
};
}
abstract deserialize(serializedKey: SerializedKey): K;
}
+35 -121
View File
@@ -1,28 +1,23 @@
/* Manages OpenAI API keys. Tracks usage, disables expired keys, and provides
round-robin access to keys. Keys are stored in the OPENAI_KEY environment
variable as a comma-separated list of keys. */
import crypto from "crypto"; import crypto from "crypto";
import http from "http"; import { IncomingHttpHeaders } from "http";
import { Key, KeyProvider, Model } from "../index";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import { OpenAIKeyChecker } from "./checker";
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models"; import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
import { Key, Model } from "../types";
import { OpenAIKeyChecker } from "./checker";
import { KeyProviderBase } from "../key-provider-base";
export type OpenAIModel = const KEY_REUSE_DELAY = 1000;
| "gpt-3.5-turbo"
| "gpt-3.5-turbo-instruct" export const OPENAI_SUPPORTED_MODELS = [
| "gpt-4"
| "gpt-4-32k"
| "text-embedding-ada-002";
export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
"gpt-3.5-turbo", "gpt-3.5-turbo",
"gpt-3.5-turbo-instruct", "gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4",
"gpt-4-32k",
"text-embedding-ada-002",
] as const; ] as const;
export type OpenAIModel = (typeof OPENAI_SUPPORTED_MODELS)[number];
// Flattening model families instead of using a nested object for easier
// cloning.
type OpenAIKeyUsage = { type OpenAIKeyUsage = {
[K in OpenAIModelFamily as `${K}Tokens`]: number; [K in OpenAIModelFamily as `${K}Tokens`]: number;
}; };
@@ -66,64 +61,30 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
rateLimitTokensReset: number; rateLimitTokensReset: number;
} }
export type OpenAIKeyUpdate = Omit< export class OpenAIKeyProvider extends KeyProviderBase<OpenAIKey> {
Partial<OpenAIKey>,
"key" | "hash" | "promptCount"
>;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 1000;
export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
readonly service = "openai" as const; readonly service = "openai" as const;
private keys: OpenAIKey[] = []; protected readonly keys: OpenAIKey[] = [];
private checker?: OpenAIKeyChecker; private checker?: OpenAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service }); protected log = logger.child({ module: "key-provider", service: this.service });
constructor() { public async init() {
const keyString = config.openaiKey?.trim(); const storeName = this.store.constructor.name;
if (!keyString) { const loadedKeys = await this.store.load();
this.log.warn("OPENAI_KEY is not set. OpenAI API will not be available.");
return; // TODO: after key management UI, keychecker should always be enabled
} // because keys may be added after initialization.
let bareKeys: string[];
bareKeys = keyString.split(",").map((k) => k.trim()); if (loadedKeys.length === 0) {
bareKeys = [...new Set(bareKeys)]; return this.log.warn({ via: storeName }, "No OpenAI keys found.");
for (const k of bareKeys) { }
const newKey: OpenAIKey = {
key: k, this.keys.push(...loadedKeys);
service: "openai" as const, this.log.info(
modelFamilies: ["turbo" as const, "gpt4" as const], { count: this.keys.length, via: storeName },
isTrial: false, "Loaded OpenAI keys."
isDisabled: false, );
isRevoked: false,
isOverQuota: false,
lastUsed: 0,
lastChecked: 0,
promptCount: 0,
hash: `oai-${crypto
.createHash("sha256")
.update(k)
.digest("hex")
.slice(0, 8)}`,
rateLimitedAt: 0,
rateLimitRequestsReset: 0,
rateLimitTokensReset: 0,
turboTokens: 0,
gpt4Tokens: 0,
"gpt4-32kTokens": 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded OpenAI keys.");
}
public init() {
if (config.checkKeys) { if (config.checkKeys) {
const cloneFn = this.clone.bind(this); const cloneFn = this.clone.bind(this);
const updateFn = this.update.bind(this); const updateFn = this.update.bind(this);
@@ -132,29 +93,16 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
} }
} }
/**
* Returns a list of all keys, with the key field removed.
* Don't mutate returned keys, use a KeyPool method instead.
**/
public list() {
return this.keys.map((key) => {
return Object.freeze({
...key,
key: undefined,
});
});
}
public get(model: Model) { public get(model: Model) {
const neededFamily = getOpenAIModelFamily(model); const neededFamily = getOpenAIModelFamily(model);
const excludeTrials = model === "text-embedding-ada-002"; const excludeTrials = model === "text-embedding-ada-002";
const availableKeys = this.keys.filter( const availableKeys = this.keys.filter(
// Allow keys which // Allow keys which...
(key) => (key) =>
!key.isDisabled && // are not disabled !key.isDisabled && // ...are not disabled
key.modelFamilies.includes(neededFamily) && // have access to the model key.modelFamilies.includes(neededFamily) && // ...have access to the model
(!excludeTrials || !key.isTrial) // and are not trials (if applicable) (!excludeTrials || !key.isTrial) // ...and are not trials (if applicable)
); );
if (availableKeys.length === 0) { if (availableKeys.length === 0) {
@@ -233,13 +181,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
return { ...selectedKey }; return { ...selectedKey };
} }
/** Called by the key checker to update key information. */
public update(keyHash: string, update: OpenAIKeyUpdate) {
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
// this.writeKeyStatus();
}
/** Called by the key checker to create clones of keys for the given orgs. */ /** Called by the key checker to create clones of keys for the given orgs. */
public clone(keyHash: string, newOrgIds: string[]) { public clone(keyHash: string, newOrgIds: string[]) {
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!; const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
@@ -261,19 +202,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
); );
return clone; return clone;
}); });
this.keys.push(...clones); clones.forEach((clone) => this.addKey(clone));
}
/** Disables a key, or does nothing if the key isn't in this pool. */
public disable(key: Key) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
this.update(key.hash, { isDisabled: true });
this.log.warn({ key: key.hash }, "Key disabled");
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
} }
/** /**
@@ -338,7 +267,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
key[`${getOpenAIModelFamily(model)}Tokens`] += tokens; key[`${getOpenAIModelFamily(model)}Tokens`] += tokens;
} }
public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) { public updateRateLimits(keyHash: string, headers: IncomingHttpHeaders) {
const key = this.keys.find((k) => k.hash === keyHash)!; const key = this.keys.find((k) => k.hash === keyHash)!;
const requestsReset = headers["x-ratelimit-reset-requests"]; const requestsReset = headers["x-ratelimit-reset-requests"];
const tokensReset = headers["x-ratelimit-reset-tokens"]; const tokensReset = headers["x-ratelimit-reset-tokens"];
@@ -382,21 +311,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
}); });
this.checker?.scheduleNextCheck(); this.checker?.scheduleNextCheck();
} }
/** Writes key status to disk. */
// public writeKeyStatus() {
// const keys = this.keys.map((key) => ({
// key: key.key,
// isGpt4: key.isGpt4,
// usage: key.usage,
// hardLimit: key.hardLimit,
// isDisabled: key.isDisabled,
// }));
// fs.writeFileSync(
// path.join(__dirname, "..", "keys.json"),
// JSON.stringify(keys, null, 2)
// );
// }
} }
/** /**
@@ -0,0 +1,49 @@
import crypto from "crypto";
import type { OpenAIKey, SerializedKey } from "../index";
import { KeySerializerBase } from "../key-serializer-base";
const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [
"key",
"service",
"hash",
"organizationId",
"promptCount",
"gpt4Tokens",
"gpt4-32kTokens",
"turboTokens",
];
export type SerializedOpenAIKey = SerializedKey &
Partial<Pick<OpenAIKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export class OpenAIKeySerializer extends KeySerializerBase<OpenAIKey> {
constructor() {
super(SERIALIZABLE_FIELDS);
}
deserialize({ key, ...rest }: SerializedOpenAIKey): OpenAIKey {
return {
key,
service: "openai",
modelFamilies: ["turbo" as const, "gpt4" as const],
isTrial: false,
isDisabled: false,
isRevoked: false,
isOverQuota: false,
lastUsed: 0,
lastChecked: 0,
promptCount: 0,
hash: `oai-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
rateLimitedAt: 0,
rateLimitRequestsReset: 0,
rateLimitTokensReset: 0,
turboTokens: 0,
gpt4Tokens: 0,
"gpt4-32kTokens": 0,
...rest,
};
}
}
+20 -84
View File
@@ -1,26 +1,15 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models"; import type { GooglePalmModelFamily } from "../../models";
import { KeyProviderBase } from "../key-provider-base";
import { Key } from "../types";
const RATE_LIMIT_LOCKOUT = 2000;
const KEY_REUSE_DELAY = 500;
// https://developers.generativeai.google.com/models/language // https://developers.generativeai.google.com/models/language
export const GOOGLE_PALM_SUPPORTED_MODELS = [ export const GOOGLE_PALM_SUPPORTED_MODELS = ["text-bison-001"] as const;
"text-bison-001",
// "chat-bison-001", no adjustable safety settings, so it's useless
] as const;
export type GooglePalmModel = (typeof GOOGLE_PALM_SUPPORTED_MODELS)[number]; export type GooglePalmModel = (typeof GOOGLE_PALM_SUPPORTED_MODELS)[number];
export type GooglePalmKeyUpdate = Omit<
Partial<GooglePalmKey>,
| "key"
| "hash"
| "lastUsed"
| "promptCount"
| "rateLimitedAt"
| "rateLimitedUntil"
>;
type GooglePalmKeyUsage = { type GooglePalmKeyUsage = {
[K in GooglePalmModelFamily as `${K}Tokens`]: number; [K in GooglePalmModelFamily as `${K}Tokens`]: number;
}; };
@@ -34,62 +23,25 @@ export interface GooglePalmKey extends Key, GooglePalmKeyUsage {
rateLimitedUntil: number; rateLimitedUntil: number;
} }
/** export class GooglePalmKeyProvider extends KeyProviderBase<GooglePalmKey> {
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
readonly service = "google-palm"; readonly service = "google-palm";
private keys: GooglePalmKey[] = []; protected keys: GooglePalmKey[] = [];
private log = logger.child({ module: "key-provider", service: this.service }); protected log = logger.child({ module: "key-provider", service: this.service });
constructor() { public async init() {
const keyConfig = config.googlePalmKey?.trim(); const storeName = this.store.constructor.name;
if (!keyConfig) { const loadedKeys = await this.store.load();
this.log.warn(
"GOOGLE_PALM_KEY is not set. PaLM API will not be available." if (loadedKeys.length === 0) {
); return this.log.warn({ via: storeName }, "No Google PaLM keys found.");
return;
} }
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: GooglePalmKey = {
key,
service: this.service,
modelFamilies: ["bison"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `plm-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
bisonTokens: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys.");
}
public init() {} this.keys.push(...loadedKeys);
this.log.info(
public list() { { count: this.keys.length, via: storeName },
return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); "Loaded PaLM keys."
);
} }
public get(_model: GooglePalmModel) { public get(_model: GooglePalmModel) {
@@ -130,22 +82,6 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return { ...selectedKey }; return { ...selectedKey };
} }
public disable(key: GooglePalmKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<GooglePalmKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, _model: string, tokens: number) { public incrementUsage(hash: string, _model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === hash);
if (!key) return; if (!key) return;
@@ -0,0 +1,42 @@
import crypto from "crypto";
import type { GooglePalmKey, SerializedKey } from "../index";
import { KeySerializerBase } from "../key-serializer-base";
const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [
"key",
"service",
"hash",
"promptCount",
"bisonTokens",
];
export type SerializedGooglePalmKey = SerializedKey &
Partial<Pick<GooglePalmKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
export class GooglePalmKeySerializer extends KeySerializerBase<GooglePalmKey> {
constructor() {
super(SERIALIZABLE_FIELDS);
}
deserialize(serializedKey: SerializedGooglePalmKey): GooglePalmKey {
const { key, ...rest } = serializedKey;
return {
key,
service: "google-palm" as const,
modelFamilies: ["bison"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `plm-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
bisonTokens: 0,
...rest,
};
}
}
+36
View File
@@ -0,0 +1,36 @@
import { assertNever } from "../utils";
import {
Key,
KeySerializer,
LLMService,
SerializedKey,
ServiceToKey,
} from "./types";
import { OpenAIKeySerializer } from "./openai/serializer";
import { AnthropicKeySerializer } from "./anthropic/serializer";
import { GooglePalmKeySerializer } from "./palm/serializer";
import { AwsBedrockKeySerializer } from "./aws/serializer";
export function assertSerializedKey(k: any): asserts k is SerializedKey {
if (typeof k !== "object" || !k || typeof (k as any).key !== "string") {
throw new Error("Invalid serialized key data");
}
}
export function getSerializer<S extends LLMService>(
service: S
): KeySerializer<ServiceToKey[S]>;
export function getSerializer(service: LLMService): KeySerializer<Key> {
switch (service) {
case "openai":
return new OpenAIKeySerializer();
case "anthropic":
return new AnthropicKeySerializer();
case "google-palm":
return new GooglePalmKeySerializer();
case "aws":
return new AwsBedrockKeySerializer();
default:
assertNever(service);
}
}
@@ -0,0 +1,125 @@
import firebase from "firebase-admin";
import { config, getFirebaseApp } from "../../../config";
import { logger } from "../../../logger";
import { assertSerializedKey } from "../serializers";
import type {
Key,
KeySerializer,
KeyStore,
LLMService,
SerializedKey,
} from "../types";
import { MemoryKeyStore } from "./index";
export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
private readonly db: firebase.database.Database;
private readonly log: typeof logger;
private readonly pendingUpdates: Map<string, Partial<SerializedKey>>;
private readonly root: string;
private readonly serializer: KeySerializer<K>;
private readonly service: LLMService;
private flushInterval: NodeJS.Timeout | null = null;
private keysRef: firebase.database.Reference | null = null;
constructor(
service: LLMService,
serializer: KeySerializer<K>,
app = getFirebaseApp()
) {
this.db = firebase.database(app);
this.log = logger.child({ module: "firebase-key-store", service });
this.root = `keys/${config.firebaseRtdbRoot.toLowerCase()}/${service}`;
this.serializer = serializer;
this.service = service;
this.pendingUpdates = new Map();
this.scheduleFlush();
}
public async load(isMigrating = false): Promise<K[]> {
const keysRef = this.db.ref(this.root);
const snapshot = await keysRef.once("value");
const keys = snapshot.val();
this.keysRef = keysRef;
if (!keys) {
if (isMigrating) return [];
this.log.warn("No keys found in Firebase. Migrating from environment.");
await this.migrate();
return this.load(true);
}
return Object.values(keys).map((k) => {
assertSerializedKey(k);
return this.serializer.deserialize(k);
});
}
public add(key: K) {
const serialized = this.serializer.serialize(key);
this.pendingUpdates.set(key.hash, serialized);
this.forceFlush();
}
public update(id: string, update: Partial<K>, force = false) {
const existing = this.pendingUpdates.get(id) ?? {};
Object.assign(existing, this.serializer.partialSerialize(id, update));
this.pendingUpdates.set(id, existing);
if (force) this.forceFlush();
}
private forceFlush() {
if (this.flushInterval) clearInterval(this.flushInterval);
this.flushInterval = setTimeout(() => this.flush(), 0);
}
private scheduleFlush() {
if (this.flushInterval) clearInterval(this.flushInterval);
this.flushInterval = setInterval(() => this.flush(), 1000 * 60 * 5);
}
private async flush() {
if (!this.keysRef) {
this.log.warn(
{ pendingUpdates: this.pendingUpdates.size },
"Database not loaded yet. Skipping flush."
);
return this.scheduleFlush();
}
if (this.pendingUpdates.size === 0) {
this.log.debug("No pending key updates to flush.");
return this.scheduleFlush();
}
const updates: Record<string, Partial<SerializedKey>> = {};
this.pendingUpdates.forEach((v, k) => (updates[k] = v));
this.pendingUpdates.clear();
console.log(updates);
await this.keysRef.update(updates);
this.log.debug(
{ count: Object.keys(updates).length },
"Flushed pending key updates."
);
this.scheduleFlush();
}
private async migrate(): Promise<SerializedKey[]> {
const keysRef = this.db.ref(this.root);
const envStore = new MemoryKeyStore<K>(this.service, this.serializer);
const keys = await envStore.load();
if (keys.length === 0) {
this.log.warn("No keys found in environment or Firebase.");
return [];
}
const updates: Record<string, SerializedKey> = {};
keys.forEach((k) => (updates[k.hash] = this.serializer.serialize(k)));
await keysRef.update(updates);
this.log.info({ count: keys.length }, "Migrated keys from environment.");
return Object.values(updates);
}
}
@@ -0,0 +1,2 @@
export { FirebaseKeyStore } from "./firebase";
export { MemoryKeyStore } from "./memory";
@@ -0,0 +1,41 @@
import { assertNever } from "../../utils";
import { Key, KeySerializer, KeyStore, LLMService } from "../types";
export class MemoryKeyStore<K extends Key> implements KeyStore<K> {
private readonly env: string;
private readonly serializer: KeySerializer<K>;
constructor(service: LLMService, serializer: KeySerializer<K>) {
switch (service) {
case "anthropic":
this.env = "ANTHROPIC_KEY";
break;
case "openai":
this.env = "OPENAI_KEY";
break;
case "google-palm":
this.env = "GOOGLE_PALM_KEY";
break;
case "aws":
this.env = "AWS_CREDENTIALS";
break;
default:
assertNever(service);
}
this.serializer = serializer;
}
public async load() {
let envKeys: string[];
envKeys = [
...new Set(process.env[this.env]?.split(",").map((k) => k.trim())),
];
return envKeys
.filter((k) => k)
.map((k) => this.serializer.deserialize({ key: k }));
}
public add() {}
public update() {}
}
+64
View File
@@ -0,0 +1,64 @@
import type { OpenAIKey, OpenAIModel } from "./openai/provider";
import type { AnthropicKey, AnthropicModel } from "./anthropic/provider";
import type { GooglePalmKey, GooglePalmModel } from "./palm/provider";
import type { AwsBedrockKey, AwsBedrockModel } from "./aws/provider";
import type { ModelFamily } from "../models";
/** The request and response format used by a model's API. */
export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text";
/**
* The service that a model is hosted on; distinct because services like AWS
* provide APIs from other service providers, but have their own authentication
* and key management.
*/
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
export type Model =
| OpenAIModel
| AnthropicModel
| GooglePalmModel
| AwsBedrockModel;
type AllKeys = OpenAIKey | AnthropicKey | GooglePalmKey | AwsBedrockKey;
export type ServiceToKey = {
[K in AllKeys["service"]]: Extract<AllKeys, { service: K }>;
};
export type SerializedKey = { key: string };
export interface Key {
/** The API key itself. Never log this, use `hash` instead. */
readonly key: string;
/** The service that this key is for. */
service: LLMService;
/** The model families that this key has access to. */
modelFamilies: ModelFamily[];
/** Whether this key is currently disabled for some reason. */
isDisabled: boolean;
/**
* Whether this key specifically has been revoked. This is different from
* `isDisabled` because a key can be disabled for other reasons, such as
* exceeding its quota. A revoked key is assumed to be permanently disabled,
* and KeyStore implementations should not return it when loading keys.
*/
isRevoked: boolean;
/** The number of prompts that have been sent with this key. */
promptCount: number;
/** The time at which this key was last used. */
lastUsed: number;
/** The time at which this key was last checked. */
lastChecked: number;
/** Hash of the key, for logging and to find the key in the pool. */
hash: string;
}
export interface KeySerializer<K> {
serialize(keyObj: K): SerializedKey;
deserialize(serializedKey: SerializedKey): K;
partialSerialize(key: string, update: Partial<K>): Partial<SerializedKey>;
}
export interface KeyStore<K extends Key> {
load(): Promise<K[]>;
add(key: K): void;
update(id: string, update: Partial<K>, force?: boolean): void;
}
+3 -3
View File
@@ -32,8 +32,8 @@ let quotaRefreshJob: schedule.Job | null = null;
let userCleanupJob: schedule.Job | null = null; let userCleanupJob: schedule.Job | null = null;
export async function init() { export async function init() {
log.info({ store: config.gatekeeperStore }, "Initializing user store..."); log.info({ store: config.persistenceProvider }, "Initializing user store...");
if (config.gatekeeperStore === "firebase_rtdb") { if (config.persistenceProvider === "firebase_rtdb") {
await initFirebase(); await initFirebase();
} }
if (config.quotaRefreshPeriod) { if (config.quotaRefreshPeriod) {
@@ -146,7 +146,7 @@ export function upsertUser(user: UserUpdate) {
usersToFlush.add(user.token); usersToFlush.add(user.token);
// Immediately schedule a flush to the database if we're using Firebase. // Immediately schedule a flush to the database if we're using Firebase.
if (config.gatekeeperStore === "firebase_rtdb") { if (config.persistenceProvider === "firebase_rtdb") {
setImmediate(flushUsers); setImmediate(flushUsers);
} }