Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b8cc5e563e | |||
| 00402c8310 | |||
| df2e986366 | |||
| f9620991e7 | |||
| dd511fe60d | |||
| ea2bfb9eef | |||
| 39436e7492 | |||
| 3b9013cd1e | |||
| 8884544b05 | |||
| 05ab8c37eb | |||
| f53e328398 | |||
| 21af866fd9 | |||
| 5d3433268f | |||
| 4114dba4f5 | |||
| e44d24a3af | |||
| d611aeee18 |
@@ -0,0 +1,4 @@
|
|||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
end_of_line = crlf
|
||||||
+32
-10
@@ -1,5 +1,6 @@
|
|||||||
import dotenv from "dotenv";
|
import dotenv from "dotenv";
|
||||||
import type firebase from "firebase-admin";
|
import type firebase from "firebase-admin";
|
||||||
|
import { hostname } from "os";
|
||||||
import pino from "pino";
|
import pino from "pino";
|
||||||
import type { ModelFamily } from "./shared/models";
|
import type { ModelFamily } from "./shared/models";
|
||||||
dotenv.config();
|
dotenv.config();
|
||||||
@@ -50,12 +51,12 @@ type Config = {
|
|||||||
*/
|
*/
|
||||||
gatekeeper: "none" | "proxy_key" | "user_token";
|
gatekeeper: "none" | "proxy_key" | "user_token";
|
||||||
/**
|
/**
|
||||||
* Persistence layer to use for user management.
|
* Persistence layer to use for user and key management.
|
||||||
* - `memory`: Users are stored in memory and are lost on restart (default)
|
* - `memory`: Data is stored in memory and lost on restart (default)
|
||||||
* - `firebase_rtdb`: Users are stored in a Firebase Realtime Database;
|
* - `firebase_rtdb`: Data is stored in Firebase Realtime Database; requires
|
||||||
* requires `firebaseKey` and `firebaseRtdbUrl` to be set.
|
* `firebaseKey` and `firebaseRtdbUrl` to be set.
|
||||||
*/
|
*/
|
||||||
gatekeeperStore: "memory" | "firebase_rtdb";
|
persistenceProvider: "memory" | "firebase_rtdb";
|
||||||
/** URL of the Firebase Realtime Database if using the Firebase RTDB store. */
|
/** URL of the Firebase Realtime Database if using the Firebase RTDB store. */
|
||||||
firebaseRtdbUrl?: string;
|
firebaseRtdbUrl?: string;
|
||||||
/**
|
/**
|
||||||
@@ -64,6 +65,19 @@ type Config = {
|
|||||||
* `private_key` field inside it.
|
* `private_key` field inside it.
|
||||||
*/
|
*/
|
||||||
firebaseKey?: string;
|
firebaseKey?: string;
|
||||||
|
/**
|
||||||
|
* The root key under which data will be stored in the Firebase RTDB. This
|
||||||
|
* allows multiple instances of the proxy to share the same database while
|
||||||
|
* keeping their data separate.
|
||||||
|
*
|
||||||
|
* If you want multiple proxies to share the same data, set all of their
|
||||||
|
* `firebaseRtdbRoot` to the same value. Beware that there will likely
|
||||||
|
* be conflicts because concurrent writes are not yet supported and proxies
|
||||||
|
* currently assume they have exclusive access to the database.
|
||||||
|
*
|
||||||
|
* Defaults to the system hostname so that data is kept separate.
|
||||||
|
*/
|
||||||
|
firebaseRtdbRoot: string;
|
||||||
/**
|
/**
|
||||||
* Maximum number of IPs per user, after which their token is disabled.
|
* Maximum number of IPs per user, after which their token is disabled.
|
||||||
* Users with the manually-assigned `special` role are exempt from this limit.
|
* Users with the manually-assigned `special` role are exempt from this limit.
|
||||||
@@ -165,10 +179,11 @@ export const config: Config = {
|
|||||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||||
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
||||||
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
|
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
|
||||||
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"),
|
persistenceProvider: getEnvWithDefault("PERSISTENCE_PROVIDER", "memory"),
|
||||||
maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0),
|
maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0),
|
||||||
firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined),
|
firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined),
|
||||||
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
|
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
|
||||||
|
firebaseRtdbRoot: getEnvWithDefault("FIREBASE_RTDB_ROOT", hostname()),
|
||||||
modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4),
|
modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4),
|
||||||
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 0),
|
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 0),
|
||||||
maxContextTokensAnthropic: getEnvWithDefault(
|
maxContextTokensAnthropic: getEnvWithDefault(
|
||||||
@@ -247,6 +262,13 @@ export async function assertConfigIsValid() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!!process.env.GATEKEEPER_STORE) {
|
||||||
|
startupLogger.warn(
|
||||||
|
"GATEKEEPER_STORE is deprecated. Use PERSISTENCE_PROVIDER instead. Configuration will be migrated."
|
||||||
|
);
|
||||||
|
config.persistenceProvider = process.env.GATEKEEPER_STORE as any;
|
||||||
|
}
|
||||||
|
|
||||||
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
|
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Invalid gatekeeper mode: ${config.gatekeeper}. Must be one of: none, proxy_key, user_token.`
|
`Invalid gatekeeper mode: ${config.gatekeeper}. Must be one of: none, proxy_key, user_token.`
|
||||||
@@ -272,11 +294,11 @@ export async function assertConfigIsValid() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
config.gatekeeperStore === "firebase_rtdb" &&
|
config.persistenceProvider === "firebase_rtdb" &&
|
||||||
(!config.firebaseKey || !config.firebaseRtdbUrl)
|
(!config.firebaseKey || !config.firebaseRtdbUrl)
|
||||||
) {
|
) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
"Firebase RTDB store requires `FIREBASE_KEY` and `FIREBASE_RTDB_URL` to be set."
|
"Firebase RTDB persistence requires `FIREBASE_KEY` and `FIREBASE_RTDB_URL` to be set."
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,9 +340,9 @@ export const OMITTED_KEYS: (keyof Config)[] = [
|
|||||||
"checkKeys",
|
"checkKeys",
|
||||||
"showTokenCosts",
|
"showTokenCosts",
|
||||||
"googleSheetsKey",
|
"googleSheetsKey",
|
||||||
|
"persistenceProvider",
|
||||||
"firebaseKey",
|
"firebaseKey",
|
||||||
"firebaseRtdbUrl",
|
"firebaseRtdbUrl",
|
||||||
"gatekeeperStore",
|
|
||||||
"maxIpsPerUser",
|
"maxIpsPerUser",
|
||||||
"blockedOrigins",
|
"blockedOrigins",
|
||||||
"blockMessage",
|
"blockMessage",
|
||||||
@@ -393,7 +415,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
|||||||
let firebaseApp: firebase.app.App | undefined;
|
let firebaseApp: firebase.app.App | undefined;
|
||||||
|
|
||||||
async function maybeInitializeFirebase() {
|
async function maybeInitializeFirebase() {
|
||||||
if (!config.gatekeeperStore.startsWith("firebase")) {
|
if (!config.persistenceProvider.startsWith("firebase")) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -4,9 +4,9 @@ import showdown from "showdown";
|
|||||||
import { config, listConfig } from "./config";
|
import { config, listConfig } from "./config";
|
||||||
import {
|
import {
|
||||||
AnthropicKey,
|
AnthropicKey,
|
||||||
|
AwsBedrockKey,
|
||||||
GooglePalmKey,
|
GooglePalmKey,
|
||||||
OpenAIKey,
|
OpenAIKey,
|
||||||
AwsBedrockKey,
|
|
||||||
keyPool,
|
keyPool,
|
||||||
} from "./shared/key-management";
|
} from "./shared/key-management";
|
||||||
import { ModelFamily, OpenAIModelFamily } from "./shared/models";
|
import { ModelFamily, OpenAIModelFamily } from "./shared/models";
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import { RequestPreprocessor } from ".";
|
|||||||
export const setApiFormat = (api: {
|
export const setApiFormat = (api: {
|
||||||
inApi: Request["inboundApi"];
|
inApi: Request["inboundApi"];
|
||||||
outApi: APIFormat;
|
outApi: APIFormat;
|
||||||
service: LLMService,
|
service: LLMService;
|
||||||
}): RequestPreprocessor => {
|
}): RequestPreprocessor => {
|
||||||
return function configureRequestApiFormat(req) {
|
return function configureRequestApiFormat(req) {
|
||||||
req.inboundApi = api.inApi;
|
req.inboundApi = api.inApi;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import zlib from "zlib";
|
|||||||
import { logger } from "../../../logger";
|
import { logger } from "../../../logger";
|
||||||
import { enqueue, trackWaitTime } from "../../queue";
|
import { enqueue, trackWaitTime } from "../../queue";
|
||||||
import { HttpError } from "../../../shared/errors";
|
import { HttpError } from "../../../shared/errors";
|
||||||
import { keyPool } from "../../../shared/key-management";
|
import { AnthropicKey, keyPool } from "../../../shared/key-management";
|
||||||
import { getOpenAIModelFamily } from "../../../shared/models";
|
import { getOpenAIModelFamily } from "../../../shared/models";
|
||||||
import { countTokens } from "../../../shared/tokenization";
|
import { countTokens } from "../../../shared/tokenization";
|
||||||
import {
|
import {
|
||||||
@@ -407,7 +407,7 @@ function maybeHandleMissingPreambleError(
|
|||||||
{ key: req.key?.hash },
|
{ key: req.key?.hash },
|
||||||
"Request failed due to missing preamble. Key will be marked as such for subsequent requests."
|
"Request failed due to missing preamble. Key will be marked as such for subsequent requests."
|
||||||
);
|
);
|
||||||
keyPool.update(req.key!, { requiresPreamble: true });
|
keyPool.update(req.key as AnthropicKey, { requiresPreamble: true });
|
||||||
reenqueueRequest(req);
|
reenqueueRequest(req);
|
||||||
throw new RetryableError("Claude request re-enqueued to add preamble.");
|
throw new RetryableError("Claude request re-enqueued to add preamble.");
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import {
|
|||||||
mergeEventsForAnthropic,
|
mergeEventsForAnthropic,
|
||||||
mergeEventsForOpenAIChat,
|
mergeEventsForOpenAIChat,
|
||||||
mergeEventsForOpenAIText,
|
mergeEventsForOpenAIText,
|
||||||
OpenAIChatCompletionStreamEvent
|
OpenAIChatCompletionStreamEvent,
|
||||||
} from "./index";
|
} from "./index";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
+2
-2
@@ -16,7 +16,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import type { Handler, Request } from "express";
|
import type { Handler, Request } from "express";
|
||||||
import { keyPool, SupportedModel } from "../shared/key-management";
|
import { keyPool } from "../shared/key-management";
|
||||||
import {
|
import {
|
||||||
getClaudeModelFamily,
|
getClaudeModelFamily,
|
||||||
getGooglePalmModelFamily,
|
getGooglePalmModelFamily,
|
||||||
@@ -138,7 +138,7 @@ function getPartitionForRequest(req: Request): ModelFamily {
|
|||||||
// There is a single request queue, but it is partitioned by model family.
|
// There is a single request queue, but it is partitioned by model family.
|
||||||
// Model families are typically separated on cost/rate limit boundaries so
|
// Model families are typically separated on cost/rate limit boundaries so
|
||||||
// they should be treated as separate queues.
|
// they should be treated as separate queues.
|
||||||
const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo";
|
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||||
|
|
||||||
// Weird special case for AWS because they serve multiple models from
|
// Weird special case for AWS because they serve multiple models from
|
||||||
// different vendors, even if currently only one is supported.
|
// different vendors, even if currently only one is supported.
|
||||||
|
|||||||
+10
-9
@@ -5,16 +5,16 @@ import cors from "cors";
|
|||||||
import path from "path";
|
import path from "path";
|
||||||
import pinoHttp from "pino-http";
|
import pinoHttp from "pino-http";
|
||||||
import childProcess from "child_process";
|
import childProcess from "child_process";
|
||||||
import { logger } from "./logger";
|
|
||||||
import { keyPool } from "./shared/key-management";
|
|
||||||
import { adminRouter } from "./admin/routes";
|
|
||||||
import { proxyRouter } from "./proxy/routes";
|
|
||||||
import { handleInfoPage } from "./info-page";
|
import { handleInfoPage } from "./info-page";
|
||||||
import { logQueue } from "./shared/prompt-logging";
|
import { logger } from "./logger";
|
||||||
import { start as startRequestQueue } from "./proxy/queue";
|
import { adminRouter } from "./admin/routes";
|
||||||
import { init as initUserStore } from "./shared/users/user-store";
|
|
||||||
import { init as initTokenizers } from "./shared/tokenization";
|
|
||||||
import { checkOrigin } from "./proxy/check-origin";
|
import { checkOrigin } from "./proxy/check-origin";
|
||||||
|
import { start as startRequestQueue } from "./proxy/queue";
|
||||||
|
import { proxyRouter } from "./proxy/routes";
|
||||||
|
import { init as initKeyPool } from "./shared/key-management/key-pool";
|
||||||
|
import { logQueue } from "./shared/prompt-logging";
|
||||||
|
import { init as initTokenizers } from "./shared/tokenization";
|
||||||
|
import { init as initUserStore } from "./shared/users/user-store";
|
||||||
import { userRouter } from "./user/routes";
|
import { userRouter } from "./user/routes";
|
||||||
|
|
||||||
const PORT = config.port;
|
const PORT = config.port;
|
||||||
@@ -92,7 +92,8 @@ async function start() {
|
|||||||
logger.info("Checking configs and external dependencies...");
|
logger.info("Checking configs and external dependencies...");
|
||||||
await assertConfigIsValid();
|
await assertConfigIsValid();
|
||||||
|
|
||||||
keyPool.init();
|
logger.info("Starting key pool...");
|
||||||
|
await initKeyPool();
|
||||||
|
|
||||||
await initTokenizers();
|
await initTokenizers();
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ export const injectLocals: RequestHandler = (req, res, next) => {
|
|||||||
quota.turbo > 0 || quota.gpt4 > 0 || quota.claude > 0;
|
quota.turbo > 0 || quota.gpt4 > 0 || quota.claude > 0;
|
||||||
res.locals.quota = quota;
|
res.locals.quota = quota;
|
||||||
res.locals.nextQuotaRefresh = userStore.getNextQuotaRefresh();
|
res.locals.nextQuotaRefresh = userStore.getNextQuotaRefresh();
|
||||||
res.locals.persistenceEnabled = config.gatekeeperStore !== "memory";
|
res.locals.persistenceEnabled = config.persistenceProvider !== "memory";
|
||||||
res.locals.showTokenCosts = config.showTokenCosts;
|
res.locals.showTokenCosts = config.showTokenCosts;
|
||||||
res.locals.maxIps = config.maxIpsPerUser;
|
res.locals.maxIps = config.maxIpsPerUser;
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
import crypto from "crypto";
|
|
||||||
import { Key, KeyProvider } from "..";
|
|
||||||
import { config } from "../../../config";
|
import { config } from "../../../config";
|
||||||
import { logger } from "../../../logger";
|
import { logger } from "../../../logger";
|
||||||
import type { AnthropicModelFamily } from "../../models";
|
import type { AnthropicModelFamily } from "../../models";
|
||||||
|
import { KeyProviderBase } from "../key-provider-base";
|
||||||
|
import { Key } from "../types";
|
||||||
import { AnthropicKeyChecker } from "./checker";
|
import { AnthropicKeyChecker } from "./checker";
|
||||||
|
|
||||||
|
const RATE_LIMIT_LOCKOUT = 2000;
|
||||||
|
const KEY_REUSE_DELAY = 500;
|
||||||
|
|
||||||
// https://docs.anthropic.com/claude/reference/selecting-a-model
|
// https://docs.anthropic.com/claude/reference/selecting-a-model
|
||||||
export const ANTHROPIC_SUPPORTED_MODELS = [
|
export const ANTHROPIC_SUPPORTED_MODELS = [
|
||||||
"claude-instant-v1",
|
"claude-instant-v1",
|
||||||
@@ -15,16 +18,6 @@ export const ANTHROPIC_SUPPORTED_MODELS = [
|
|||||||
] as const;
|
] as const;
|
||||||
export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number];
|
export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number];
|
||||||
|
|
||||||
export type AnthropicKeyUpdate = Omit<
|
|
||||||
Partial<AnthropicKey>,
|
|
||||||
| "key"
|
|
||||||
| "hash"
|
|
||||||
| "lastUsed"
|
|
||||||
| "promptCount"
|
|
||||||
| "rateLimitedAt"
|
|
||||||
| "rateLimitedUntil"
|
|
||||||
>;
|
|
||||||
|
|
||||||
type AnthropicKeyUsage = {
|
type AnthropicKeyUsage = {
|
||||||
[K in AnthropicModelFamily as `${K}Tokens`]: number;
|
[K in AnthropicModelFamily as `${K}Tokens`]: number;
|
||||||
};
|
};
|
||||||
@@ -51,72 +44,33 @@ export interface AnthropicKey extends Key, AnthropicKeyUsage {
|
|||||||
isPozzed: boolean;
|
isPozzed: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
export class AnthropicKeyProvider extends KeyProviderBase<AnthropicKey> {
|
||||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
readonly service = "anthropic" as const;
|
||||||
* while we wait for other concurrent requests to finish.
|
|
||||||
*/
|
|
||||||
const RATE_LIMIT_LOCKOUT = 2000;
|
|
||||||
/**
|
|
||||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
|
||||||
* to be used again. This is to prevent the queue from flooding a key with too
|
|
||||||
* many requests while we wait to learn whether previous ones succeeded.
|
|
||||||
*/
|
|
||||||
const KEY_REUSE_DELAY = 500;
|
|
||||||
|
|
||||||
export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
protected readonly keys: AnthropicKey[] = [];
|
||||||
readonly service = "anthropic";
|
|
||||||
|
|
||||||
private keys: AnthropicKey[] = [];
|
|
||||||
private checker?: AnthropicKeyChecker;
|
private checker?: AnthropicKeyChecker;
|
||||||
private log = logger.child({ module: "key-provider", service: this.service });
|
protected log = logger.child({ module: "key-provider", service: this.service });
|
||||||
|
|
||||||
constructor() {
|
public async init() {
|
||||||
const keyConfig = config.anthropicKey?.trim();
|
const storeName = this.store.constructor.name;
|
||||||
if (!keyConfig) {
|
const loadedKeys = await this.store.load();
|
||||||
this.log.warn(
|
|
||||||
"ANTHROPIC_KEY is not set. Anthropic API will not be available."
|
if (loadedKeys.length === 0) {
|
||||||
|
return this.log.warn({ via: storeName }, "No Anthropic keys found.");
|
||||||
|
}
|
||||||
|
|
||||||
|
this.keys.push(...loadedKeys);
|
||||||
|
this.log.info(
|
||||||
|
{ count: this.keys.length, via: storeName },
|
||||||
|
"Loaded Anthropic keys."
|
||||||
);
|
);
|
||||||
return;
|
|
||||||
}
|
|
||||||
let bareKeys: string[];
|
|
||||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
|
||||||
for (const key of bareKeys) {
|
|
||||||
const newKey: AnthropicKey = {
|
|
||||||
key,
|
|
||||||
service: this.service,
|
|
||||||
modelFamilies: ["claude"],
|
|
||||||
isDisabled: false,
|
|
||||||
isRevoked: false,
|
|
||||||
isPozzed: false,
|
|
||||||
promptCount: 0,
|
|
||||||
lastUsed: 0,
|
|
||||||
rateLimitedAt: 0,
|
|
||||||
rateLimitedUntil: 0,
|
|
||||||
requiresPreamble: false,
|
|
||||||
hash: `ant-${crypto
|
|
||||||
.createHash("sha256")
|
|
||||||
.update(key)
|
|
||||||
.digest("hex")
|
|
||||||
.slice(0, 8)}`,
|
|
||||||
lastChecked: 0,
|
|
||||||
claudeTokens: 0,
|
|
||||||
};
|
|
||||||
this.keys.push(newKey);
|
|
||||||
}
|
|
||||||
this.log.info({ keyCount: this.keys.length }, "Loaded Anthropic keys.");
|
|
||||||
}
|
|
||||||
|
|
||||||
public init() {
|
|
||||||
if (config.checkKeys) {
|
if (config.checkKeys) {
|
||||||
this.checker = new AnthropicKeyChecker(this.keys, this.update.bind(this));
|
this.checker = new AnthropicKeyChecker(this.keys, this.update.bind(this));
|
||||||
this.checker.start();
|
this.checker.start();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public list() {
|
|
||||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
|
||||||
}
|
|
||||||
|
|
||||||
public get(_model: AnthropicModel) {
|
public get(_model: AnthropicModel) {
|
||||||
// Currently, all Anthropic keys have access to all models. This will almost
|
// Currently, all Anthropic keys have access to all models. This will almost
|
||||||
// certainly change when they move out of beta later this year.
|
// certainly change when they move out of beta later this year.
|
||||||
@@ -161,22 +115,6 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
|||||||
return { ...selectedKey };
|
return { ...selectedKey };
|
||||||
}
|
}
|
||||||
|
|
||||||
public disable(key: AnthropicKey) {
|
|
||||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
|
||||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
|
||||||
keyFromPool.isDisabled = true;
|
|
||||||
this.log.warn({ key: key.hash }, "Key disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
public update(hash: string, update: Partial<AnthropicKey>) {
|
|
||||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
|
||||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
|
||||||
}
|
|
||||||
|
|
||||||
public available() {
|
|
||||||
return this.keys.filter((k) => !k.isDisabled).length;
|
|
||||||
}
|
|
||||||
|
|
||||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||||
const key = this.keys.find((k) => k.hash === hash);
|
const key = this.keys.find((k) => k.hash === hash);
|
||||||
if (!key) return;
|
if (!key) return;
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import type { AnthropicKey, SerializedKey } from "../index";
|
||||||
|
import { KeySerializerBase } from "../key-serializer-base";
|
||||||
|
|
||||||
|
const SERIALIZABLE_FIELDS: (keyof AnthropicKey)[] = [
|
||||||
|
"key",
|
||||||
|
"service",
|
||||||
|
"hash",
|
||||||
|
"promptCount",
|
||||||
|
"claudeTokens",
|
||||||
|
];
|
||||||
|
export type SerializedAnthropicKey = SerializedKey &
|
||||||
|
Partial<Pick<AnthropicKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
|
||||||
|
|
||||||
|
export class AnthropicKeySerializer extends KeySerializerBase<AnthropicKey> {
|
||||||
|
constructor() {
|
||||||
|
super(SERIALIZABLE_FIELDS);
|
||||||
|
}
|
||||||
|
|
||||||
|
deserialize({ key, ...rest }: SerializedAnthropicKey): AnthropicKey {
|
||||||
|
return {
|
||||||
|
key,
|
||||||
|
service: "anthropic" as const,
|
||||||
|
modelFamilies: ["claude" as const],
|
||||||
|
isDisabled: false,
|
||||||
|
isRevoked: false,
|
||||||
|
isPozzed: false,
|
||||||
|
promptCount: 0,
|
||||||
|
lastUsed: 0,
|
||||||
|
rateLimitedAt: 0,
|
||||||
|
rateLimitedUntil: 0,
|
||||||
|
requiresPreamble: false,
|
||||||
|
hash: `ant-${crypto
|
||||||
|
.createHash("sha256")
|
||||||
|
.update(key)
|
||||||
|
.digest("hex")
|
||||||
|
.slice(0, 8)}`,
|
||||||
|
lastChecked: 0,
|
||||||
|
claudeTokens: 0,
|
||||||
|
...rest,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
import crypto from "crypto";
|
|
||||||
import { Key, KeyProvider } from "..";
|
|
||||||
import { config } from "../../../config";
|
import { config } from "../../../config";
|
||||||
import { logger } from "../../../logger";
|
import { logger } from "../../../logger";
|
||||||
import type { AwsBedrockModelFamily } from "../../models";
|
import type { AwsBedrockModelFamily } from "../../models";
|
||||||
|
import { KeyProviderBase } from "../key-provider-base";
|
||||||
|
import { Key } from "../types";
|
||||||
import { AwsKeyChecker } from "./checker";
|
import { AwsKeyChecker } from "./checker";
|
||||||
|
|
||||||
|
const RATE_LIMIT_LOCKOUT = 2000;
|
||||||
|
const KEY_REUSE_DELAY = 500;
|
||||||
|
|
||||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||||
export const AWS_BEDROCK_SUPPORTED_MODELS = [
|
export const AWS_BEDROCK_SUPPORTED_MODELS = [
|
||||||
"anthropic.claude-v1",
|
"anthropic.claude-v1",
|
||||||
@@ -33,71 +36,33 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
|||||||
awsLoggingStatus: "unknown" | "disabled" | "enabled";
|
awsLoggingStatus: "unknown" | "disabled" | "enabled";
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
export class AwsBedrockKeyProvider extends KeyProviderBase<AwsBedrockKey> {
|
||||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
readonly service = "aws" as const;
|
||||||
* while we wait for other concurrent requests to finish.
|
|
||||||
*/
|
|
||||||
const RATE_LIMIT_LOCKOUT = 300;
|
|
||||||
/**
|
|
||||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
|
||||||
* to be used again. This is to prevent the queue from flooding a key with too
|
|
||||||
* many requests while we wait to learn whether previous ones succeeded.
|
|
||||||
*/
|
|
||||||
const KEY_REUSE_DELAY = 500;
|
|
||||||
|
|
||||||
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
protected readonly keys: AwsBedrockKey[] = [];
|
||||||
readonly service = "aws";
|
|
||||||
|
|
||||||
private keys: AwsBedrockKey[] = [];
|
|
||||||
private checker?: AwsKeyChecker;
|
private checker?: AwsKeyChecker;
|
||||||
private log = logger.child({ module: "key-provider", service: this.service });
|
protected log = logger.child({ module: "key-provider", service: this.service });
|
||||||
|
|
||||||
constructor() {
|
public async init() {
|
||||||
const keyConfig = config.awsCredentials?.trim();
|
const storeName = this.store.constructor.name;
|
||||||
if (!keyConfig) {
|
const loadedKeys = await this.store.load();
|
||||||
this.log.warn(
|
|
||||||
"AWS_CREDENTIALS is not set. AWS Bedrock API will not be available."
|
if (loadedKeys.length === 0) {
|
||||||
|
return this.log.warn({ via: storeName }, "No AWS credentials found.");
|
||||||
|
}
|
||||||
|
|
||||||
|
this.keys.push(...loadedKeys);
|
||||||
|
this.log.info(
|
||||||
|
{ count: this.keys.length, via: storeName },
|
||||||
|
"Loaded AWS Bedrock keys."
|
||||||
);
|
);
|
||||||
return;
|
|
||||||
}
|
|
||||||
let bareKeys: string[];
|
|
||||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
|
||||||
for (const key of bareKeys) {
|
|
||||||
const newKey: AwsBedrockKey = {
|
|
||||||
key,
|
|
||||||
service: this.service,
|
|
||||||
modelFamilies: ["aws-claude"],
|
|
||||||
isDisabled: false,
|
|
||||||
isRevoked: false,
|
|
||||||
promptCount: 0,
|
|
||||||
lastUsed: 0,
|
|
||||||
rateLimitedAt: 0,
|
|
||||||
rateLimitedUntil: 0,
|
|
||||||
awsLoggingStatus: "unknown",
|
|
||||||
hash: `aws-${crypto
|
|
||||||
.createHash("sha256")
|
|
||||||
.update(key)
|
|
||||||
.digest("hex")
|
|
||||||
.slice(0, 8)}`,
|
|
||||||
lastChecked: 0,
|
|
||||||
["aws-claudeTokens"]: 0,
|
|
||||||
};
|
|
||||||
this.keys.push(newKey);
|
|
||||||
}
|
|
||||||
this.log.info({ keyCount: this.keys.length }, "Loaded AWS Bedrock keys.");
|
|
||||||
}
|
|
||||||
|
|
||||||
public init() {
|
|
||||||
if (config.checkKeys) {
|
if (config.checkKeys) {
|
||||||
this.checker = new AwsKeyChecker(this.keys, this.update.bind(this));
|
this.checker = new AwsKeyChecker(this.keys, this.update.bind(this));
|
||||||
this.checker.start();
|
this.checker.start();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public list() {
|
|
||||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
|
||||||
}
|
|
||||||
|
|
||||||
public get(_model: AwsBedrockModel) {
|
public get(_model: AwsBedrockModel) {
|
||||||
const availableKeys = this.keys.filter((k) => {
|
const availableKeys = this.keys.filter((k) => {
|
||||||
const isNotLogged = k.awsLoggingStatus === "disabled";
|
const isNotLogged = k.awsLoggingStatus === "disabled";
|
||||||
@@ -139,22 +104,6 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
|||||||
return { ...selectedKey };
|
return { ...selectedKey };
|
||||||
}
|
}
|
||||||
|
|
||||||
public disable(key: AwsBedrockKey) {
|
|
||||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
|
||||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
|
||||||
keyFromPool.isDisabled = true;
|
|
||||||
this.log.warn({ key: key.hash }, "Key disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
public update(hash: string, update: Partial<AwsBedrockKey>) {
|
|
||||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
|
||||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
|
||||||
}
|
|
||||||
|
|
||||||
public available() {
|
|
||||||
return this.keys.filter((k) => !k.isDisabled).length;
|
|
||||||
}
|
|
||||||
|
|
||||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||||
const key = this.keys.find((k) => k.hash === hash);
|
const key = this.keys.find((k) => k.hash === hash);
|
||||||
if (!key) return;
|
if (!key) return;
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import type { AwsBedrockKey, SerializedKey } from "../index";
|
||||||
|
import { KeySerializerBase } from "../key-serializer-base";
|
||||||
|
|
||||||
|
const SERIALIZABLE_FIELDS: (keyof AwsBedrockKey)[] = [
|
||||||
|
"key",
|
||||||
|
"service",
|
||||||
|
"hash",
|
||||||
|
"promptCount",
|
||||||
|
"aws-claudeTokens",
|
||||||
|
];
|
||||||
|
export type SerializedAwsBedrockKey = SerializedKey &
|
||||||
|
Partial<Pick<AwsBedrockKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
|
||||||
|
|
||||||
|
export class AwsBedrockKeySerializer extends KeySerializerBase<AwsBedrockKey> {
|
||||||
|
constructor() {
|
||||||
|
super(SERIALIZABLE_FIELDS);
|
||||||
|
}
|
||||||
|
|
||||||
|
deserialize(serializedKey: SerializedAwsBedrockKey): AwsBedrockKey {
|
||||||
|
const { key, ...rest } = serializedKey;
|
||||||
|
return {
|
||||||
|
key,
|
||||||
|
service: "aws",
|
||||||
|
modelFamilies: ["aws-claude"],
|
||||||
|
isDisabled: false,
|
||||||
|
isRevoked: false,
|
||||||
|
promptCount: 0,
|
||||||
|
lastUsed: 0,
|
||||||
|
rateLimitedAt: 0,
|
||||||
|
rateLimitedUntil: 0,
|
||||||
|
awsLoggingStatus: "unknown",
|
||||||
|
hash: `aws-${crypto
|
||||||
|
.createHash("sha256")
|
||||||
|
.update(key)
|
||||||
|
.digest("hex")
|
||||||
|
.slice(0, 8)}`,
|
||||||
|
lastChecked: 0,
|
||||||
|
["aws-claudeTokens"]: 0,
|
||||||
|
...rest,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,83 +1,10 @@
|
|||||||
import { OPENAI_SUPPORTED_MODELS, OpenAIModel } from "./openai/provider";
|
export { keyPool } from "./key-pool";
|
||||||
import {
|
export { OPENAI_SUPPORTED_MODELS } from "./openai/provider";
|
||||||
ANTHROPIC_SUPPORTED_MODELS,
|
export { ANTHROPIC_SUPPORTED_MODELS } from "./anthropic/provider";
|
||||||
AnthropicModel,
|
export { GOOGLE_PALM_SUPPORTED_MODELS } from "./palm/provider";
|
||||||
} from "./anthropic/provider";
|
export { AWS_BEDROCK_SUPPORTED_MODELS } from "./aws/provider";
|
||||||
import { GOOGLE_PALM_SUPPORTED_MODELS, GooglePalmModel } from "./palm/provider";
|
export type { AnthropicKey } from "./anthropic/provider";
|
||||||
import { AWS_BEDROCK_SUPPORTED_MODELS, AwsBedrockModel } from "./aws/provider";
|
export type { OpenAIKey } from "./openai/provider";
|
||||||
import { KeyPool } from "./key-pool";
|
export type { GooglePalmKey } from "./palm/provider";
|
||||||
import type { ModelFamily } from "../models";
|
export type { AwsBedrockKey } from "./aws/provider";
|
||||||
|
export * from "./types";
|
||||||
/** The request and response format used by a model's API. */
|
|
||||||
export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text";
|
|
||||||
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
|
|
||||||
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
|
|
||||||
export type Model =
|
|
||||||
| OpenAIModel
|
|
||||||
| AnthropicModel
|
|
||||||
| GooglePalmModel
|
|
||||||
| AwsBedrockModel;
|
|
||||||
|
|
||||||
export interface Key {
|
|
||||||
/** The API key itself. Never log this, use `hash` instead. */
|
|
||||||
readonly key: string;
|
|
||||||
/** The service that this key is for. */
|
|
||||||
service: LLMService;
|
|
||||||
/** The model families that this key has access to. */
|
|
||||||
modelFamilies: ModelFamily[];
|
|
||||||
/** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */
|
|
||||||
isDisabled: boolean;
|
|
||||||
/** Whether this key specifically has been revoked. */
|
|
||||||
isRevoked: boolean;
|
|
||||||
/** The number of prompts that have been sent with this key. */
|
|
||||||
promptCount: number;
|
|
||||||
/** The time at which this key was last used. */
|
|
||||||
lastUsed: number;
|
|
||||||
/** The time at which this key was last checked. */
|
|
||||||
lastChecked: number;
|
|
||||||
/** Hash of the key, for logging and to find the key in the pool. */
|
|
||||||
hash: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
KeyPool and KeyProvider's similarities are a relic of the old design where
|
|
||||||
there was only a single KeyPool for OpenAI keys. Now that there are multiple
|
|
||||||
supported services, the service-specific functionality has been moved to
|
|
||||||
KeyProvider and KeyPool is just a wrapper around multiple KeyProviders,
|
|
||||||
delegating to the appropriate one based on the model requested.
|
|
||||||
|
|
||||||
Existing code will continue to call methods on KeyPool, which routes them to
|
|
||||||
the appropriate KeyProvider or returns data aggregated across all KeyProviders
|
|
||||||
for service-agnostic functionality.
|
|
||||||
*/
|
|
||||||
|
|
||||||
export interface KeyProvider<T extends Key = Key> {
|
|
||||||
readonly service: LLMService;
|
|
||||||
init(): void;
|
|
||||||
get(model: Model): T;
|
|
||||||
list(): Omit<T, "key">[];
|
|
||||||
disable(key: T): void;
|
|
||||||
update(hash: string, update: Partial<T>): void;
|
|
||||||
available(): number;
|
|
||||||
incrementUsage(hash: string, model: string, tokens: number): void;
|
|
||||||
getLockoutPeriod(model: Model): number;
|
|
||||||
markRateLimited(hash: string): void;
|
|
||||||
recheck(): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export const keyPool = new KeyPool();
|
|
||||||
export const SUPPORTED_MODELS = [
|
|
||||||
...OPENAI_SUPPORTED_MODELS,
|
|
||||||
...ANTHROPIC_SUPPORTED_MODELS,
|
|
||||||
] as const;
|
|
||||||
export type SupportedModel = (typeof SUPPORTED_MODELS)[number];
|
|
||||||
export {
|
|
||||||
OPENAI_SUPPORTED_MODELS,
|
|
||||||
ANTHROPIC_SUPPORTED_MODELS,
|
|
||||||
GOOGLE_PALM_SUPPORTED_MODELS,
|
|
||||||
AWS_BEDROCK_SUPPORTED_MODELS,
|
|
||||||
};
|
|
||||||
export { AnthropicKey } from "./anthropic/provider";
|
|
||||||
export { OpenAIKey } from "./openai/provider";
|
|
||||||
export { GooglePalmKey } from "./palm/provider";
|
|
||||||
export { AwsBedrockKey } from "./aws/provider";
|
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
|
import { AxiosError } from "axios";
|
||||||
import pino from "pino";
|
import pino from "pino";
|
||||||
import { logger } from "../../logger";
|
import { logger } from "../../logger";
|
||||||
import { Key } from "./index";
|
import { Key } from "./types";
|
||||||
import { AxiosError } from "axios";
|
|
||||||
|
|
||||||
type KeyCheckerOptions = {
|
type KeyCheckerOptions = {
|
||||||
service: string;
|
service: string;
|
||||||
keyCheckPeriod: number;
|
keyCheckPeriod: number;
|
||||||
minCheckInterval: number;
|
minCheckInterval: number;
|
||||||
}
|
};
|
||||||
|
|
||||||
export abstract class KeyCheckerBase<TKey extends Key> {
|
export abstract class KeyCheckerBase<TKey extends Key> {
|
||||||
protected readonly service: string;
|
protected readonly service: string;
|
||||||
|
|||||||
@@ -4,34 +4,36 @@ import os from "os";
|
|||||||
import schedule from "node-schedule";
|
import schedule from "node-schedule";
|
||||||
import { config } from "../../config";
|
import { config } from "../../config";
|
||||||
import { logger } from "../../logger";
|
import { logger } from "../../logger";
|
||||||
import { Key, Model, KeyProvider, LLMService } from "./index";
|
import { KeyProviderBase } from "./key-provider-base";
|
||||||
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
import { getSerializer } from "./serializers";
|
||||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
import { FirebaseKeyStore, MemoryKeyStore } from "./stores";
|
||||||
|
import { AnthropicKeyProvider } from "./anthropic/provider";
|
||||||
|
import { OpenAIKeyProvider } from "./openai/provider";
|
||||||
import { GooglePalmKeyProvider } from "./palm/provider";
|
import { GooglePalmKeyProvider } from "./palm/provider";
|
||||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||||
|
import { Key, KeyStore, LLMService, Model, ServiceToKey } from "./types";
|
||||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
|
||||||
|
|
||||||
export class KeyPool {
|
export class KeyPool {
|
||||||
private keyProviders: KeyProvider[] = [];
|
private keyProviders: KeyProviderBase[] = [];
|
||||||
private recheckJobs: Partial<Record<LLMService, schedule.Job | null>> = {
|
private recheckJobs: Partial<Record<LLMService, schedule.Job | null>> = {
|
||||||
openai: null,
|
openai: null,
|
||||||
};
|
};
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.keyProviders.push(new OpenAIKeyProvider());
|
this.keyProviders.push(
|
||||||
this.keyProviders.push(new AnthropicKeyProvider());
|
new OpenAIKeyProvider(createKeyStore("openai")),
|
||||||
this.keyProviders.push(new GooglePalmKeyProvider());
|
new AnthropicKeyProvider(createKeyStore("anthropic")),
|
||||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
new GooglePalmKeyProvider(createKeyStore("google-palm")),
|
||||||
|
new AwsBedrockKeyProvider(createKeyStore("aws"))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public init() {
|
public async init() {
|
||||||
this.keyProviders.forEach((provider) => provider.init());
|
await Promise.all(this.keyProviders.map((p) => p.init()));
|
||||||
|
|
||||||
const availableKeys = this.available("all");
|
const availableKeys = this.available("all");
|
||||||
if (availableKeys === 0) {
|
if (availableKeys === 0) {
|
||||||
throw new Error(
|
throw new Error("No keys loaded, the application cannot start.");
|
||||||
"No keys loaded. Ensure that at least one key is configured."
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
this.scheduleRecheck();
|
this.scheduleRecheck();
|
||||||
}
|
}
|
||||||
@@ -59,7 +61,7 @@ export class KeyPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public update(key: Key, props: AllowedPartial): void {
|
public update<T extends Key>(key: T, props: Partial<T>): void {
|
||||||
const service = this.getKeyProvider(key.service);
|
const service = this.getKeyProvider(key.service);
|
||||||
service.update(key.hash, props);
|
service.update(key.hash, props);
|
||||||
}
|
}
|
||||||
@@ -122,7 +124,7 @@ export class KeyPool {
|
|||||||
throw new Error(`Unknown service for model '${model}'`);
|
throw new Error(`Unknown service for model '${model}'`);
|
||||||
}
|
}
|
||||||
|
|
||||||
private getKeyProvider(service: LLMService): KeyProvider {
|
private getKeyProvider(service: LLMService): KeyProviderBase {
|
||||||
return this.keyProviders.find((provider) => provider.service === service)!;
|
return this.keyProviders.find((provider) => provider.service === service)!;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,3 +153,25 @@ export class KeyPool {
|
|||||||
this.recheckJobs.openai = job;
|
this.recheckJobs.openai = job;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function createKeyStore<S extends LLMService>(
|
||||||
|
service: S
|
||||||
|
): KeyStore<ServiceToKey[S]> {
|
||||||
|
const serializer = getSerializer(service);
|
||||||
|
|
||||||
|
switch (config.persistenceProvider) {
|
||||||
|
case "memory":
|
||||||
|
return new MemoryKeyStore(service, serializer);
|
||||||
|
case "firebase_rtdb":
|
||||||
|
return new FirebaseKeyStore(service, serializer);
|
||||||
|
default:
|
||||||
|
throw new Error(`Unknown store type: ${config.persistenceProvider}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export let keyPool: KeyPool;
|
||||||
|
|
||||||
|
export async function init() {
|
||||||
|
keyPool = new KeyPool();
|
||||||
|
await keyPool.init();
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
import { logger } from "../../logger";
|
||||||
|
import { Key, KeyStore, LLMService, Model } from "./types";
|
||||||
|
|
||||||
|
export abstract class KeyProviderBase<K extends Key = Key> {
|
||||||
|
public abstract readonly service: LLMService;
|
||||||
|
|
||||||
|
protected abstract readonly keys: K[];
|
||||||
|
protected abstract log: typeof logger;
|
||||||
|
protected readonly store: KeyStore<K>;
|
||||||
|
|
||||||
|
public constructor(keyStore: KeyStore<K>) {
|
||||||
|
this.store = keyStore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract init(): Promise<void>;
|
||||||
|
|
||||||
|
public addKey(key: K): void {
|
||||||
|
this.keys.push(key);
|
||||||
|
this.store.add(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract get(model: Model): K;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a list of all keys, with the actual key value removed. Don't
|
||||||
|
* mutate the returned objects; use `update` instead to ensure the changes
|
||||||
|
* are synced to the key store.
|
||||||
|
*/
|
||||||
|
public list(): Omit<K, "key">[] {
|
||||||
|
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||||
|
}
|
||||||
|
|
||||||
|
public disable(key: K): void {
|
||||||
|
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||||
|
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||||
|
this.update(key.hash, { isDisabled: true } as Partial<K>, true);
|
||||||
|
this.log.warn({ key: key.hash }, "Key disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
public update(hash: string, update: Partial<K>, force = false): void {
|
||||||
|
const key = this.keys.find((k) => k.hash === hash);
|
||||||
|
if (!key) {
|
||||||
|
throw new Error(`No key with hash ${hash}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
Object.assign(key, { lastChecked: Date.now(), ...update });
|
||||||
|
this.store.update(hash, update, force);
|
||||||
|
}
|
||||||
|
|
||||||
|
public available(): number {
|
||||||
|
return this.keys.filter((k) => !k.isDisabled).length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract incrementUsage(
|
||||||
|
hash: string,
|
||||||
|
model: string,
|
||||||
|
tokens: number
|
||||||
|
): void;
|
||||||
|
|
||||||
|
public abstract getLockoutPeriod(model: Model): number;
|
||||||
|
|
||||||
|
public abstract markRateLimited(hash: string): void;
|
||||||
|
|
||||||
|
public abstract recheck(): void;
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
import { Key, KeySerializer, SerializedKey } from "./types";
|
||||||
|
|
||||||
|
export abstract class KeySerializerBase<K extends Key>
|
||||||
|
implements KeySerializer<K>
|
||||||
|
{
|
||||||
|
protected constructor(protected serializableFields: (keyof K)[]) {}
|
||||||
|
|
||||||
|
serialize(keyObj: K): SerializedKey {
|
||||||
|
return {
|
||||||
|
...Object.fromEntries(
|
||||||
|
this.serializableFields
|
||||||
|
.map((f) => [f, keyObj[f]])
|
||||||
|
.filter(([, v]) => v !== undefined)
|
||||||
|
),
|
||||||
|
key: keyObj.key,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
partialSerialize(key: string, update: Partial<K>): Partial<SerializedKey> {
|
||||||
|
return {
|
||||||
|
...Object.fromEntries(
|
||||||
|
this.serializableFields
|
||||||
|
.map((f) => [f, update[f]])
|
||||||
|
.filter(([, v]) => v !== undefined)
|
||||||
|
),
|
||||||
|
key,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract deserialize(serializedKey: SerializedKey): K;
|
||||||
|
}
|
||||||
@@ -1,28 +1,23 @@
|
|||||||
/* Manages OpenAI API keys. Tracks usage, disables expired keys, and provides
|
|
||||||
round-robin access to keys. Keys are stored in the OPENAI_KEY environment
|
|
||||||
variable as a comma-separated list of keys. */
|
|
||||||
import crypto from "crypto";
|
import crypto from "crypto";
|
||||||
import http from "http";
|
import { IncomingHttpHeaders } from "http";
|
||||||
import { Key, KeyProvider, Model } from "../index";
|
|
||||||
import { config } from "../../../config";
|
import { config } from "../../../config";
|
||||||
import { logger } from "../../../logger";
|
import { logger } from "../../../logger";
|
||||||
import { OpenAIKeyChecker } from "./checker";
|
|
||||||
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
|
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
|
||||||
|
import { Key, Model } from "../types";
|
||||||
|
import { OpenAIKeyChecker } from "./checker";
|
||||||
|
import { KeyProviderBase } from "../key-provider-base";
|
||||||
|
|
||||||
export type OpenAIModel =
|
const KEY_REUSE_DELAY = 1000;
|
||||||
| "gpt-3.5-turbo"
|
|
||||||
| "gpt-3.5-turbo-instruct"
|
export const OPENAI_SUPPORTED_MODELS = [
|
||||||
| "gpt-4"
|
|
||||||
| "gpt-4-32k"
|
|
||||||
| "text-embedding-ada-002";
|
|
||||||
export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
|
|
||||||
"gpt-3.5-turbo",
|
"gpt-3.5-turbo",
|
||||||
"gpt-3.5-turbo-instruct",
|
"gpt-3.5-turbo-instruct",
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
|
"gpt-4-32k",
|
||||||
|
"text-embedding-ada-002",
|
||||||
] as const;
|
] as const;
|
||||||
|
export type OpenAIModel = (typeof OPENAI_SUPPORTED_MODELS)[number];
|
||||||
|
|
||||||
// Flattening model families instead of using a nested object for easier
|
|
||||||
// cloning.
|
|
||||||
type OpenAIKeyUsage = {
|
type OpenAIKeyUsage = {
|
||||||
[K in OpenAIModelFamily as `${K}Tokens`]: number;
|
[K in OpenAIModelFamily as `${K}Tokens`]: number;
|
||||||
};
|
};
|
||||||
@@ -66,64 +61,30 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
|
|||||||
rateLimitTokensReset: number;
|
rateLimitTokensReset: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type OpenAIKeyUpdate = Omit<
|
export class OpenAIKeyProvider extends KeyProviderBase<OpenAIKey> {
|
||||||
Partial<OpenAIKey>,
|
|
||||||
"key" | "hash" | "promptCount"
|
|
||||||
>;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
|
||||||
* to be used again. This is to prevent the queue from flooding a key with too
|
|
||||||
* many requests while we wait to learn whether previous ones succeeded.
|
|
||||||
*/
|
|
||||||
const KEY_REUSE_DELAY = 1000;
|
|
||||||
|
|
||||||
export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
|
||||||
readonly service = "openai" as const;
|
readonly service = "openai" as const;
|
||||||
|
|
||||||
private keys: OpenAIKey[] = [];
|
protected readonly keys: OpenAIKey[] = [];
|
||||||
private checker?: OpenAIKeyChecker;
|
private checker?: OpenAIKeyChecker;
|
||||||
private log = logger.child({ module: "key-provider", service: this.service });
|
protected log = logger.child({ module: "key-provider", service: this.service });
|
||||||
|
|
||||||
constructor() {
|
public async init() {
|
||||||
const keyString = config.openaiKey?.trim();
|
const storeName = this.store.constructor.name;
|
||||||
if (!keyString) {
|
const loadedKeys = await this.store.load();
|
||||||
this.log.warn("OPENAI_KEY is not set. OpenAI API will not be available.");
|
|
||||||
return;
|
// TODO: after key management UI, keychecker should always be enabled
|
||||||
}
|
// because keys may be added after initialization.
|
||||||
let bareKeys: string[];
|
|
||||||
bareKeys = keyString.split(",").map((k) => k.trim());
|
if (loadedKeys.length === 0) {
|
||||||
bareKeys = [...new Set(bareKeys)];
|
return this.log.warn({ via: storeName }, "No OpenAI keys found.");
|
||||||
for (const k of bareKeys) {
|
|
||||||
const newKey: OpenAIKey = {
|
|
||||||
key: k,
|
|
||||||
service: "openai" as const,
|
|
||||||
modelFamilies: ["turbo" as const, "gpt4" as const],
|
|
||||||
isTrial: false,
|
|
||||||
isDisabled: false,
|
|
||||||
isRevoked: false,
|
|
||||||
isOverQuota: false,
|
|
||||||
lastUsed: 0,
|
|
||||||
lastChecked: 0,
|
|
||||||
promptCount: 0,
|
|
||||||
hash: `oai-${crypto
|
|
||||||
.createHash("sha256")
|
|
||||||
.update(k)
|
|
||||||
.digest("hex")
|
|
||||||
.slice(0, 8)}`,
|
|
||||||
rateLimitedAt: 0,
|
|
||||||
rateLimitRequestsReset: 0,
|
|
||||||
rateLimitTokensReset: 0,
|
|
||||||
turboTokens: 0,
|
|
||||||
gpt4Tokens: 0,
|
|
||||||
"gpt4-32kTokens": 0,
|
|
||||||
};
|
|
||||||
this.keys.push(newKey);
|
|
||||||
}
|
|
||||||
this.log.info({ keyCount: this.keys.length }, "Loaded OpenAI keys.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public init() {
|
this.keys.push(...loadedKeys);
|
||||||
|
this.log.info(
|
||||||
|
{ count: this.keys.length, via: storeName },
|
||||||
|
"Loaded OpenAI keys."
|
||||||
|
);
|
||||||
|
|
||||||
if (config.checkKeys) {
|
if (config.checkKeys) {
|
||||||
const cloneFn = this.clone.bind(this);
|
const cloneFn = this.clone.bind(this);
|
||||||
const updateFn = this.update.bind(this);
|
const updateFn = this.update.bind(this);
|
||||||
@@ -132,29 +93,16 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a list of all keys, with the key field removed.
|
|
||||||
* Don't mutate returned keys, use a KeyPool method instead.
|
|
||||||
**/
|
|
||||||
public list() {
|
|
||||||
return this.keys.map((key) => {
|
|
||||||
return Object.freeze({
|
|
||||||
...key,
|
|
||||||
key: undefined,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
public get(model: Model) {
|
public get(model: Model) {
|
||||||
const neededFamily = getOpenAIModelFamily(model);
|
const neededFamily = getOpenAIModelFamily(model);
|
||||||
const excludeTrials = model === "text-embedding-ada-002";
|
const excludeTrials = model === "text-embedding-ada-002";
|
||||||
|
|
||||||
const availableKeys = this.keys.filter(
|
const availableKeys = this.keys.filter(
|
||||||
// Allow keys which
|
// Allow keys which...
|
||||||
(key) =>
|
(key) =>
|
||||||
!key.isDisabled && // are not disabled
|
!key.isDisabled && // ...are not disabled
|
||||||
key.modelFamilies.includes(neededFamily) && // have access to the model
|
key.modelFamilies.includes(neededFamily) && // ...have access to the model
|
||||||
(!excludeTrials || !key.isTrial) // and are not trials (if applicable)
|
(!excludeTrials || !key.isTrial) // ...and are not trials (if applicable)
|
||||||
);
|
);
|
||||||
|
|
||||||
if (availableKeys.length === 0) {
|
if (availableKeys.length === 0) {
|
||||||
@@ -233,13 +181,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
|||||||
return { ...selectedKey };
|
return { ...selectedKey };
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Called by the key checker to update key information. */
|
|
||||||
public update(keyHash: string, update: OpenAIKeyUpdate) {
|
|
||||||
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
|
|
||||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
|
||||||
// this.writeKeyStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Called by the key checker to create clones of keys for the given orgs. */
|
/** Called by the key checker to create clones of keys for the given orgs. */
|
||||||
public clone(keyHash: string, newOrgIds: string[]) {
|
public clone(keyHash: string, newOrgIds: string[]) {
|
||||||
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
|
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
|
||||||
@@ -261,19 +202,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
|||||||
);
|
);
|
||||||
return clone;
|
return clone;
|
||||||
});
|
});
|
||||||
this.keys.push(...clones);
|
clones.forEach((clone) => this.addKey(clone));
|
||||||
}
|
|
||||||
|
|
||||||
/** Disables a key, or does nothing if the key isn't in this pool. */
|
|
||||||
public disable(key: Key) {
|
|
||||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
|
||||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
|
||||||
this.update(key.hash, { isDisabled: true });
|
|
||||||
this.log.warn({ key: key.hash }, "Key disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
public available() {
|
|
||||||
return this.keys.filter((k) => !k.isDisabled).length;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -338,7 +267,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
|||||||
key[`${getOpenAIModelFamily(model)}Tokens`] += tokens;
|
key[`${getOpenAIModelFamily(model)}Tokens`] += tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) {
|
public updateRateLimits(keyHash: string, headers: IncomingHttpHeaders) {
|
||||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||||
const requestsReset = headers["x-ratelimit-reset-requests"];
|
const requestsReset = headers["x-ratelimit-reset-requests"];
|
||||||
const tokensReset = headers["x-ratelimit-reset-tokens"];
|
const tokensReset = headers["x-ratelimit-reset-tokens"];
|
||||||
@@ -382,21 +311,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
|||||||
});
|
});
|
||||||
this.checker?.scheduleNextCheck();
|
this.checker?.scheduleNextCheck();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Writes key status to disk. */
|
|
||||||
// public writeKeyStatus() {
|
|
||||||
// const keys = this.keys.map((key) => ({
|
|
||||||
// key: key.key,
|
|
||||||
// isGpt4: key.isGpt4,
|
|
||||||
// usage: key.usage,
|
|
||||||
// hardLimit: key.hardLimit,
|
|
||||||
// isDisabled: key.isDisabled,
|
|
||||||
// }));
|
|
||||||
// fs.writeFileSync(
|
|
||||||
// path.join(__dirname, "..", "keys.json"),
|
|
||||||
// JSON.stringify(keys, null, 2)
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import type { OpenAIKey, SerializedKey } from "../index";
|
||||||
|
import { KeySerializerBase } from "../key-serializer-base";
|
||||||
|
|
||||||
|
const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [
|
||||||
|
"key",
|
||||||
|
"service",
|
||||||
|
"hash",
|
||||||
|
"organizationId",
|
||||||
|
"promptCount",
|
||||||
|
"gpt4Tokens",
|
||||||
|
"gpt4-32kTokens",
|
||||||
|
"turboTokens",
|
||||||
|
];
|
||||||
|
export type SerializedOpenAIKey = SerializedKey &
|
||||||
|
Partial<Pick<OpenAIKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
|
||||||
|
|
||||||
|
export class OpenAIKeySerializer extends KeySerializerBase<OpenAIKey> {
|
||||||
|
constructor() {
|
||||||
|
super(SERIALIZABLE_FIELDS);
|
||||||
|
}
|
||||||
|
|
||||||
|
deserialize({ key, ...rest }: SerializedOpenAIKey): OpenAIKey {
|
||||||
|
return {
|
||||||
|
key,
|
||||||
|
service: "openai",
|
||||||
|
modelFamilies: ["turbo" as const, "gpt4" as const],
|
||||||
|
isTrial: false,
|
||||||
|
isDisabled: false,
|
||||||
|
isRevoked: false,
|
||||||
|
isOverQuota: false,
|
||||||
|
lastUsed: 0,
|
||||||
|
lastChecked: 0,
|
||||||
|
promptCount: 0,
|
||||||
|
hash: `oai-${crypto
|
||||||
|
.createHash("sha256")
|
||||||
|
.update(key)
|
||||||
|
.digest("hex")
|
||||||
|
.slice(0, 8)}`,
|
||||||
|
rateLimitedAt: 0,
|
||||||
|
rateLimitRequestsReset: 0,
|
||||||
|
rateLimitTokensReset: 0,
|
||||||
|
turboTokens: 0,
|
||||||
|
gpt4Tokens: 0,
|
||||||
|
"gpt4-32kTokens": 0,
|
||||||
|
...rest,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,26 +1,15 @@
|
|||||||
import crypto from "crypto";
|
|
||||||
import { Key, KeyProvider } from "..";
|
|
||||||
import { config } from "../../../config";
|
|
||||||
import { logger } from "../../../logger";
|
import { logger } from "../../../logger";
|
||||||
import type { GooglePalmModelFamily } from "../../models";
|
import type { GooglePalmModelFamily } from "../../models";
|
||||||
|
import { KeyProviderBase } from "../key-provider-base";
|
||||||
|
import { Key } from "../types";
|
||||||
|
|
||||||
|
const RATE_LIMIT_LOCKOUT = 2000;
|
||||||
|
const KEY_REUSE_DELAY = 500;
|
||||||
|
|
||||||
// https://developers.generativeai.google.com/models/language
|
// https://developers.generativeai.google.com/models/language
|
||||||
export const GOOGLE_PALM_SUPPORTED_MODELS = [
|
export const GOOGLE_PALM_SUPPORTED_MODELS = ["text-bison-001"] as const;
|
||||||
"text-bison-001",
|
|
||||||
// "chat-bison-001", no adjustable safety settings, so it's useless
|
|
||||||
] as const;
|
|
||||||
export type GooglePalmModel = (typeof GOOGLE_PALM_SUPPORTED_MODELS)[number];
|
export type GooglePalmModel = (typeof GOOGLE_PALM_SUPPORTED_MODELS)[number];
|
||||||
|
|
||||||
export type GooglePalmKeyUpdate = Omit<
|
|
||||||
Partial<GooglePalmKey>,
|
|
||||||
| "key"
|
|
||||||
| "hash"
|
|
||||||
| "lastUsed"
|
|
||||||
| "promptCount"
|
|
||||||
| "rateLimitedAt"
|
|
||||||
| "rateLimitedUntil"
|
|
||||||
>;
|
|
||||||
|
|
||||||
type GooglePalmKeyUsage = {
|
type GooglePalmKeyUsage = {
|
||||||
[K in GooglePalmModelFamily as `${K}Tokens`]: number;
|
[K in GooglePalmModelFamily as `${K}Tokens`]: number;
|
||||||
};
|
};
|
||||||
@@ -34,62 +23,25 @@ export interface GooglePalmKey extends Key, GooglePalmKeyUsage {
|
|||||||
rateLimitedUntil: number;
|
rateLimitedUntil: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
export class GooglePalmKeyProvider extends KeyProviderBase<GooglePalmKey> {
|
||||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
|
||||||
* while we wait for other concurrent requests to finish.
|
|
||||||
*/
|
|
||||||
const RATE_LIMIT_LOCKOUT = 2000;
|
|
||||||
/**
|
|
||||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
|
||||||
* to be used again. This is to prevent the queue from flooding a key with too
|
|
||||||
* many requests while we wait to learn whether previous ones succeeded.
|
|
||||||
*/
|
|
||||||
const KEY_REUSE_DELAY = 500;
|
|
||||||
|
|
||||||
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
|
||||||
readonly service = "google-palm";
|
readonly service = "google-palm";
|
||||||
|
|
||||||
private keys: GooglePalmKey[] = [];
|
protected keys: GooglePalmKey[] = [];
|
||||||
private log = logger.child({ module: "key-provider", service: this.service });
|
protected log = logger.child({ module: "key-provider", service: this.service });
|
||||||
|
|
||||||
constructor() {
|
public async init() {
|
||||||
const keyConfig = config.googlePalmKey?.trim();
|
const storeName = this.store.constructor.name;
|
||||||
if (!keyConfig) {
|
const loadedKeys = await this.store.load();
|
||||||
this.log.warn(
|
|
||||||
"GOOGLE_PALM_KEY is not set. PaLM API will not be available."
|
if (loadedKeys.length === 0) {
|
||||||
|
return this.log.warn({ via: storeName }, "No Google PaLM keys found.");
|
||||||
|
}
|
||||||
|
|
||||||
|
this.keys.push(...loadedKeys);
|
||||||
|
this.log.info(
|
||||||
|
{ count: this.keys.length, via: storeName },
|
||||||
|
"Loaded PaLM keys."
|
||||||
);
|
);
|
||||||
return;
|
|
||||||
}
|
|
||||||
let bareKeys: string[];
|
|
||||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
|
||||||
for (const key of bareKeys) {
|
|
||||||
const newKey: GooglePalmKey = {
|
|
||||||
key,
|
|
||||||
service: this.service,
|
|
||||||
modelFamilies: ["bison"],
|
|
||||||
isDisabled: false,
|
|
||||||
isRevoked: false,
|
|
||||||
promptCount: 0,
|
|
||||||
lastUsed: 0,
|
|
||||||
rateLimitedAt: 0,
|
|
||||||
rateLimitedUntil: 0,
|
|
||||||
hash: `plm-${crypto
|
|
||||||
.createHash("sha256")
|
|
||||||
.update(key)
|
|
||||||
.digest("hex")
|
|
||||||
.slice(0, 8)}`,
|
|
||||||
lastChecked: 0,
|
|
||||||
bisonTokens: 0,
|
|
||||||
};
|
|
||||||
this.keys.push(newKey);
|
|
||||||
}
|
|
||||||
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys.");
|
|
||||||
}
|
|
||||||
|
|
||||||
public init() {}
|
|
||||||
|
|
||||||
public list() {
|
|
||||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public get(_model: GooglePalmModel) {
|
public get(_model: GooglePalmModel) {
|
||||||
@@ -130,22 +82,6 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
|||||||
return { ...selectedKey };
|
return { ...selectedKey };
|
||||||
}
|
}
|
||||||
|
|
||||||
public disable(key: GooglePalmKey) {
|
|
||||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
|
||||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
|
||||||
keyFromPool.isDisabled = true;
|
|
||||||
this.log.warn({ key: key.hash }, "Key disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
public update(hash: string, update: Partial<GooglePalmKey>) {
|
|
||||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
|
||||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
|
||||||
}
|
|
||||||
|
|
||||||
public available() {
|
|
||||||
return this.keys.filter((k) => !k.isDisabled).length;
|
|
||||||
}
|
|
||||||
|
|
||||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||||
const key = this.keys.find((k) => k.hash === hash);
|
const key = this.keys.find((k) => k.hash === hash);
|
||||||
if (!key) return;
|
if (!key) return;
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import type { GooglePalmKey, SerializedKey } from "../index";
|
||||||
|
import { KeySerializerBase } from "../key-serializer-base";
|
||||||
|
|
||||||
|
const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [
|
||||||
|
"key",
|
||||||
|
"service",
|
||||||
|
"hash",
|
||||||
|
"promptCount",
|
||||||
|
"bisonTokens",
|
||||||
|
];
|
||||||
|
export type SerializedGooglePalmKey = SerializedKey &
|
||||||
|
Partial<Pick<GooglePalmKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
|
||||||
|
|
||||||
|
export class GooglePalmKeySerializer extends KeySerializerBase<GooglePalmKey> {
|
||||||
|
constructor() {
|
||||||
|
super(SERIALIZABLE_FIELDS);
|
||||||
|
}
|
||||||
|
|
||||||
|
deserialize(serializedKey: SerializedGooglePalmKey): GooglePalmKey {
|
||||||
|
const { key, ...rest } = serializedKey;
|
||||||
|
return {
|
||||||
|
key,
|
||||||
|
service: "google-palm" as const,
|
||||||
|
modelFamilies: ["bison"],
|
||||||
|
isDisabled: false,
|
||||||
|
isRevoked: false,
|
||||||
|
promptCount: 0,
|
||||||
|
lastUsed: 0,
|
||||||
|
rateLimitedAt: 0,
|
||||||
|
rateLimitedUntil: 0,
|
||||||
|
hash: `plm-${crypto
|
||||||
|
.createHash("sha256")
|
||||||
|
.update(key)
|
||||||
|
.digest("hex")
|
||||||
|
.slice(0, 8)}`,
|
||||||
|
lastChecked: 0,
|
||||||
|
bisonTokens: 0,
|
||||||
|
...rest,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
import { assertNever } from "../utils";
|
||||||
|
import {
|
||||||
|
Key,
|
||||||
|
KeySerializer,
|
||||||
|
LLMService,
|
||||||
|
SerializedKey,
|
||||||
|
ServiceToKey,
|
||||||
|
} from "./types";
|
||||||
|
import { OpenAIKeySerializer } from "./openai/serializer";
|
||||||
|
import { AnthropicKeySerializer } from "./anthropic/serializer";
|
||||||
|
import { GooglePalmKeySerializer } from "./palm/serializer";
|
||||||
|
import { AwsBedrockKeySerializer } from "./aws/serializer";
|
||||||
|
|
||||||
|
export function assertSerializedKey(k: any): asserts k is SerializedKey {
|
||||||
|
if (typeof k !== "object" || !k || typeof (k as any).key !== "string") {
|
||||||
|
throw new Error("Invalid serialized key data");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getSerializer<S extends LLMService>(
|
||||||
|
service: S
|
||||||
|
): KeySerializer<ServiceToKey[S]>;
|
||||||
|
export function getSerializer(service: LLMService): KeySerializer<Key> {
|
||||||
|
switch (service) {
|
||||||
|
case "openai":
|
||||||
|
return new OpenAIKeySerializer();
|
||||||
|
case "anthropic":
|
||||||
|
return new AnthropicKeySerializer();
|
||||||
|
case "google-palm":
|
||||||
|
return new GooglePalmKeySerializer();
|
||||||
|
case "aws":
|
||||||
|
return new AwsBedrockKeySerializer();
|
||||||
|
default:
|
||||||
|
assertNever(service);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
import firebase from "firebase-admin";
|
||||||
|
import { config, getFirebaseApp } from "../../../config";
|
||||||
|
import { logger } from "../../../logger";
|
||||||
|
import { assertSerializedKey } from "../serializers";
|
||||||
|
import type {
|
||||||
|
Key,
|
||||||
|
KeySerializer,
|
||||||
|
KeyStore,
|
||||||
|
LLMService,
|
||||||
|
SerializedKey,
|
||||||
|
} from "../types";
|
||||||
|
import { MemoryKeyStore } from "./index";
|
||||||
|
|
||||||
|
export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
|
||||||
|
private readonly db: firebase.database.Database;
|
||||||
|
private readonly log: typeof logger;
|
||||||
|
private readonly pendingUpdates: Map<string, Partial<SerializedKey>>;
|
||||||
|
private readonly root: string;
|
||||||
|
private readonly serializer: KeySerializer<K>;
|
||||||
|
private readonly service: LLMService;
|
||||||
|
private flushInterval: NodeJS.Timeout | null = null;
|
||||||
|
private keysRef: firebase.database.Reference | null = null;
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
service: LLMService,
|
||||||
|
serializer: KeySerializer<K>,
|
||||||
|
app = getFirebaseApp()
|
||||||
|
) {
|
||||||
|
this.db = firebase.database(app);
|
||||||
|
this.log = logger.child({ module: "firebase-key-store", service });
|
||||||
|
this.root = `keys/${config.firebaseRtdbRoot.toLowerCase()}/${service}`;
|
||||||
|
this.serializer = serializer;
|
||||||
|
this.service = service;
|
||||||
|
this.pendingUpdates = new Map();
|
||||||
|
this.scheduleFlush();
|
||||||
|
}
|
||||||
|
|
||||||
|
public async load(isMigrating = false): Promise<K[]> {
|
||||||
|
const keysRef = this.db.ref(this.root);
|
||||||
|
const snapshot = await keysRef.once("value");
|
||||||
|
const keys = snapshot.val();
|
||||||
|
this.keysRef = keysRef;
|
||||||
|
|
||||||
|
if (!keys) {
|
||||||
|
if (isMigrating) return [];
|
||||||
|
this.log.warn("No keys found in Firebase. Migrating from environment.");
|
||||||
|
await this.migrate();
|
||||||
|
return this.load(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Object.values(keys).map((k) => {
|
||||||
|
assertSerializedKey(k);
|
||||||
|
return this.serializer.deserialize(k);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public add(key: K) {
|
||||||
|
const serialized = this.serializer.serialize(key);
|
||||||
|
this.pendingUpdates.set(key.hash, serialized);
|
||||||
|
this.forceFlush();
|
||||||
|
}
|
||||||
|
|
||||||
|
public update(id: string, update: Partial<K>, force = false) {
|
||||||
|
const existing = this.pendingUpdates.get(id) ?? {};
|
||||||
|
Object.assign(existing, this.serializer.partialSerialize(id, update));
|
||||||
|
this.pendingUpdates.set(id, existing);
|
||||||
|
if (force) this.forceFlush();
|
||||||
|
}
|
||||||
|
|
||||||
|
private forceFlush() {
|
||||||
|
if (this.flushInterval) clearInterval(this.flushInterval);
|
||||||
|
this.flushInterval = setTimeout(() => this.flush(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private scheduleFlush() {
|
||||||
|
if (this.flushInterval) clearInterval(this.flushInterval);
|
||||||
|
this.flushInterval = setInterval(() => this.flush(), 1000 * 60 * 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async flush() {
|
||||||
|
if (!this.keysRef) {
|
||||||
|
this.log.warn(
|
||||||
|
{ pendingUpdates: this.pendingUpdates.size },
|
||||||
|
"Database not loaded yet. Skipping flush."
|
||||||
|
);
|
||||||
|
return this.scheduleFlush();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.pendingUpdates.size === 0) {
|
||||||
|
this.log.debug("No pending key updates to flush.");
|
||||||
|
return this.scheduleFlush();
|
||||||
|
}
|
||||||
|
|
||||||
|
const updates: Record<string, Partial<SerializedKey>> = {};
|
||||||
|
this.pendingUpdates.forEach((v, k) => (updates[k] = v));
|
||||||
|
this.pendingUpdates.clear();
|
||||||
|
console.log(updates);
|
||||||
|
|
||||||
|
await this.keysRef.update(updates);
|
||||||
|
|
||||||
|
this.log.debug(
|
||||||
|
{ count: Object.keys(updates).length },
|
||||||
|
"Flushed pending key updates."
|
||||||
|
);
|
||||||
|
this.scheduleFlush();
|
||||||
|
}
|
||||||
|
|
||||||
|
private async migrate(): Promise<SerializedKey[]> {
|
||||||
|
const keysRef = this.db.ref(this.root);
|
||||||
|
const envStore = new MemoryKeyStore<K>(this.service, this.serializer);
|
||||||
|
const keys = await envStore.load();
|
||||||
|
|
||||||
|
if (keys.length === 0) {
|
||||||
|
this.log.warn("No keys found in environment or Firebase.");
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const updates: Record<string, SerializedKey> = {};
|
||||||
|
keys.forEach((k) => (updates[k.hash] = this.serializer.serialize(k)));
|
||||||
|
await keysRef.update(updates);
|
||||||
|
|
||||||
|
this.log.info({ count: keys.length }, "Migrated keys from environment.");
|
||||||
|
return Object.values(updates);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
export { FirebaseKeyStore } from "./firebase";
|
||||||
|
export { MemoryKeyStore } from "./memory";
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
import { assertNever } from "../../utils";
|
||||||
|
import { Key, KeySerializer, KeyStore, LLMService } from "../types";
|
||||||
|
|
||||||
|
export class MemoryKeyStore<K extends Key> implements KeyStore<K> {
|
||||||
|
private readonly env: string;
|
||||||
|
private readonly serializer: KeySerializer<K>;
|
||||||
|
|
||||||
|
constructor(service: LLMService, serializer: KeySerializer<K>) {
|
||||||
|
switch (service) {
|
||||||
|
case "anthropic":
|
||||||
|
this.env = "ANTHROPIC_KEY";
|
||||||
|
break;
|
||||||
|
case "openai":
|
||||||
|
this.env = "OPENAI_KEY";
|
||||||
|
break;
|
||||||
|
case "google-palm":
|
||||||
|
this.env = "GOOGLE_PALM_KEY";
|
||||||
|
break;
|
||||||
|
case "aws":
|
||||||
|
this.env = "AWS_CREDENTIALS";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assertNever(service);
|
||||||
|
}
|
||||||
|
this.serializer = serializer;
|
||||||
|
}
|
||||||
|
|
||||||
|
public async load() {
|
||||||
|
let envKeys: string[];
|
||||||
|
envKeys = [
|
||||||
|
...new Set(process.env[this.env]?.split(",").map((k) => k.trim())),
|
||||||
|
];
|
||||||
|
return envKeys
|
||||||
|
.filter((k) => k)
|
||||||
|
.map((k) => this.serializer.deserialize({ key: k }));
|
||||||
|
}
|
||||||
|
|
||||||
|
public add() {}
|
||||||
|
|
||||||
|
public update() {}
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
import type { OpenAIKey, OpenAIModel } from "./openai/provider";
|
||||||
|
import type { AnthropicKey, AnthropicModel } from "./anthropic/provider";
|
||||||
|
import type { GooglePalmKey, GooglePalmModel } from "./palm/provider";
|
||||||
|
import type { AwsBedrockKey, AwsBedrockModel } from "./aws/provider";
|
||||||
|
import type { ModelFamily } from "../models";
|
||||||
|
|
||||||
|
/** The request and response format used by a model's API. */
|
||||||
|
export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text";
|
||||||
|
/**
|
||||||
|
* The service that a model is hosted on; distinct because services like AWS
|
||||||
|
* provide APIs from other service providers, but have their own authentication
|
||||||
|
* and key management.
|
||||||
|
*/
|
||||||
|
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
|
||||||
|
|
||||||
|
export type Model =
|
||||||
|
| OpenAIModel
|
||||||
|
| AnthropicModel
|
||||||
|
| GooglePalmModel
|
||||||
|
| AwsBedrockModel;
|
||||||
|
|
||||||
|
type AllKeys = OpenAIKey | AnthropicKey | GooglePalmKey | AwsBedrockKey;
|
||||||
|
export type ServiceToKey = {
|
||||||
|
[K in AllKeys["service"]]: Extract<AllKeys, { service: K }>;
|
||||||
|
};
|
||||||
|
export type SerializedKey = { key: string };
|
||||||
|
|
||||||
|
export interface Key {
|
||||||
|
/** The API key itself. Never log this, use `hash` instead. */
|
||||||
|
readonly key: string;
|
||||||
|
/** The service that this key is for. */
|
||||||
|
service: LLMService;
|
||||||
|
/** The model families that this key has access to. */
|
||||||
|
modelFamilies: ModelFamily[];
|
||||||
|
/** Whether this key is currently disabled for some reason. */
|
||||||
|
isDisabled: boolean;
|
||||||
|
/**
|
||||||
|
* Whether this key specifically has been revoked. This is different from
|
||||||
|
* `isDisabled` because a key can be disabled for other reasons, such as
|
||||||
|
* exceeding its quota. A revoked key is assumed to be permanently disabled,
|
||||||
|
* and KeyStore implementations should not return it when loading keys.
|
||||||
|
*/
|
||||||
|
isRevoked: boolean;
|
||||||
|
/** The number of prompts that have been sent with this key. */
|
||||||
|
promptCount: number;
|
||||||
|
/** The time at which this key was last used. */
|
||||||
|
lastUsed: number;
|
||||||
|
/** The time at which this key was last checked. */
|
||||||
|
lastChecked: number;
|
||||||
|
/** Hash of the key, for logging and to find the key in the pool. */
|
||||||
|
hash: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface KeySerializer<K> {
|
||||||
|
serialize(keyObj: K): SerializedKey;
|
||||||
|
deserialize(serializedKey: SerializedKey): K;
|
||||||
|
partialSerialize(key: string, update: Partial<K>): Partial<SerializedKey>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface KeyStore<K extends Key> {
|
||||||
|
load(): Promise<K[]>;
|
||||||
|
add(key: K): void;
|
||||||
|
update(id: string, update: Partial<K>, force?: boolean): void;
|
||||||
|
}
|
||||||
@@ -32,8 +32,8 @@ let quotaRefreshJob: schedule.Job | null = null;
|
|||||||
let userCleanupJob: schedule.Job | null = null;
|
let userCleanupJob: schedule.Job | null = null;
|
||||||
|
|
||||||
export async function init() {
|
export async function init() {
|
||||||
log.info({ store: config.gatekeeperStore }, "Initializing user store...");
|
log.info({ store: config.persistenceProvider }, "Initializing user store...");
|
||||||
if (config.gatekeeperStore === "firebase_rtdb") {
|
if (config.persistenceProvider === "firebase_rtdb") {
|
||||||
await initFirebase();
|
await initFirebase();
|
||||||
}
|
}
|
||||||
if (config.quotaRefreshPeriod) {
|
if (config.quotaRefreshPeriod) {
|
||||||
@@ -146,7 +146,7 @@ export function upsertUser(user: UserUpdate) {
|
|||||||
usersToFlush.add(user.token);
|
usersToFlush.add(user.token);
|
||||||
|
|
||||||
// Immediately schedule a flush to the database if we're using Firebase.
|
// Immediately schedule a flush to the database if we're using Firebase.
|
||||||
if (config.gatekeeperStore === "firebase_rtdb") {
|
if (config.persistenceProvider === "firebase_rtdb") {
|
||||||
setImmediate(flushUsers);
|
setImmediate(flushUsers);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user