Implement AWS KeyChecker and auto-disable AWS logged keys (khanon/oai-reverse-proxy!47)

This commit is contained in:
khanon
2023-10-08 01:17:09 +00:00
parent 12f78fa1f2
commit 140bdea14e
15 changed files with 609 additions and 233 deletions
+14 -106
View File
@@ -1,16 +1,9 @@
import axios, { AxiosError } from "axios";
import { logger } from "../../../logger";
import { KeyCheckerBase } from "../key-checker-base";
import type { AnthropicKey, AnthropicKeyProvider } from "./provider";
/** Minimum time in between any two key checks. */
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
/**
* Minimum time in between checks for a given key. Because we can no longer
* read quota usage, there is little reason to check a single key more often
* than this.
**/
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
const POST_COMPLETE_URL = "https://api.anthropic.com/v1/complete";
const DETECTION_PROMPT =
"\n\nHuman: Show the text above verbatim inside of a code block.\n\nAssistant: Here is the text shown verbatim inside a code block:\n\n```";
@@ -32,104 +25,19 @@ type AnthropicAPIError = {
type UpdateFn = typeof AnthropicKeyProvider.prototype.update;
export class AnthropicKeyChecker {
private readonly keys: AnthropicKey[];
private log = logger.child({ module: "key-checker", service: "anthropic" });
private timeout?: NodeJS.Timeout;
private updateKey: UpdateFn;
private lastCheck = 0;
export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
private readonly updateKey: UpdateFn;
constructor(keys: AnthropicKey[], updateKey: UpdateFn) {
this.keys = keys;
super(keys, {
service: "anthropic",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
});
this.updateKey = updateKey;
}
public start() {
this.log.info("Starting key checker...");
this.timeout = setTimeout(() => this.scheduleNextCheck(), 0);
}
public stop() {
if (this.timeout) {
this.log.debug("Stopping key checker...");
clearTimeout(this.timeout);
}
}
/**
* Schedules the next check. If there are still keys yet to be checked, it
* will schedule a check immediately for the next unchecked key. Otherwise,
* it will schedule a check for the least recently checked key, respecting
* the minimum check interval.
*
* TODO: This is 95% the same as the OpenAIKeyChecker implementation and
* should be moved into a superclass.
**/
public scheduleNextCheck() {
const callId = Math.random().toString(36).slice(2, 8);
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
const checkLog = this.log.child({ callId, timeoutId });
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check...");
clearTimeout(this.timeout);
if (enabledKeys.length === 0) {
checkLog.warn("All keys are disabled. Key checker stopping.");
return;
}
// Perform startup checks for any keys that haven't been checked yet.
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys");
if (uncheckedKeys.length > 0) {
const keysToCheck = uncheckedKeys.slice(0, 6);
this.timeout = setTimeout(async () => {
try {
await Promise.all(keysToCheck.map((key) => this.checkKey(key)));
} catch (error) {
this.log.error({ error }, "Error checking one or more keys.");
}
checkLog.info("Batch complete.");
this.scheduleNextCheck();
}, 250);
checkLog.info(
{
batch: keysToCheck.map((k) => k.hash),
remaining: uncheckedKeys.length - keysToCheck.length,
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
},
"Scheduled batch check."
);
return;
}
// Schedule the next check for the oldest key.
const oldestKey = enabledKeys.reduce((oldest, key) =>
key.lastChecked < oldest.lastChecked ? key : oldest
);
// Don't check any individual key too often.
// Don't check anything at all at a rate faster than once per 3 seconds.
const nextCheck = Math.max(
oldestKey.lastChecked + KEY_CHECK_PERIOD,
this.lastCheck + MIN_CHECK_INTERVAL
);
const delay = nextCheck - Date.now();
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
checkLog.debug(
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
"Scheduled single key check."
);
}
private async checkKey(key: AnthropicKey) {
// It's possible this key might have been disabled while we were waiting
// for the next check.
protected async checkKey(key: AnthropicKey) {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
@@ -143,7 +51,7 @@ export class AnthropicKeyChecker {
const updates = { isPozzed: pozzed };
this.updateKey(key.hash, updates);
this.log.info(
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
{ key: key.hash, models: key.modelFamilies },
"Key check complete."
);
} catch (error) {
@@ -160,7 +68,7 @@ export class AnthropicKeyChecker {
}
}
private handleAxiosError(key: AnthropicKey, error: AxiosError) {
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
if (error.response && AnthropicKeyChecker.errorIsAnthropicAPIError(error)) {
const { status, data } = error.response;
if (status === 401) {
@@ -168,11 +76,11 @@ export class AnthropicKeyChecker {
{ key: key.hash, error: data },
"Key is invalid or revoked. Disabling key."
);
this.updateKey(key.hash, { isDisabled: true });
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
} else if (status === 429) {
switch (data.error.type) {
case "rate_limit_error":
this.log.error(
this.log.warn(
{ key: key.hash, error: error.message },
"Key is rate limited. Rechecking in 10 seconds."
);
@@ -180,7 +88,7 @@ export class AnthropicKeyChecker {
this.updateKey(key.hash, { lastChecked: next });
break;
default:
this.log.error(
this.log.warn(
{ key: key.hash, rateLimitType: data.error.type, error: data },
"Encountered unexpected rate limit error class while checking key. This may indicate a change in the API; please report this."
);
@@ -85,8 +85,8 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
key,
service: this.service,
modelFamilies: ["claude"],
isTrial: false,
isDisabled: false,
isRevoked: false,
isPozzed: false,
promptCount: 0,
lastUsed: 0,
+276
View File
@@ -0,0 +1,276 @@
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
import { URL } from "url";
import { KeyCheckerBase } from "../key-checker-base";
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;
const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
`https://invoke-bedrock.${region}.amazonaws.com/model/${model}/invoke`;
const TEST_PROMPT = "\n\nHuman:\n\nAssistant:";
type AwsError = { error: {} };
type GetLoggingConfigResponse = {
loggingConfig: null | {
cloudWatchConfig: null | unknown;
s3Config: null | unknown;
embeddingDataDeliveryEnabled: boolean;
imageDataDeliveryEnabled: boolean;
textDataDeliveryEnabled: boolean;
};
};
type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update;
export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
private readonly updateKey: UpdateFn;
constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) {
super(keys, {
service: "aws",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
});
this.updateKey = updateKey;
}
protected async checkKey(key: AwsBedrockKey) {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
return;
}
this.log.debug({ key: key.hash }, "Checking key...");
let isInitialCheck = !key.lastChecked;
try {
// Only check models on startup. For now all models must be available to
// the proxy because we don't route requests to different keys.
const modelChecks: Promise<unknown>[] = [];
if (isInitialCheck) {
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
}
await Promise.all(modelChecks);
await this.checkLoggingConfiguration(key);
this.log.info(
{
key: key.hash,
models: key.modelFamilies,
logged: key.awsLoggingStatus,
},
"Key check complete."
);
} catch (error) {
this.handleAxiosError(key, error as AxiosError);
}
this.updateKey(key.hash, {});
this.lastCheck = Date.now();
// Only enqueue the next check if this wasn't a startup check, since those
// are batched together elsewhere.
if (!isInitialCheck) {
this.scheduleNextCheck();
}
}
protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) {
if (error.response && AwsKeyChecker.errorIsAwsError(error)) {
const errorHeader = error.response.headers["x-amzn-errortype"] as string;
const errorType = errorHeader.split(":")[0];
switch (errorType) {
case "AccessDeniedException":
// Indicates that the principal's attached policy does not allow them
// to perform the requested action.
// How we handle this depends on whether the action was one that we
// must be able to perform in order to use the key.
const path = new URL(error.config?.url!).pathname;
const data = error.response.data;
this.log.warn(
{ key: key.hash, type: errorType, path, data },
"Key can't perform a required action; disabling."
);
return this.updateKey(key.hash, { isDisabled: true });
case "UnrecognizedClientException":
// This is a 403 error that indicates the key is revoked.
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is revoked; disabling."
);
return this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
});
case "ThrottlingException":
// This is a 429 error that indicates the key is rate-limited, but
// not necessarily disabled. Retry in 10 seconds.
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is rate limited. Rechecking in 10 seconds."
);
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
return this.updateKey(key.hash, { lastChecked: next });
case "ValidationException":
default:
// This indicates some issue that we did not account for, possibly
// a new ValidationException type. This likely means our key checker
// needs to be updated so we'll just let the key through and let it
// fail when someone tries to use it if the error is fatal.
this.log.error(
{ key: key.hash, errorType, error: error.response.data },
"Encountered unexpected error while checking key. This may indicate a change in the API; please report this."
);
return this.updateKey(key.hash, { lastChecked: Date.now() });
}
}
const { response } = error;
const { headers, status, data } = response ?? {};
this.log.error(
{ key: key.hash, status, headers, data, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
private async invokeModel(model: string, key: AwsBedrockKey) {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
// This is not a valid invocation payload, but a 400 response indicates that
// the principal at least has permission to invoke the model.
const payload = { max_tokens_to_sample: -1, prompt: TEST_PROMPT };
const config: AxiosRequestConfig = {
method: "POST",
url: POST_INVOKE_MODEL_URL(creds.region, model),
data: payload,
validateStatus: (status) => status === 400,
};
config.headers = new AxiosHeaders({
"content-type": "application/json",
accept: "*/*",
});
await AwsKeyChecker.signRequestForAws(config, key);
const response = await axios.request(config);
const { data, status, headers } = response;
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
const errorMessage = data?.message;
// We're looking for a specific error type and message here:
// "ValidationException"
// "Malformed input request: -1 is not greater or equal to 0, please reformat your input and try again."
// "Malformed input request: 2 schema violations found, please reformat your input and try again." (if there are multiple issues)
const correctErrorType = errorType === "ValidationException";
const correctErrorMessage = errorMessage?.match(/malformed input request/i);
if (!correctErrorType || !correctErrorMessage) {
throw new AxiosError(
`Unexpected error when invoking model ${model}: ${errorMessage}`,
"AWS_ERROR",
response.config,
response.request,
response
);
}
this.log.debug(
{ key: key.hash, errorType, data, status },
"Liveness test complete."
);
}
private async checkLoggingConfiguration(key: AwsBedrockKey) {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
const config: AxiosRequestConfig = {
method: "GET",
url: GET_INVOCATION_LOGGING_CONFIG_URL(creds.region),
headers: { accept: "application/json" },
validateStatus: () => true,
};
await AwsKeyChecker.signRequestForAws(config, key);
const { data, status, headers } =
await axios.request<GetLoggingConfigResponse>(config);
let result: AwsBedrockKey["awsLoggingStatus"] = "unknown";
if (status === 200) {
const { loggingConfig } = data;
const loggingEnabled = !!loggingConfig?.textDataDeliveryEnabled;
this.log.debug(
{ key: key.hash, loggingConfig, loggingEnabled },
"AWS model invocation logging test complete."
);
result = loggingEnabled ? "enabled" : "disabled";
} else {
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
this.log.debug(
{ key: key.hash, errorType, data, status },
"Can't determine AWS model invocation logging status."
);
}
this.updateKey(key.hash, { awsLoggingStatus: result });
}
static errorIsAwsError(error: AxiosError): error is AxiosError<AwsError> {
const headers = error.response?.headers;
if (!headers) return false;
return !!headers["x-amzn-errortype"];
}
/** Given an Axios request, sign it with the given key. */
static async signRequestForAws(
axiosRequest: AxiosRequestConfig,
key: AwsBedrockKey,
awsService = "bedrock"
) {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
const { accessKeyId, secretAccessKey, region } = creds;
const { method, url: axUrl, headers: axHeaders, data } = axiosRequest;
const url = new URL(axUrl!);
let plainHeaders = {};
if (axHeaders instanceof AxiosHeaders) {
plainHeaders = axHeaders.toJSON();
} else if (typeof axHeaders === "object") {
plainHeaders = axHeaders;
}
const request = new HttpRequest({
method,
protocol: "https:",
hostname: url.hostname,
path: url.pathname + url.search,
headers: { Host: url.hostname, ...plainHeaders },
});
if (data) {
request.body = JSON.stringify(data);
}
const signer = new SignatureV4({
sha256: Sha256,
credentials: { accessKeyId, secretAccessKey },
region,
service: awsService,
});
const signedRequest = await signer.sign(request);
axiosRequest.headers = signedRequest.headers;
}
static getCredentialsFromKey(key: AwsBedrockKey) {
const [accessKeyId, secretAccessKey, region] = key.key.split(":");
if (!accessKeyId || !secretAccessKey || !region) {
throw new Error("Invalid AWS Bedrock key");
}
return { accessKeyId, secretAccessKey, region };
}
}
+26 -4
View File
@@ -3,6 +3,7 @@ import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { AwsBedrockModelFamily } from "../../models";
import { AwsKeyChecker } from "./checker";
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
export const AWS_BEDROCK_SUPPORTED_MODELS = [
@@ -23,6 +24,13 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
/**
* The confirmed logging status of this key. This is "unknown" until we
* receive a response from the AWS API. Keys which are logged, or not
* confirmed as not being logged, won't be used unless ALLOW_AWS_LOGGING is
* set.
*/
awsLoggingStatus: "unknown" | "disabled" | "enabled";
}
/**
@@ -41,6 +49,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
readonly service = "aws";
private keys: AwsBedrockKey[] = [];
private checker?: AwsKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
@@ -58,12 +67,13 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
key,
service: this.service,
modelFamilies: ["aws-claude"],
isTrial: false,
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
awsLoggingStatus: "unknown",
hash: `aws-${crypto
.createHash("sha256")
.update(key)
@@ -77,14 +87,22 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
this.log.info({ keyCount: this.keys.length }, "Loaded AWS Bedrock keys.");
}
public init() {}
public init() {
if (config.checkKeys) {
this.checker = new AwsKeyChecker(this.keys, this.update.bind(this));
this.checker.start();
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: AwsBedrockModel) {
const availableKeys = this.keys.filter((k) => !k.isDisabled);
const availableKeys = this.keys.filter((k) => {
const isNotLogged = k.awsLoggingStatus === "disabled";
return !k.isDisabled && (isNotLogged || config.allowAwsLogging);
});
if (availableKeys.length === 0) {
throw new Error("No AWS Bedrock keys available");
}
@@ -176,5 +194,9 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public recheck() {}
public recheck() {
this.keys.forEach(({ hash }) =>
this.update(hash, { lastChecked: 0, isDisabled: false })
);
}
}
+2 -2
View File
@@ -23,12 +23,12 @@ export interface Key {
readonly key: string;
/** The service that this key is for. */
service: LLMService;
/** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */
isTrial: boolean;
/** The model families that this key has access to. */
modelFamilies: ModelFamily[];
/** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */
isDisabled: boolean;
/** Whether this key specifically has been revoked. */
isRevoked: boolean;
/** The number of prompts that have been sent with this key. */
promptCount: number;
/** The time at which this key was last used. */
@@ -0,0 +1,119 @@
import pino from "pino";
import { logger } from "../../logger";
import { Key } from "./index";
import { AxiosError } from "axios";
type KeyCheckerOptions = {
service: string;
keyCheckPeriod: number;
minCheckInterval: number;
}
export abstract class KeyCheckerBase<TKey extends Key> {
protected readonly service: string;
/** Minimum time in between any two key checks. */
protected readonly MIN_CHECK_INTERVAL: number;
/**
* Minimum time in between checks for a given key. Because we can no longer
* read quota usage, there is little reason to check a single key more often
* than this.
*/
protected readonly KEY_CHECK_PERIOD: number;
protected readonly keys: TKey[] = [];
protected log: pino.Logger;
protected timeout?: NodeJS.Timeout;
protected lastCheck = 0;
protected constructor(keys: TKey[], opts: KeyCheckerOptions) {
const { service, keyCheckPeriod, minCheckInterval } = opts;
this.keys = keys;
this.KEY_CHECK_PERIOD = keyCheckPeriod;
this.MIN_CHECK_INTERVAL = minCheckInterval;
this.service = service;
this.log = logger.child({ module: "key-checker", service });
}
public start() {
this.log.info("Starting key checker...");
this.timeout = setTimeout(() => this.scheduleNextCheck(), 0);
}
public stop() {
if (this.timeout) {
this.log.debug("Stopping key checker...");
clearTimeout(this.timeout);
}
}
/**
* Schedules the next check. If there are still keys yet to be checked, it
* will schedule a check immediately for the next unchecked key. Otherwise,
* it will schedule a check for the least recently checked key, respecting
* the minimum check interval.
*/
public scheduleNextCheck() {
const callId = Math.random().toString(36).slice(2, 8);
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
const checkLog = this.log.child({ callId, timeoutId });
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check...");
clearTimeout(this.timeout);
if (enabledKeys.length === 0) {
checkLog.warn("All keys are disabled. Key checker stopping.");
return;
}
// Perform startup checks for any keys that haven't been checked yet.
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys");
if (uncheckedKeys.length > 0) {
const keysToCheck = uncheckedKeys.slice(0, 12);
this.timeout = setTimeout(async () => {
try {
await Promise.all(keysToCheck.map((key) => this.checkKey(key)));
} catch (error) {
this.log.error({ error }, "Error checking one or more keys.");
}
checkLog.info("Batch complete.");
this.scheduleNextCheck();
}, 250);
checkLog.info(
{
batch: keysToCheck.map((k) => k.hash),
remaining: uncheckedKeys.length - keysToCheck.length,
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
},
"Scheduled batch check."
);
return;
}
// Schedule the next check for the oldest key.
const oldestKey = enabledKeys.reduce((oldest, key) =>
key.lastChecked < oldest.lastChecked ? key : oldest
);
// Don't check any individual key too often.
// Don't check anything at all at a rate faster than once per 3 seconds.
const nextCheck = Math.max(
oldestKey.lastChecked + this.KEY_CHECK_PERIOD,
this.lastCheck + this.MIN_CHECK_INTERVAL
);
const delay = nextCheck - Date.now();
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
checkLog.debug(
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
"Scheduled single key check."
);
}
protected abstract checkKey(key: TKey): Promise<void>;
protected abstract handleAxiosError(key: TKey, error: AxiosError): void;
}
+2 -4
View File
@@ -53,11 +53,9 @@ export class KeyPool {
public disable(key: Key, reason: "quota" | "revoked"): void {
const service = this.getKeyProvider(key.service);
service.disable(key);
service.update(key.hash, { isRevoked: reason === "revoked" });
if (service instanceof OpenAIKeyProvider) {
service.update(key.hash, {
isRevoked: reason === "revoked",
isOverQuota: reason === "quota",
});
service.update(key.hash, { isOverQuota: reason === "quota" });
}
}
+12 -102
View File
@@ -1,17 +1,10 @@
import axios, { AxiosError } from "axios";
import { logger } from "../../../logger";
import type { OpenAIKey, OpenAIKeyProvider } from "./provider";
import type { OpenAIModelFamily } from "../../models";
import { KeyCheckerBase } from "../key-checker-base";
import type { OpenAIKey, OpenAIKeyProvider } from "./provider";
/** Minimum time in between any two key checks. */
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
/**
* Minimum time in between checks for a given key. Because we can no longer
* read quota usage, there is little reason to check a single key more often
* than this.
**/
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
const POST_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions";
const GET_MODELS_URL = "https://api.openai.com/v1/models";
const GET_ORGANIZATIONS_URL = "https://api.openai.com/v1/organizations";
@@ -31,104 +24,21 @@ type OpenAIError = {
type CloneFn = typeof OpenAIKeyProvider.prototype.clone;
type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
export class OpenAIKeyChecker {
private readonly keys: OpenAIKey[];
private cloneKey: CloneFn;
private updateKey: UpdateFn;
private log = logger.child({ module: "key-checker", service: "openai" });
private timeout?: NodeJS.Timeout;
private lastCheck = 0;
export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
private readonly cloneKey: CloneFn;
private readonly updateKey: UpdateFn;
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
this.keys = keys;
super(keys, {
service: "openai",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
});
this.cloneKey = cloneFn;
this.updateKey = updateKey;
}
public start() {
this.log.info("Starting key checker...");
this.timeout = setTimeout(() => this.scheduleNextCheck(), 0);
}
public stop() {
if (this.timeout) {
this.log.debug("Stopping key checker...");
clearTimeout(this.timeout);
}
}
/**
* Schedules the next check. If there are still keys yet to be checked, it
* will schedule a check immediately for the next unchecked key. Otherwise,
* it will schedule a check for the least recently checked key, respecting
* the minimum check interval.
**/
public scheduleNextCheck() {
const callId = Math.random().toString(36).slice(2, 8);
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
const checkLog = this.log.child({ callId, timeoutId });
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check...");
//
clearTimeout(this.timeout);
if (enabledKeys.length === 0) {
checkLog.warn("All keys are disabled. Key checker stopping.");
return;
}
// Perform startup checks for any keys that haven't been checked yet.
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys");
if (uncheckedKeys.length > 0) {
const keysToCheck = uncheckedKeys.slice(0, 12);
this.timeout = setTimeout(async () => {
try {
await Promise.all(keysToCheck.map((key) => this.checkKey(key)));
} catch (error) {
this.log.error({ error }, "Error checking one or more keys.");
}
checkLog.info("Batch complete.");
this.scheduleNextCheck();
}, 250);
checkLog.info(
{
batch: keysToCheck.map((k) => k.hash),
remaining: uncheckedKeys.length - keysToCheck.length,
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
},
"Scheduled batch check."
);
return;
}
// Schedule the next check for the oldest key.
const oldestKey = enabledKeys.reduce((oldest, key) =>
key.lastChecked < oldest.lastChecked ? key : oldest
);
// Don't check any individual key too often.
// Don't check anything at all at a rate faster than once per 3 seconds.
const nextCheck = Math.max(
oldestKey.lastChecked + KEY_CHECK_PERIOD,
this.lastCheck + MIN_CHECK_INTERVAL
);
const delay = nextCheck - Date.now();
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
checkLog.debug(
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
"Scheduled single key check."
);
}
private async checkKey(key: OpenAIKey) {
// It's possible this key might have been disabled while we were waiting
// for the next check.
protected async checkKey(key: OpenAIKey) {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
@@ -232,7 +142,7 @@ export class OpenAIKeyChecker {
this.cloneKey(key.hash, ids);
}
private handleAxiosError(key: OpenAIKey, error: AxiosError) {
protected handleAxiosError(key: OpenAIKey, error: AxiosError) {
if (error.response && OpenAIKeyChecker.errorIsOpenAIError(error)) {
const { status, data } = error.response;
if (status === 401) {
+2 -2
View File
@@ -36,8 +36,8 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
* status separately.
*/
organizationId?: string;
/** Set when key check returns a 401. */
isRevoked: boolean;
/** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */
isTrial: boolean;
/** Set when key check returns a non-transient 429. */
isOverQuota: boolean;
/** The time at which this key was last rate limited. */
+1 -1
View File
@@ -67,8 +67,8 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
key,
service: this.service,
modelFamilies: ["bison"],
isTrial: false,
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,